1 #ifndef _OPERATORS_DEPTHWISE_CONVOLUTION_OP_H_
2 #define _OPERATORS_DEPTHWISE_CONVOLUTION_OP_H_
4 #include "smaug/operators/convolution_op.h"
14 template <
typename Backend>
15 class DepthwiseConvolutionOp :
public ConvolutionOp<Backend> {
17 typedef ConvolutionOp<Backend> Parent;
20 DepthwiseConvolutionOp(
const std::string& name, Workspace* workspace)
21 : ConvolutionOp<Backend>(name, workspace) {
22 this->
template opType = OpType::ConvolutionDepthwise;
25 void run()
override {}
27 TensorShape inferOutputShape()
const override {
28 Tensor* input = this->
template getInput(Parent::Inputs);
29 assert(input &&
"Unable to get input for convolution op!");
30 const TensorShape& shape = input->getShape();
31 DataLayout layout = shape.getLayout();
32 bool isNCHW = (layout == DataLayout::NCHW);
33 int rowIdx = isNCHW ? 2 : 1;
34 int colIdx = isNCHW ? 3 : 2;
35 int outputRows = this->computeOutputDim(shape[rowIdx],
39 int outputCols = this->computeOutputDim(shape[colIdx],
44 return TensorShape({ shape[0], shape[1], outputRows, outputCols },
48 return TensorShape({ shape[0], outputRows, outputCols, shape[3] },
54 TensorShape inferWeightsShape()
const override {
55 Tensor* input = this->
template getInput(Parent::Inputs);
56 const TensorShape& shape = input->getShape();
57 DataLayout layout = shape.getLayout();
58 bool isNCHW = (layout == DataLayout::NCHW);
59 int inputChannels = isNCHW ? shape[1] : shape[3];
62 { 1, inputChannels, this->weightRows, this->weightCols },
63 layout, Backend::Alignment);
66 { 1, this->weightRows, this->weightCols, inputChannels },
67 layout, Backend::Alignment);
72 REGISTER_SPECIAL_OP(DepthwiseConvolutionOp, ReferenceBackend);