|
SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
|
An iterator over a multidimensional tensor's indices, accounting for data alignment padding. More...
#include <tensor.h>
An iterator over a multidimensional tensor's indices, accounting for data alignment padding.
The iterator tracks the current location as a coordinate and outputs the linearized index so that the data in a tensor can be accessed. While most commonly used to iterate through the contents of a tensor one by one, it can also provide random access to any location in the tensor.
Example usage for simple iteration: auto iter = TensorIndexIterator(tensor->getShape());
Example usage for random access (assume 4D tensor): auto iter = TensorIndexIterator(tensor->getShape()); float* data = tensor->data<float>(); data[iter(1,2,3,4)] = 1.2; data[iter(3,4,0,0)] = 3.4;
The iterator skips over data alignment padding areas, if any exist.
Public Member Functions | |
| TensorIndexIterator (const TensorShape &shape, bool _atEnd=false) | |
| operator int () const | |
| bool | end () const |
| void | operator++ () |
| void | operator+= (const std::vector< int > ®ion) |
| template<typename... Args> | |
| int | operator() (int i, Args... args) |
| bool | operator== (const TensorIndexIterator &other) const |
| bool | operator!= (const TensorIndexIterator &other) const |
| int | currentIndex (int dim) const |
| This returns the current index of the iterator on the specified dim. | |
Protected Member Functions | |
| template<typename Container > | |
| int | getIndex (Container indices) const |
| Returns the linear index into the Tensor's underlying data container at the specified coordinates. | |
| virtual void | advanceRegion (const std::vector< int > ®ion) |
Protected Attributes | |
| std::vector< int > | state |
| The current location of the iterator. | |
| std::vector< int > | dims |
| The dimensions of this iterator's Tensor. | |
| std::vector< int > | padding |
| Alignment padding of the Tensor. | |
| bool | atEnd |
| If true, we've reached the end of the Tensor. | |
| const std::vector< int > | advanceOne |
| A vector of all ones, used to implement operator++. | |
Friends | |
| std::ostream & | operator<< (std::ostream &os, const TensorIndexIterator &iter) |
1.8.18