1 #ifndef _OPERATORS_INNER_PRODUCT_OP_H_
2 #define _OPERATORS_INNER_PRODUCT_OP_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/core/tensor.h"
7 #include "smaug/core/workspace.h"
9 #include "smaug/operators/fused_activation_op.h"
19 template <
typename Backend>
20 class InnerProductOp :
public FusedActivationOp {
22 InnerProductOp(
const std::string& name, Workspace* workspace)
23 : FusedActivationOp(name, OpType::InnerProduct, workspace),
24 numOutputs(0), weightsTensorsCreated(false),
25 outputTensorsCreated(false), weightsName(name +
"/weights"),
26 sampling({ NoSampling, 1 }) {
27 inputs.resize(kNumInputs,
nullptr);
28 outputs.resize(kNumOutputs,
nullptr);
31 void setNumOutputs(
int _outputs) { numOutputs = _outputs; }
33 void run()
override {}
35 TensorShape inferOutputShape()
const {
36 const TensorShape& shape = getInput(Inputs)->getShape();
37 assert(shape.getLayout() == DataLayout::NC);
39 { shape[0], numOutputs }, DataLayout::NC, Backend::Alignment);
42 TensorShape inferWeightsShape()
const {
43 const TensorShape& shape = getInput(Inputs)->getShape();
44 assert(shape.getLayout() == DataLayout::NC);
45 std::vector<int> outputDims;
47 if (Backend::TransposeFCWeights) {
48 outputDims = { numOutputs, shape[1] };
49 outLayout = DataLayout::NC;
51 outputDims = { shape[1], numOutputs };
52 outLayout = DataLayout::CN;
54 return TensorShape(outputDims, outLayout, Backend::Alignment);
57 void createWeightsTensors() {
58 if (inputs.at(Weights))
60 TensorShape shape = inferWeightsShape();
61 Tensor* weights =
new Tensor(weightsName, shape);
62 workspace->addTensor(weights);
63 inputs.at(Weights) = weights;
64 weightsTensorsCreated =
true;
67 void createOutputTensors() {
68 if (outputs.at(Outputs))
70 TensorShape shape = inferOutputShape();
71 Tensor* output =
new Tensor(name, shape);
72 workspace->addTensor(output);
73 outputs[Outputs] = output;
76 void createAllTensors()
override {
77 createWeightsTensors();
78 createOutputTensors();
81 int getNumOutputs()
const {
return numOutputs; }
83 int getNumParameters()
const override {
84 return inputs.at(Weights)->getShape().size();
87 std::vector<TensorBase*> getParameterizableInputs()
override {
88 return { inputs[Weights] };
91 bool isSamplingSupported()
const override {
return true; }
92 void setSamplingInfo(
const SamplingInfo& _sampling)
override {
97 enum { Inputs, Weights, kNumInputs };
98 enum { Outputs, kNumOutputs };
102 bool weightsTensorsCreated;
103 bool outputTensorsCreated;
104 std::string weightsName;
108 REGISTER_SPECIAL_OP(InnerProductOp, ReferenceBackend);