1 #include "smaug/core/backend.h"
3 #include "smaug/operators/smv/smv_inner_product_op.h"
4 #include "smaug/operators/smv/smv_inner_product_tiling.h"
5 #include "smaug/operators/smv/smv_kernels.h"
6 #include "smaug/operators/smv/smv_accel_pool.h"
7 #include "smaug/utility/debug_stream.h"
13 const int kNumPEs = 8;
14 const int kNumMaccsPerPE = 32;
31 assert(outputs.size() == 1 &&
32 "Inner product outputs tiling not implemented yet!");
33 int inputNumTiles = inputs.getShape()[0];
34 int inputActTiles = inputs.getShape()[1];
35 int weightActTiles = weights.getShape()[1];
36 int weightNeuronTiles = weights.getShape()[0];
37 auto inputIdx = inputs.startIndex();
38 auto weightIdx = weights.startIndex();
39 auto outputIdx = outputs.startIndex();
42 smv::kInnerProductHw + i,
"host_a", getInputsMemType());
44 smv::kInnerProductHw + i,
"host_b", getWeightsMemType());
46 smv::kInnerProductHw + i,
"host_results", getOutputsMemType());
51 for (
int N = 0; N < inputNumTiles; N++) {
55 int finishedNeurons = 0;
56 for (
int W = 0; W < weightNeuronTiles; W++) {
62 int outputTileIdx = outputIdx(N, 0);
63 Tensor* outputTile = outputs[outputTileIdx];
64 const TensorShape& outputShape = outputTile->getShape();
66 outputTile->data<float16>(),
67 outputShape.storageSize() *
sizeof(float16));
71 while (iC < inputActTiles && wC < weightActTiles) {
72 int inputTileIdx = inputIdx(N, iC);
73 int weightTileIdx = weightIdx(W, wC);
80 dout(1) <<
"Input: " << inputTileIdx
81 <<
", weights: " << weightTileIdx
82 <<
", output: " << outputTileIdx <<
"\n";
85 const TensorShape& inputShape = inputTile->getShape();
86 const TensorShape& weightsShape = weightsTile->getShape();
88 inputTile->data<float16>(),
89 inputShape.storageSize() *
sizeof(float16));
91 weightsTile->data<float16>(),
92 weightsShape.storageSize() *
sizeof(float16));
93 int inputDims[2] = { inputShape[0], inputShape[1] };
94 int weightsDims[2] = { weightsShape[0], weightsShape[1] };
95 int outputDims[2] = { outputShape[0], outputShape[1] };
100 int actStart = (iC == wC) ? 0 : actOffset;
104 bool accumulate = wC > 0;
106 bool readInputs =
false;
107 if (inputTileIdx != lastReadInputTileIdx[currAccelIdx]) {
109 lastReadInputTileIdx[currAccelIdx] = inputTileIdx;
113 bool sendOutputs = (N == inputNumTiles - 1) &&
114 (W == weightNeuronTiles - 1) &&
115 (wC == weightActTiles - 1);
118 currAccelIdx, smv::kInnerProductHw + currAccelIdx,
120 inputTile->data<float16>(),
121 weightsTile->data<float16>(),
122 outputTile->data<float16>(), smv::spad0, smv::spad1,
123 smv::spad2, inputDims, weightsDims, outputDims,
124 inputShape.getPadding(1), weightsShape.getPadding(1),
125 outputShape.getPadding(1), actStart, finishedNeurons,
126 accumulate, readInputs, sendOutputs, actInfo.function,
127 actInfo.params, &sampling);
128 accelPool.addFinishFlag(currAccelIdx, std::move(finishFlag));
130 actOffset += weightsTile->getShape()[1];
131 if (inputActTiles == weightActTiles) {
134 }
else if (inputActTiles == 1) {
137 assert(
false &&
"The input/weight tiles can have different "
138 "number of channels only when the inputs "
139 "don't need activation-wise tiling.");
142 finishedNeurons += weights[weightIdx(W, 0)]->getShape()[0];
143 currAccelIdx = accelPool.getNextAvailableAccelerator(currAccelIdx);
150 void SmvInnerProductOp::tile() {
154 tiledTensors = smaug::smv::fc::TilingOptimizer::doTiling(
this);
157 void SmvInnerProductOp::run() {
158 auto inputs = getInput(Inputs);
159 auto weights = getInput(Weights);
160 auto outputs = getOutput(Outputs);
161 const TensorShape& inputsShape = inputs->getShape();
162 const TensorShape& weightsShape = weights->getShape();
163 const TensorShape& outputsShape = outputs->getShape();
164 assert(inputsShape.getLayout() == DataLayout::NC);
165 assert(weightsShape.getLayout() == DataLayout::NC);
166 assert(outputsShape.getLayout() == DataLayout::NC);
167 dout(2) << *weights <<
"\n";
170 auto stats = gem5::ScopedStats(
171 stats::kTensorPrepStart, stats::kTensorPrepEnd);
172 tiledTensors[0].copyDataToAllTiles();
173 tiledTensors[1].copyDataToAllTiles();
176 runNWA(tiledTensors[0], tiledTensors[1], tiledTensors[2]);
179 auto stats = gem5::ScopedStats(
180 stats::kTensorFinalStart, stats::kTensorFinalEnd);
181 tiledTensors[2].untile();