SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
|
A tensor index iterator that stays within a specified rectangular region. More...
#include <tensor.h>
A tensor index iterator that stays within a specified rectangular region.
The rectangular region is specified using an origin coordinate and a region size. The iterator will output linear indices in the same space as the full tensor index iterator, but indices outside the region will be skipped.
Example: consider a 3x3 tensor. The upper right 2x2 region's origin is at location (0,1). We can output just that block like so:
auto it = TensorRegionIndexIterator(tensor->getShape(), {0,1}, {2,2}); while (!it.end()) std::cout << (int)it << "\n";
This produces: 1, 2, 4, 5
Public Member Functions | |
TensorRegionIndexIterator (const TensorShape &shape, const std::vector< int > &_origin, const std::vector< int > &_regionSize) | |
Public Member Functions inherited from smaug::TensorIndexIterator | |
TensorIndexIterator (const TensorShape &shape, bool _atEnd=false) | |
operator int () const | |
bool | end () const |
void | operator++ () |
void | operator+= (const std::vector< int > ®ion) |
template<typename... Args> | |
int | operator() (int i, Args... args) |
bool | operator== (const TensorIndexIterator &other) const |
bool | operator!= (const TensorIndexIterator &other) const |
int | currentIndex (int dim) const |
This returns the current index of the iterator on the specified dim. | |
Protected Member Functions | |
virtual void | advanceRegion (const std::vector< int > &advanceRegionSize) |
Advance the tensor region index with the specified region size. | |
Protected Member Functions inherited from smaug::TensorIndexIterator | |
template<typename Container > | |
int | getIndex (Container indices) const |
Returns the linear index into the Tensor's underlying data container at the specified coordinates. | |
Protected Attributes | |
std::vector< int > | origin |
std::vector< int > | regionSize |
Protected Attributes inherited from smaug::TensorIndexIterator | |
std::vector< int > | state |
The current location of the iterator. | |
std::vector< int > | dims |
The dimensions of this iterator's Tensor. | |
std::vector< int > | padding |
Alignment padding of the Tensor. | |
bool | atEnd |
If true, we've reached the end of the Tensor. | |
const std::vector< int > | advanceOne |
A vector of all ones, used to implement operator++. | |