SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
|
An iterator over a multidimensional tensor's indices, accounting for data alignment padding. More...
#include <tensor.h>
An iterator over a multidimensional tensor's indices, accounting for data alignment padding.
The iterator tracks the current location as a coordinate and outputs the linearized index so that the data in a tensor can be accessed. While most commonly used to iterate through the contents of a tensor one by one, it can also provide random access to any location in the tensor.
Example usage for simple iteration: auto iter = TensorIndexIterator(tensor->getShape());
Example usage for random access (assume 4D tensor): auto iter = TensorIndexIterator(tensor->getShape()); float* data = tensor->data<float>(); data[iter(1,2,3,4)] = 1.2; data[iter(3,4,0,0)] = 3.4;
The iterator skips over data alignment padding areas, if any exist.
Public Member Functions | |
TensorIndexIterator (const TensorShape &shape, bool _atEnd=false) | |
operator int () const | |
bool | end () const |
void | operator++ () |
void | operator+= (const std::vector< int > ®ion) |
template<typename... Args> | |
int | operator() (int i, Args... args) |
bool | operator== (const TensorIndexIterator &other) const |
bool | operator!= (const TensorIndexIterator &other) const |
int | currentIndex (int dim) const |
This returns the current index of the iterator on the specified dim. | |
Protected Member Functions | |
template<typename Container > | |
int | getIndex (Container indices) const |
Returns the linear index into the Tensor's underlying data container at the specified coordinates. | |
virtual void | advanceRegion (const std::vector< int > ®ion) |
Protected Attributes | |
std::vector< int > | state |
The current location of the iterator. | |
std::vector< int > | dims |
The dimensions of this iterator's Tensor. | |
std::vector< int > | padding |
Alignment padding of the Tensor. | |
bool | atEnd |
If true, we've reached the end of the Tensor. | |
const std::vector< int > | advanceOne |
A vector of all ones, used to implement operator++. | |
Friends | |
std::ostream & | operator<< (std::ostream &os, const TensorIndexIterator &iter) |