SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
pooling_op.h
1 #ifndef _OPERATORS_POOLING_OP_H_
2 #define _OPERATORS_POOLING_OP_H_
3 
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/core/tensor.h"
7 #include "smaug/core/workspace.h"
8 
9 namespace smaug {
10 
21 template <typename Backend>
22 class PoolingOp : public Operator {
23  public:
24  PoolingOp(const std::string& name, OpType _opType, Workspace* workspace)
25  : Operator(name, _opType, workspace), poolingRowSize(0),
26  poolingColSize(0), poolingRowStride(0), poolingColStride(0),
27  sampling({ NoSampling, 1 }) {
28  inputs.resize(kNumInputs, nullptr);
29  outputs.resize(kNumOutputs, nullptr);
30  }
31 
32  std::pair<int, int> getPoolingSize() const {
33  return std::make_pair(poolingRowSize, poolingColSize);
34  }
35  std::pair<int, int> getPoolingStride() const {
36  return std::make_pair(poolingRowStride, poolingColStride);
37  }
38 
39  void setPoolingSize(int rowSize, int colSize) {
40  poolingRowSize = rowSize;
41  poolingColSize = colSize;
42  }
43 
44  void setPoolingStride(int rowStride, int colStride) {
45  poolingRowStride = rowStride;
46  poolingColStride = colStride;
47  }
48 
49  bool validate() override {
50  return (poolingColSize > 0 && poolingRowStride > 0 &&
51  poolingColStride > 0 && Operator::validate());
52  }
53 
54  int getNumOfmaps() const {
55  Tensor* input = getInput(0);
56  assert(input && "Unable to find input for pooling layer!");
57  const TensorShape& inputShape = inputs.at(Inputs)->getShape();
58  bool isNCHW = inputShape.getLayout() == DataLayout::NCHW;
59  int chanIdx = isNCHW ? 1 : 3;
60  return input->dim(chanIdx);
61  }
62 
63  TensorShape inferOutputShape() const {
64  const TensorShape& inputShape = inputs.at(Inputs)->getShape();
65  bool isNCHW = inputShape.getLayout() == DataLayout::NCHW;
66  int inputRows = isNCHW ? inputShape[2] : inputShape[1];
67  int inputCols = isNCHW ? inputShape[3] : inputShape[2];
68  int inputChans = isNCHW ? inputShape[1] : inputShape[3];
69  int outputRows = calcOutputRows(inputRows);
70  int outputCols = calcOutputCols(inputCols);
71  assert(outputRows > 0 && outputCols > 0 &&
72  "Pooling layer field size exceeds the input image dimensions!");
73  if (isNCHW) {
74  return TensorShape(
75  { inputShape[0], inputChans, outputRows, outputCols },
76  inputShape.getLayout(), Backend::Alignment);
77  } else {
78  return TensorShape(
79  { inputShape[0], outputRows, outputCols, inputChans },
80  inputShape.getLayout(), Backend::Alignment);
81  }
82  }
83 
84  void createOutputTensors() {
85  if (outputs.at(Outputs))
86  return;
87  TensorShape shape = inferOutputShape();
88  Tensor* output = new Tensor(name, shape);
89  workspace->addTensor(output);
90  outputs.at(Outputs) = output;
91  }
92 
93  void createAllTensors() override { createOutputTensors(); }
94 
95  bool isSamplingSupported() const override { return true; }
96  void setSamplingInfo(const SamplingInfo& _sampling) override {
97  sampling = _sampling;
98  }
99 
100  protected:
101  int calcOutputRows(int inputRows) const {
102  return computeOutputDim(inputRows, poolingRowSize, poolingRowStride);
103  }
104  int calcOutputCols(int inputCols) const {
105  return computeOutputDim(inputCols, poolingColSize, poolingColStride);
106  }
107 
108  int computeOutputDim(int inputDims, int poolSize, int poolStride) const {
109  return (inputDims - poolSize) / poolStride + 1;
110  }
111 
112  enum { Inputs, kNumInputs };
113  enum { Outputs, kNumOutputs };
114 
115  int poolingRowSize;
116  int poolingColSize;
117  int poolingRowStride;
118  int poolingColStride;
119  SamplingInfo sampling;
120 };
121 
126 template <typename Backend>
127 class MaxPoolingOp : public PoolingOp<Backend> {
128  protected:
129  typedef PoolingOp<Backend> Parent;
130 
131  public:
132  MaxPoolingOp(const std::string& name, Workspace* workspace)
133  : PoolingOp<Backend>(name, OpType::MaxPooling, workspace) {}
134  void run() override{};
135 };
136 
141 template <typename Backend>
142 class AvgPoolingOp : public PoolingOp<Backend> {
143  protected:
144  typedef PoolingOp<Backend> Parent;
145 
146  public:
147  AvgPoolingOp(const std::string& name, Workspace* workspace)
148  : PoolingOp<Backend>(name, OpType::AveragePooling, workspace) {}
149  void run() override{};
150 };
151 
152 REGISTER_SPECIAL_OP(MaxPoolingOp, ReferenceBackend);
153 REGISTER_SPECIAL_OP(AvgPoolingOp, ReferenceBackend);
154 
155 } // namespace smaug
156 
157 #endif
smaug::Tensor
Tensor represents a single multi-dimensional array of data.
Definition: tensor.h:344
_SamplingInfo
Simulation sampling information maintained by the Operator and passed to the accelerated kernel.
Definition: common.h:262
smaug::Workspace
Workspace is the container and owner of all Tensors and Operators in the Network.
Definition: workspace.h:17
smaug::PoolingOp::createAllTensors
void createAllTensors() override
For tests: creates all input and output tensors for this operator.
Definition: pooling_op.h:93
smaug::TensorShape
TensorShape describes the shape of a Tensor.
Definition: tensor.h:35
smaug::PoolingOp::validate
bool validate() override
Returns true if the parameters/tensors of this operator are all valid.
Definition: pooling_op.h:49
smaug::Operator::outputs
std::vector< TensorBase * > outputs
An ordered list of output tensors produced by this operator.
Definition: operator.h:141
smaug::Operator
Operator is the base class for all graph operators supported by SMAUG.
Definition: operator.h:28
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47
smaug::PoolingOp
Implements a pooling operator.
Definition: pooling_op.h:22
smaug::Operator::inputs
std::vector< TensorBase * > inputs
An ordered list of input tensors consumed by this operator.
Definition: operator.h:134