1 #ifndef _OPERATORS_CONVOLUTION_OP_H_
2 #define _OPERATORS_CONVOLUTION_OP_H_
6 #include "smaug/core/backend.h"
7 #include "smaug/core/operator.h"
8 #include "smaug/core/workspace.h"
9 #include "smaug/core/types.pb.h"
11 #include "smaug/operators/fused_activation_op.h"
23 template <
typename Backend>
24 class ConvolutionOp :
public FusedActivationOp {
26 ConvolutionOp(
const std::string& name, Workspace* workspace)
27 : FusedActivationOp(name, OpType::Convolution3d, workspace),
28 weightRows(0), weightCols(0), numOfmaps(0), rowStride(0),
29 colStride(0), paddingType(UnknownPadding),
30 weightsName(name +
"/kernels"), sampling({ NoSampling, 1 }) {
31 inputs.resize(kNumInputs,
nullptr);
32 outputs.resize(kNumOutputs,
nullptr);
35 void setWeightDims(
int _weightRows,
int _weightCols,
int _numOfmaps) {
36 weightRows = _weightRows;
37 weightCols = _weightCols;
38 numOfmaps = _numOfmaps;
41 void setStride(
int _rowStride,
int _colStride) {
42 rowStride = _rowStride;
43 colStride = _colStride;
46 void setPadding(PaddingType padding) {
47 paddingType = padding;
50 bool validate()
override {
51 return (weightRows > 0 && weightCols > 0 && numOfmaps > 0 &&
52 rowStride > 0 && colStride > 0 &&
56 virtual TensorShape inferOutputShape()
const {
57 Tensor* input = getInput(Inputs);
58 assert(input &&
"Unable to get input for convolution op!");
59 DataLayout layout = input->getShape().getLayout();
60 bool isNCHW = (layout == DataLayout::NCHW);
61 int rowIdx = isNCHW ? 2 : 1;
62 int colIdx = isNCHW ? 3 : 2;
63 int outputRows = computeOutputDim(
64 input->dim(rowIdx), weightRows, rowStride, paddingType);
65 int outputCols = computeOutputDim(
66 input->dim(colIdx), weightCols, colStride, paddingType);
68 return TensorShape({ 1, numOfmaps, outputRows, outputCols },
72 return TensorShape({ 1, outputRows, outputCols, numOfmaps },
78 virtual TensorShape inferWeightsShape()
const {
79 Tensor* input = getInput(Inputs);
80 DataLayout layout = input->getShape().getLayout();
81 bool isNCHW = (layout == DataLayout::NCHW);
82 int channelsIdx = isNCHW ? 1 : 3;
83 int inputChannels = input->dim(channelsIdx);
86 { numOfmaps, inputChannels, weightRows, weightCols },
87 layout, Backend::Alignment);
90 { numOfmaps, weightRows, weightCols, inputChannels },
91 layout, Backend::Alignment);
99 if (inputs.at(Kernels) !=
nullptr)
103 workspace->addTensor(kernels);
104 inputs[Kernels] = kernels;
107 void createOutputTensors() {
108 if (outputs.at(Outputs) !=
nullptr)
112 workspace->addTensor(output);
113 outputs.at(Outputs) = output;
116 void createAllTensors()
override {
118 createOutputTensors();
121 int getNumOfmaps()
const {
return numOfmaps; }
123 void run()
override {}
125 int getNumParameters()
const override {
126 return inputs.at(Kernels)->getShape().size();
129 std::vector<TensorBase*> getParameterizableInputs()
override {
130 return { inputs[Kernels] };
133 int getRowStride()
const {
return rowStride; }
134 int getColStride()
const {
return colStride; }
135 int getWeightRows()
const {
return weightRows; }
136 int getWeightCols()
const {
return weightCols; }
137 PaddingType getPadding()
const {
return paddingType; }
144 std::vector<int> inputPadding(4);
145 int totalRowPad = (paddingType == SamePadding) ? weightRows - 1 : 0;
146 int totalColPad = (paddingType == SamePadding) ? weightCols - 1 : 0;
147 inputPadding[0] =
FRAC_CEIL(totalRowPad, 2);
148 inputPadding[1] = totalRowPad - inputPadding[0];
149 inputPadding[2] =
FRAC_CEIL(totalColPad, 2);
150 inputPadding[3] = totalColPad - inputPadding[2];
154 bool isSamplingSupported()
const override {
return true; }
155 void setSamplingInfo(
const SamplingInfo& _sampling)
override {
156 sampling = _sampling;
160 int computeOutputDim(
int inputDim,
163 PaddingType pad)
const {
164 int padding = (pad == SamePadding ? (weightDim - 1) : 0);
165 return computeOutputDim(inputDim, weightDim, stride, padding);
167 int computeOutputDim(
int inputDim,
171 return (inputDim - weightDim + padding) / stride + 1;
175 enum { Inputs, Kernels, kNumInputs };
176 enum { Outputs, kNumOutputs };
184 PaddingType paddingType;
185 std::string weightsName;
189 REGISTER_SPECIAL_OP(ConvolutionOp, ReferenceBackend);