1 #ifndef _OPERATORS_POOLING_OP_H_
2 #define _OPERATORS_POOLING_OP_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/core/tensor.h"
7 #include "smaug/core/workspace.h"
21 template <
typename Backend>
25 :
Operator(name, _opType, workspace), poolingRowSize(0),
26 poolingColSize(0), poolingRowStride(0), poolingColStride(0),
27 sampling({ NoSampling, 1 }) {
28 inputs.resize(kNumInputs,
nullptr);
29 outputs.resize(kNumOutputs,
nullptr);
32 std::pair<int, int> getPoolingSize()
const {
33 return std::make_pair(poolingRowSize, poolingColSize);
35 std::pair<int, int> getPoolingStride()
const {
36 return std::make_pair(poolingRowStride, poolingColStride);
39 void setPoolingSize(
int rowSize,
int colSize) {
40 poolingRowSize = rowSize;
41 poolingColSize = colSize;
44 void setPoolingStride(
int rowStride,
int colStride) {
45 poolingRowStride = rowStride;
46 poolingColStride = colStride;
50 return (poolingColSize > 0 && poolingRowStride > 0 &&
54 int getNumOfmaps()
const {
55 Tensor* input = getInput(0);
56 assert(input &&
"Unable to find input for pooling layer!");
58 bool isNCHW = inputShape.getLayout() == DataLayout::NCHW;
59 int chanIdx = isNCHW ? 1 : 3;
60 return input->dim(chanIdx);
63 TensorShape inferOutputShape()
const {
64 const TensorShape& inputShape =
inputs.at(Inputs)->getShape();
65 bool isNCHW = inputShape.getLayout() == DataLayout::NCHW;
66 int inputRows = isNCHW ? inputShape[2] : inputShape[1];
67 int inputCols = isNCHW ? inputShape[3] : inputShape[2];
68 int inputChans = isNCHW ? inputShape[1] : inputShape[3];
69 int outputRows = calcOutputRows(inputRows);
70 int outputCols = calcOutputCols(inputCols);
71 assert(outputRows > 0 && outputCols > 0 &&
72 "Pooling layer field size exceeds the input image dimensions!");
75 { inputShape[0], inputChans, outputRows, outputCols },
76 inputShape.getLayout(), Backend::Alignment);
79 { inputShape[0], outputRows, outputCols, inputChans },
80 inputShape.getLayout(), Backend::Alignment);
84 void createOutputTensors() {
87 TensorShape shape = inferOutputShape();
88 Tensor* output =
new Tensor(name, shape);
89 workspace->addTensor(output);
95 bool isSamplingSupported()
const override {
return true; }
96 void setSamplingInfo(
const SamplingInfo& _sampling)
override {
101 int calcOutputRows(
int inputRows)
const {
102 return computeOutputDim(inputRows, poolingRowSize, poolingRowStride);
104 int calcOutputCols(
int inputCols)
const {
105 return computeOutputDim(inputCols, poolingColSize, poolingColStride);
108 int computeOutputDim(
int inputDims,
int poolSize,
int poolStride)
const {
109 return (inputDims - poolSize) / poolStride + 1;
112 enum { Inputs, kNumInputs };
113 enum { Outputs, kNumOutputs };
117 int poolingRowStride;
118 int poolingColStride;
126 template <
typename Backend>
127 class MaxPoolingOp :
public PoolingOp<Backend> {
129 typedef PoolingOp<Backend> Parent;
132 MaxPoolingOp(
const std::string& name, Workspace* workspace)
133 : PoolingOp<Backend>(name, OpType::MaxPooling, workspace) {}
134 void run()
override{};
141 template <
typename Backend>
142 class AvgPoolingOp :
public PoolingOp<Backend> {
144 typedef PoolingOp<Backend> Parent;
147 AvgPoolingOp(
const std::string& name, Workspace* workspace)
148 : PoolingOp<Backend>(name, OpType::AveragePooling, workspace) {}
149 void run()
override{};
152 REGISTER_SPECIAL_OP(MaxPoolingOp, ReferenceBackend);
153 REGISTER_SPECIAL_OP(AvgPoolingOp, ReferenceBackend);