SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
inner_product_op.h
1 #ifndef _OPERATORS_INNER_PRODUCT_OP_H_
2 #define _OPERATORS_INNER_PRODUCT_OP_H_
3 
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/core/tensor.h"
7 #include "smaug/core/workspace.h"
9 #include "smaug/operators/fused_activation_op.h"
10 
11 namespace smaug {
12 
19 template <typename Backend>
20 class InnerProductOp : public FusedActivationOp {
21  public:
22  InnerProductOp(const std::string& name, Workspace* workspace)
23  : FusedActivationOp(name, OpType::InnerProduct, workspace),
24  numOutputs(0), weightsTensorsCreated(false),
25  outputTensorsCreated(false), weightsName(name + "/weights"),
26  sampling({ NoSampling, 1 }) {
27  inputs.resize(kNumInputs, nullptr);
28  outputs.resize(kNumOutputs, nullptr);
29  }
30 
31  void setNumOutputs(int _outputs) { numOutputs = _outputs; }
32 
33  void run() override {}
34  bool validate() override { return numOutputs > 0 && Operator::validate(); }
35  TensorShape inferOutputShape() const {
36  const TensorShape& shape = getInput(Inputs)->getShape();
37  assert(shape.getLayout() == DataLayout::NC);
38  return TensorShape(
39  { shape[0], numOutputs }, DataLayout::NC, Backend::Alignment);
40  }
41 
42  TensorShape inferWeightsShape() const {
43  const TensorShape& shape = getInput(Inputs)->getShape();
44  assert(shape.getLayout() == DataLayout::NC);
45  std::vector<int> outputDims;
46  DataLayout outLayout;
47  if (Backend::TransposeFCWeights) {
48  outputDims = { numOutputs, shape[1] };
49  outLayout = DataLayout::NC;
50  } else {
51  outputDims = { shape[1], numOutputs };
52  outLayout = DataLayout::CN;
53  }
54  return TensorShape(outputDims, outLayout, Backend::Alignment);
55  }
56 
57  void createWeightsTensors() {
58  if (inputs.at(Weights))
59  return;
60  TensorShape shape = inferWeightsShape();
61  Tensor* weights = new Tensor(weightsName, shape);
62  workspace->addTensor(weights);
63  inputs.at(Weights) = weights;
64  weightsTensorsCreated = true;
65  }
66 
67  void createOutputTensors() {
68  if (outputs.at(Outputs))
69  return;
70  TensorShape shape = inferOutputShape();
71  Tensor* output = new Tensor(name, shape);
72  workspace->addTensor(output);
73  outputs[Outputs] = output;
74  }
75 
76  void createAllTensors() override {
77  createWeightsTensors();
78  createOutputTensors();
79  }
80 
81  int getNumOutputs() const { return numOutputs; }
82 
83  int getNumParameters() const override {
84  return inputs.at(Weights)->getShape().size();
85  }
86 
87  std::vector<TensorBase*> getParameterizableInputs() override {
88  return { inputs[Weights] };
89  }
90 
91  bool isSamplingSupported() const override { return true; }
92  void setSamplingInfo(const SamplingInfo& _sampling) override {
93  sampling = _sampling;
94  }
95 
96  public:
97  enum { Inputs, Weights, kNumInputs };
98  enum { Outputs, kNumOutputs };
99 
100  protected:
101  int numOutputs;
102  bool weightsTensorsCreated;
103  bool outputTensorsCreated;
104  std::string weightsName;
105  SamplingInfo sampling;
106 };
107 
108 REGISTER_SPECIAL_OP(InnerProductOp, ReferenceBackend);
109 
110 
111 } // namespace smaug
112 
113 #endif
_SamplingInfo
Simulation sampling information maintained by the Operator and passed to the accelerated kernel.
Definition: common.h:262
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
common.h
Utilities for writing and invoking Aladdin kernels from Operators.
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47