SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
batch_norm_op.h
1 #ifndef _OPERATORS_BATCH_NORM_OP_H_
2 #define _OPERATORS_BATCH_NORM_OP_H_
3 
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/core/tensor.h"
7 #include "smaug/core/workspace.h"
9 #include "smaug/operators/fused_activation_op.h"
10 
11 namespace smaug {
12 
24 template <typename Backend>
25 class BatchNormOp : public FusedActivationOp {
26  public:
27  enum {
28  Inputs,
29  Mean,
30  Variance,
31  Gamma,
32  Scaleshift = Gamma, // for MKL.
33  Beta,
34  kNumInputs
35  };
36  enum { Outputs, kNumOutputs };
37  static constexpr float kEpsilon = 1e-5;
38 
39  BatchNormOp(const std::string& name, Workspace* workspace)
40  : FusedActivationOp(name, OpType::BatchNorm, workspace),
41  meanName(name + "/mean"), varianceName(name + "/variance"),
42  gammaName(name + "/gamma"), betaName(name + "/beta"),
43  sampling({ NoSampling, 1 }) {
44  inputs.resize(kNumInputs, nullptr);
45  outputs.resize(kNumOutputs, nullptr);
46  }
47 
48  void run() override {}
49  TensorShape inferOutputShape() const {
50  return getInput(Inputs)->getShape();
51  }
52  TensorShape inferWeightsShape() const {
53  TensorShape shape = getInput(Inputs)->getShape();
54  DataLayout layout = shape.getLayout();
55  int ndims = shape.ndims();
56  if (ndims >= 4) {
57  // This is a volume which should be batch norm'ed by feature map.
58  bool isNCHW = layout == DataLayout::NCHW;
59  int fmaps = isNCHW ? shape[ndims - 3] : shape[ndims - 1];
60  return TensorShape(
61  { 1, fmaps }, DataLayout::NC, Backend::Alignment);
62  } else if (ndims == 2) {
63  if (layout == DataLayout::NC)
64  return TensorShape(
65  { 1, shape[1] }, DataLayout::NC, Backend::Alignment);
66  else
67  assert(false && "Unexpected data layout for batch norm!");
68  } else {
69  assert(false && "Unexpected input dimensions for batch norm!");
70  }
71  return TensorShape();
72  }
73 
74  void createWeightsTensors() {
75  if (inputs[Mean] && inputs[Variance] && inputs[Gamma] && inputs[Beta])
76  return;
77  TensorShape shape = inferWeightsShape();
78  inputs[Mean] = new Tensor(meanName, shape);
79  inputs[Variance] = new Tensor(varianceName, shape);
80  inputs[Gamma] = new Tensor(gammaName, shape);
81  inputs[Beta] = new Tensor(betaName, shape);
82  for (int i = Mean; i <= Beta; i++)
83  workspace->addTensor(static_cast<Tensor*>(inputs[i]));
84  }
85 
86  void createOutputTensors() {
87  if (outputs[Outputs])
88  return;
89  TensorShape shape = inferOutputShape();
90  Tensor* output = new Tensor(name, shape);
91  workspace->addTensor(output);
92  outputs[Outputs] = output;
93  }
94 
95  void createAllTensors() override {
96  createWeightsTensors();
97  createOutputTensors();
98  }
99 
100  int getNumParameters() const override {
101  return kNumInputs * inputs.at(Mean)->getShape().size();
102  }
103 
104  std::vector<TensorBase*> getParameterizableInputs() override {
105  return { inputs[Mean], inputs[Variance], inputs[Gamma], inputs[Beta] };
106  }
107 
108  bool isSamplingSupported() const override { return true; }
109  void setSamplingInfo(const SamplingInfo& _sampling) override {
110  sampling = _sampling;
111  }
112 
113  protected:
114  const std::string meanName;
115  const std::string varianceName;
116  const std::string gammaName;
117  const std::string betaName;
118  SamplingInfo sampling;
119 };
120 
121 REGISTER_SPECIAL_OP(BatchNormOp, ReferenceBackend);
122 
123 } // namespace smaug
124 
125 #endif
_SamplingInfo
Simulation sampling information maintained by the Operator and passed to the accelerated kernel.
Definition: common.h:262
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
common.h
Utilities for writing and invoking Aladdin kernels from Operators.