SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
depthwise_convolution_op.h
1 #ifndef _OPERATORS_DEPTHWISE_CONVOLUTION_OP_H_
2 #define _OPERATORS_DEPTHWISE_CONVOLUTION_OP_H_
3 
4 #include "smaug/operators/convolution_op.h"
5 
6 namespace smaug {
7 
14 template <typename Backend>
15 class DepthwiseConvolutionOp : public ConvolutionOp<Backend> {
16  protected:
17  typedef ConvolutionOp<Backend> Parent;
18 
19  public:
20  DepthwiseConvolutionOp(const std::string& name, Workspace* workspace)
21  : ConvolutionOp<Backend>(name, workspace) {
22  this->template opType = OpType::ConvolutionDepthwise;
23  }
24 
25  void run() override {}
26 
27  TensorShape inferOutputShape() const override {
28  Tensor* input = this->template getInput(Parent::Inputs);
29  assert(input && "Unable to get input for convolution op!");
30  const TensorShape& shape = input->getShape();
31  DataLayout layout = shape.getLayout();
32  bool isNCHW = (layout == DataLayout::NCHW);
33  int rowIdx = isNCHW ? 2 : 1;
34  int colIdx = isNCHW ? 3 : 2;
35  int outputRows = this->computeOutputDim(shape[rowIdx],
36  this->weightRows,
37  this->rowStride,
38  this->paddingType);
39  int outputCols = this->computeOutputDim(shape[colIdx],
40  this->weightCols,
41  this->colStride,
42  this->paddingType);
43  if (isNCHW) {
44  return TensorShape({ shape[0], shape[1], outputRows, outputCols },
45  layout,
46  Backend::Alignment);
47  } else {
48  return TensorShape({ shape[0], outputRows, outputCols, shape[3] },
49  layout,
50  Backend::Alignment);
51  }
52  }
53 
54  TensorShape inferWeightsShape() const override {
55  Tensor* input = this->template getInput(Parent::Inputs);
56  const TensorShape& shape = input->getShape();
57  DataLayout layout = shape.getLayout();
58  bool isNCHW = (layout == DataLayout::NCHW);
59  int inputChannels = isNCHW ? shape[1] : shape[3];
60  if (isNCHW) {
61  return TensorShape(
62  { 1, inputChannels, this->weightRows, this->weightCols },
63  layout, Backend::Alignment);
64  } else {
65  return TensorShape(
66  { 1, this->weightRows, this->weightCols, inputChannels },
67  layout, Backend::Alignment);
68  }
69  }
70 };
71 
72 REGISTER_SPECIAL_OP(DepthwiseConvolutionOp, ReferenceBackend);
73 
74 } // namespace smaug
75 
76 #endif
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38