SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
Public Member Functions | Protected Member Functions | Protected Attributes | Friends | List of all members
smaug::TensorIndexIterator Class Reference

An iterator over a multidimensional tensor's indices, accounting for data alignment padding. More...

#include <tensor.h>

Inheritance diagram for smaug::TensorIndexIterator:
smaug::TensorRegionIndexIterator

Detailed Description

An iterator over a multidimensional tensor's indices, accounting for data alignment padding.

The iterator tracks the current location as a coordinate and outputs the linearized index so that the data in a tensor can be accessed. While most commonly used to iterate through the contents of a tensor one by one, it can also provide random access to any location in the tensor.

Example usage for simple iteration: auto iter = TensorIndexIterator(tensor->getShape());

Example usage for random access (assume 4D tensor): auto iter = TensorIndexIterator(tensor->getShape()); float* data = tensor->data<float>(); data[iter(1,2,3,4)] = 1.2; data[iter(3,4,0,0)] = 3.4;

The iterator skips over data alignment padding areas, if any exist.

Definition at line 128 of file tensor.h.

Public Member Functions

 TensorIndexIterator (const TensorShape &shape, bool _atEnd=false)
 
 operator int () const
 
bool end () const
 
void operator++ ()
 
void operator+= (const std::vector< int > &region)
 
template<typename... Args>
int operator() (int i, Args... args)
 
bool operator== (const TensorIndexIterator &other) const
 
bool operator!= (const TensorIndexIterator &other) const
 
int currentIndex (int dim) const
 This returns the current index of the iterator on the specified dim.
 

Protected Member Functions

template<typename Container >
int getIndex (Container indices) const
 Returns the linear index into the Tensor's underlying data container at the specified coordinates.
 
virtual void advanceRegion (const std::vector< int > &region)
 

Protected Attributes

std::vector< int > state
 The current location of the iterator.
 
std::vector< int > dims
 The dimensions of this iterator's Tensor.
 
std::vector< int > padding
 Alignment padding of the Tensor.
 
bool atEnd
 If true, we've reached the end of the Tensor.
 
const std::vector< int > advanceOne
 A vector of all ones, used to implement operator++.
 

Friends

std::ostream & operator<< (std::ostream &os, const TensorIndexIterator &iter)
 

The documentation for this class was generated from the following file: