1 #include "smaug/core/backend.h"
3 #include "smaug/operators/smv/smv_batch_norm_op.h"
4 #include "smaug/operators/smv/smv_batch_norm_tiling.h"
5 #include "smaug/operators/smv/smv_kernels.h"
6 #include "smaug/operators/smv/smv_accel_pool.h"
7 #include "smaug/utility/debug_stream.h"
13 const int kVectorSize = 8;
25 int inputNumTiles = inputs.getShape()[0];
26 int inputActTiles = inputs.getShape()[1];
27 int weightActTiles = weights.getShape()[1];
28 auto inputIdx = inputs.startIndex();
29 auto weightIdx = weights.startIndex();
30 auto outputIdx = outputs.startIndex();
32 smv::kBatchNormHw,
"host_inputs", getInputsMemType());
34 smv::kBatchNormHw,
"host_weights", getWeightsMemType());
36 smv::kBatchNormHw,
"host_results", getOutputsMemType());
37 for (
int N = 0; N < inputNumTiles; N++) {
41 while (iC < inputActTiles && wC < weightActTiles) {
42 int inputTileIdx = inputIdx(N, iC);
43 int weightTileIdx = weightIdx(0, wC);
44 int outputTileIdx = outputIdx(N, iC);
45 dout(1) <<
"Input: " << inputIdx(N, iC)
46 <<
", weight: " << weightIdx(0, wC)
47 <<
", output: " << outputIdx(N, iC) <<
"\n";
50 Tensor* outputTile = outputs[outputTileIdx];
51 const TensorShape& inputShape = inputTile->getShape();
52 const TensorShape& weightsShape = weightsTile->getShape();
53 const TensorShape& outputShape = outputTile->getShape();
55 inputTile->
data<float16>(),
56 inputShape.storageSize() *
sizeof(float16));
58 weightsTile->
data<float16>(),
59 weightsShape.storageSize() *
sizeof(float16));
61 outputTile->
data<float16>(),
62 outputShape.storageSize() *
sizeof(float16));
63 int inputDims[2] = { inputShape[0], inputShape[1] };
68 int actStart = (iC == wC) ? 0 : actOffset;
70 bool sendOutputs = iC == wC || wC == weightActTiles - 1;
73 inputTile->
data<float16>(),
74 weightsTile->
data<float16>(),
75 outputTile->
data<float16>(), smv::spad0, smv::spad1,
76 smv::spad2, inputDims, weightsShape[1],
77 inputShape.getPadding(1), actStart, sendOutputs,
78 actInfo.function, actInfo.params);
80 actOffset += weightsTile->getShape()[1];
81 if (inputActTiles == weightActTiles) {
84 }
else if (inputActTiles == 1) {
87 assert(
false &&
"The input/weight tiles can have different "
88 "number of channels only when the inputs "
89 "don't need activation-wise tiling.");
105 assert(weights.size() == 1);
106 int inputNumTiles = inputs.getShape()[0];
107 int inputRowTiles = inputs.getShape()[1];
108 int inputColTiles = inputs.getShape()[2];
109 int inputChanTiles = inputs.getShape()[3];
110 auto inputIdx = inputs.startIndex();
111 auto outputIdx = outputs.startIndex();
113 const TensorShape& weightShape = weightTile->getShape();
116 weightTile->
data<float16>(),
117 weightShape.storageSize() *
sizeof(float16));
119 smv::kBatchNormHw + i,
"host_inputs", getInputsMemType());
121 smv::kBatchNormHw + i,
"host_weights", getWeightsMemType());
123 smv::kBatchNormHw + i,
"host_results", getOutputsMemType());
126 int currAccelIdx = 0;
127 for (
int N = 0; N < inputNumTiles; N++) {
128 for (
int H = 0; H < inputRowTiles; H++) {
129 for (
int W = 0; W < inputColTiles; W++) {
132 for (
int C = 0; C < inputChanTiles; C++) {
133 int inputTileIdx = inputIdx(N, H, W, C);
134 int outputTileIdx = outputIdx(N, H, W, C);
135 dout(1) <<
"Input: " << inputTileIdx <<
", Weight: 0"
136 <<
", output: " << outputTileIdx <<
"\n";
138 Tensor* outputTile = outputs[outputTileIdx];
139 const TensorShape& inputShape = inputTile->getShape();
140 const TensorShape& outputShape = outputTile->getShape();
142 "host_inputs", inputTile->
data<float16>(),
143 inputShape.storageSize() *
sizeof(float16));
145 smv::kBatchNormHw + currAccelIdx,
"host_results",
146 outputTile->
data<float16>(),
147 outputShape.storageSize() *
sizeof(float16));
148 int inputDims[4] = { inputShape[0], inputShape[1],
149 inputShape[2], inputShape[3] };
151 std::unique_ptr<volatile int> finishFlag =
154 smv::kBatchNormHw + currAccelIdx,
156 inputTile->
data<float16>(),
157 weightTile->
data<float16>(),
158 outputTile->
data<float16>(), smv::spad0,
159 smv::spad1, smv::spad2, inputDims,
160 weightShape[1], inputShape.getPadding(3),
161 weightShape.getPadding(1), ifmapOffset,
162 actInfo.function, actInfo.params,
165 currAccelIdx, std::move(finishFlag));
166 ifmapOffset += inputShape[3];
176 void SmvBatchNormOp::tile() {
184 void SmvBatchNormOp::run() {
186 auto input = getInput(Inputs);
187 auto mean = getInput(Mean);
188 auto variance = getInput(Variance);
189 auto gamma = getInput(Gamma);
190 auto beta = getInput(Beta);
191 auto output = getOutput(Outputs);
192 const TensorShape& inputShape = input->getShape();
193 const TensorShape& kernelShape = mean->getShape();
194 const TensorShape& outputShape = output->getShape();
195 bool isPostConv = (input->ndims() == 4);
196 dout(2) << *mean <<
"\n";
197 dout(2) << *variance<<
"\n";
198 dout(2) << *gamma <<
"\n";
199 dout(2) << *beta <<
"\n";
202 auto stats = gem5::ScopedStats(
203 stats::kTensorPrepStart, stats::kTensorPrepEnd);
204 tiledTensors[0].copyDataToAllTiles();
205 tiledTensors[1].copyDataToAllTiles();
209 assert(inputShape.getLayout() == DataLayout::NHWC);
210 assert(outputShape.getLayout() == DataLayout::NHWC);
211 runNHWC(tiledTensors[0], tiledTensors[1], tiledTensors[2]);
213 assert(inputShape.getLayout() == DataLayout::NC);
214 assert(outputShape.getLayout() == DataLayout::NC);
215 runNA(tiledTensors[0], tiledTensors[1], tiledTensors[2]);
219 auto stats = gem5::ScopedStats(
220 stats::kTensorFinalStart, stats::kTensorFinalEnd);
221 tiledTensors[2].untile();