SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
Public Types | Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
smaug::ConvolutionOp< Backend > Class Template Reference

The base class for all 4D spatial convolution operators. More...

#include <convolution_op.h>

Detailed Description

template<typename Backend>
class smaug::ConvolutionOp< Backend >

The base class for all 4D spatial convolution operators.

Provides common functionality for writing convolution operators.

Template Parameters
BackendThe Backend specialization of this Operator.

Definition at line 33 of file backend.h.

Public Types

enum  { Inputs, Kernels, kNumInputs }
 
enum  { Outputs, kNumOutputs }
 

Public Member Functions

 ConvolutionOp (const std::string &name, Workspace *workspace)
 
void setWeightDims (int _weightRows, int _weightCols, int _numOfmaps)
 
void setStride (int _rowStride, int _colStride)
 
void setPadding (PaddingType padding)
 
bool validate () override
 
virtual TensorShape inferOutputShape () const
 
virtual TensorShape inferWeightsShape () const
 
void createWeightsTensors ()
 Create placeholder tensors for weights, assuming any data layout is okay.
 
void createOutputTensors ()
 
void createAllTensors () override
 
int getNumOfmaps () const
 
void run () override
 
int getNumParameters () const override
 
std::vector< TensorBase * > getParameterizableInputs () override
 
int getRowStride () const
 
int getColStride () const
 
int getWeightRows () const
 
int getWeightCols () const
 
PaddingType getPadding () const
 
std::vector< int > getInputPadding () const
 Compute padding sizes on the row/column boundaries of the input feature map.
 
bool isSamplingSupported () const override
 
void setSamplingInfo (const SamplingInfo &_sampling) override
 
void run ()
 

Protected Member Functions

int computeOutputDim (int inputDim, int weightDim, int stride, PaddingType pad) const
 
int computeOutputDim (int inputDim, int weightDim, int stride, int padding) const
 

Protected Attributes

int weightRows
 
int weightCols
 
int numOfmaps
 
int rowStride
 
int colStride
 
PaddingType paddingType
 
std::string weightsName
 
SamplingInfo sampling
 

The documentation for this class was generated from the following files: