1 #include "smaug/core/backend.h"
3 #include "smaug/operators/smv/smv_convolution_op.h"
4 #include "smaug/operators/smv/smv_convolution_tiling.h"
5 #include "smaug/operators/smv/smv_kernels.h"
6 #include "smaug/operators/smv/smv_accel_pool.h"
7 #include "smaug/utility/debug_stream.h"
13 const int kNumPEs = 8;
14 const int kNumMaccsPerPE = 32;
22 int inputIfmapTiles = inputs.getShape()[0];
23 int inputRowTiles = inputs.getShape()[1];
24 int inputChanTiles = inputs.getShape()[3];
25 int weightOfmapTiles = weights.getShape()[0];
26 int weightChanTiles = weights.getShape()[3];
27 int outputRowTiles = outputs.getShape()[1];
28 int outputChanTiles = outputs.getShape()[3];
29 auto inputIdx = inputs.startIndex();
30 auto weightIdx = weights.startIndex();
31 auto outputIdx = outputs.startIndex();
33 int topPad = inputPadding[0];
34 int bottomPad = inputPadding[1];
35 int leftPad = inputPadding[2];
36 int rightPad = inputPadding[3];
38 : smv::kConvolutionHw;
44 accelId + i,
"host_inputs", getInputsMemType());
46 accelId + i,
"host_weights", getWeightsMemType());
48 accelId + i,
"host_results", getOutputsMemType());
51 for (
int N = 0; N < inputIfmapTiles; N++) {
52 for (
int H = 0; H < outputRowTiles; H++) {
53 int currentTileTopPad = topPad;
54 int currentTileBottomPad = bottomPad;
55 if (inputRowTiles > 1) {
57 currentTileBottomPad = 0;
58 }
else if (H == inputRowTiles - 1) {
59 currentTileTopPad = 0;
61 currentTileTopPad = 0;
62 currentTileBottomPad = 0;
67 int inputHaloPad[4] = { currentTileTopPad, currentTileBottomPad,
77 bool needOutputIteration = weightOfmapTiles < outputChanTiles;
83 int numOutputInvocations =
84 needOutputIteration ? outputChanTiles : 1;
85 assert(numOutputInvocations > 1
86 ? weightOfmapTiles == 1
87 : weightOfmapTiles == outputChanTiles);
88 for (
int W = 0; W < weightOfmapTiles; W++) {
104 for (
int oC = 0; oC < numOutputInvocations; oC++) {
108 int outputTileIdx = outputIdx(N, H, 0, W + oC);
109 Tensor* outputTile = outputs[outputTileIdx];
110 const TensorShape& outputShape = outputTile->getShape();
112 accelId + currAccelIdx,
"host_results",
113 outputTile->
data<float16>(),
114 outputShape.storageSize() *
sizeof(float16));
125 while (iC < inputChanTiles && wC < weightChanTiles) {
126 int inputTileIdx = inputIdx(N, H, 0, iC);
127 int weightTileIdx = weightIdx(W, 0, 0, wC);
128 dout(1) <<
"Input: " << inputTileIdx
129 <<
", weights: " << weightTileIdx
130 <<
", output: " << outputTileIdx <<
"\n";
135 const TensorShape& inputShape = inputTile->getShape();
137 weightsTile->getShape();
139 accelId + currAccelIdx,
"host_inputs",
140 inputTile->
data<float16>(),
141 inputShape.storageSize() *
sizeof(float16));
143 accelId + currAccelIdx,
"host_weights",
144 weightsTile->
data<float16>(),
145 weightsShape.storageSize() *
sizeof(float16));
146 int inputDims[4] = { inputShape[0], inputShape[1],
147 inputShape[2], inputShape[3] };
148 int weightsDims[4] = { weightsShape[0], weightsShape[1],
151 int outputDims[4] = { outputShape[0], outputShape[1],
152 outputShape[2], outputShape[3] };
157 int ifmapStart = (iC == wC) ? 0 : ifmapOffset;
162 bool accumulate = wC > 0;
165 bool readInputs =
false;
167 lastReadInputTileIdx[currAccelIdx]) {
169 lastReadInputTileIdx[currAccelIdx] = inputTileIdx;
171 bool readWeights =
false;
173 lastReadWeightTileIdx[currAccelIdx]) {
175 lastReadWeightTileIdx[currAccelIdx] = weightTileIdx;
180 bool sendResults = wC == weightChanTiles - 1;
182 std::unique_ptr<volatile int> finishFlag;
185 finishFlag = invokeSystolicArrayKernel(
186 accelId + currAccelIdx,
187 inputTile->
data<float16>(),
188 weightsTile->
data<float16>(),
189 outputTile->
data<float16>(), inputDims,
190 weightsDims, outputDims,
191 inputShape.getPadding(3),
192 weightsShape.getPadding(3),
193 outputShape.getPadding(3), inputHaloPad,
194 getRowStride(), ifmapStart, kernStart,
195 accumulate, readInputs, readWeights,
196 sendResults, &actInfo);
200 currAccelIdx, accelId + currAccelIdx,
202 inputTile->
data<float16>(),
203 weightsTile->
data<float16>(),
204 outputTile->
data<float16>(), smv::spad0,
205 smv::spad1, smv::spad2, inputDims,
206 weightsDims, outputDims,
207 inputShape.getPadding(3),
208 weightsShape.getPadding(3),
209 outputShape.getPadding(3), inputHaloPad,
210 getRowStride(), getColStride(), ifmapStart,
211 kernStart, accumulate, readInputs,
212 readWeights, sendResults, actInfo.function,
213 actInfo.params, &sampling);
216 currAccelIdx, std::move(finishFlag));
218 ifmapOffset += weightsTile->getShape()[3];
219 if (inputChanTiles == weightChanTiles) {
222 }
else if (inputChanTiles == 1) {
226 "The input/weight tiles can have different "
227 "number of channels only when the inputs "
228 "don't need channelwise tiling.");
231 if (needOutputIteration)
232 kernStart += outputShape[3];
243 std::unique_ptr<volatile int> SmvConvolutionOp::invokeSystolicArrayKernel(
267 systolic_array_params_t params;
268 params.input_base_addr = inputs;
269 params.weight_base_addr = weights;
270 params.output_base_addr = outputs;
271 memcpy(params.input_dims, inputsDims,
sizeof(
int) * 4);
272 memcpy(params.weight_dims, weightsDims,
sizeof(
int) * 4);
273 memcpy(params.output_dims, outputsDims,
sizeof(
int) * 4);
274 params.input_dims[3] += inputsPad;
275 params.weight_dims[3] += weightsPad;
276 params.output_dims[3] += outputPad;
277 params.stride = stride;
278 memcpy(params.input_halo_pad, inputHaloPad,
sizeof(
int) * 4);
279 params.ifmap_start = ifmapStart;
280 params.kern_start = kernStart;
281 params.accum_results = accumulate;
282 params.read_inputs = readInputs;
283 params.read_weights = readWeights;
284 params.send_results = sendResults;
287 memcpy(¶ms.act_type, &(actInfo->function),
sizeof(
activation_type));
289 return std::unique_ptr<volatile int>(
290 invokeSystolicArrayAndReturn(accelId, params));
296 void SmvConvolutionOp::tile() {
306 tiledTensors = smaug::smv::conv::TilingOptimizer::doTiling(
this);
309 void SmvConvolutionOp::run() {
310 auto input = getInput(Inputs);
311 auto kernels = getInput(Kernels);
312 auto output = getOutput(Outputs);
313 const TensorShape& inputShape = input->getShape();
314 const TensorShape& kernelShape = kernels->getShape();
315 const TensorShape& outputShape = output->getShape();
316 assert(inputShape.getLayout() == DataLayout::NHWC);
317 assert(kernelShape.getLayout() == DataLayout::NHWC);
318 assert(outputShape.getLayout() == DataLayout::NHWC);
319 dout(2) << *kernels <<
"\n";
322 auto stats = gem5::ScopedStats(
323 stats::kTensorPrepStart, stats::kTensorPrepEnd);
324 tiledTensors[0].copyDataToAllTiles();
325 tiledTensors[1].copyDataToAllTiles();
328 runNHWC(tiledTensors[0], tiledTensors[1], tiledTensors[2]);
331 auto stats = gem5::ScopedStats(
332 stats::kTensorFinalStart, stats::kTensorFinalEnd);
333 tiledTensors[2].untile();