1 #ifndef _OPERATORS_BATCH_NORM_OP_H_
2 #define _OPERATORS_BATCH_NORM_OP_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/core/tensor.h"
7 #include "smaug/core/workspace.h"
9 #include "smaug/operators/fused_activation_op.h"
24 template <
typename Backend>
25 class BatchNormOp :
public FusedActivationOp {
36 enum { Outputs, kNumOutputs };
37 static constexpr
float kEpsilon = 1e-5;
39 BatchNormOp(
const std::string& name, Workspace* workspace)
40 : FusedActivationOp(name, OpType::BatchNorm, workspace),
41 meanName(name +
"/mean"), varianceName(name +
"/variance"),
42 gammaName(name +
"/gamma"), betaName(name +
"/beta"),
43 sampling({ NoSampling, 1 }) {
44 inputs.resize(kNumInputs,
nullptr);
45 outputs.resize(kNumOutputs,
nullptr);
48 void run()
override {}
49 TensorShape inferOutputShape()
const {
50 return getInput(Inputs)->getShape();
52 TensorShape inferWeightsShape()
const {
53 TensorShape shape = getInput(Inputs)->getShape();
54 DataLayout layout = shape.getLayout();
55 int ndims = shape.ndims();
58 bool isNCHW = layout == DataLayout::NCHW;
59 int fmaps = isNCHW ? shape[ndims - 3] : shape[ndims - 1];
61 { 1, fmaps }, DataLayout::NC, Backend::Alignment);
62 }
else if (ndims == 2) {
63 if (layout == DataLayout::NC)
65 { 1, shape[1] }, DataLayout::NC, Backend::Alignment);
67 assert(
false &&
"Unexpected data layout for batch norm!");
69 assert(
false &&
"Unexpected input dimensions for batch norm!");
74 void createWeightsTensors() {
75 if (inputs[Mean] && inputs[Variance] && inputs[Gamma] && inputs[Beta])
77 TensorShape shape = inferWeightsShape();
78 inputs[Mean] =
new Tensor(meanName, shape);
79 inputs[Variance] =
new Tensor(varianceName, shape);
80 inputs[Gamma] =
new Tensor(gammaName, shape);
81 inputs[Beta] =
new Tensor(betaName, shape);
82 for (
int i = Mean; i <= Beta; i++)
83 workspace->addTensor(
static_cast<Tensor*
>(inputs[i]));
86 void createOutputTensors() {
89 TensorShape shape = inferOutputShape();
90 Tensor* output =
new Tensor(name, shape);
91 workspace->addTensor(output);
92 outputs[Outputs] = output;
95 void createAllTensors()
override {
96 createWeightsTensors();
97 createOutputTensors();
100 int getNumParameters()
const override {
101 return kNumInputs * inputs.at(Mean)->getShape().size();
104 std::vector<TensorBase*> getParameterizableInputs()
override {
105 return { inputs[Mean], inputs[Variance], inputs[Gamma], inputs[Beta] };
108 bool isSamplingSupported()
const override {
return true; }
109 void setSamplingInfo(
const SamplingInfo& _sampling)
override {
110 sampling = _sampling;
114 const std::string meanName;
115 const std::string varianceName;
116 const std::string gammaName;
117 const std::string betaName;
121 REGISTER_SPECIAL_OP(BatchNormOp, ReferenceBackend);