3 #include "smaug/core/backend.h"
5 #include "smaug/operators/smv/smv_batch_norm_op.h"
6 #include "smaug/operators/smv/smv_batch_norm_tiling.h"
7 #include "smaug/utility/debug_stream.h"
20 if (inputShape.ndims() == 4) {
24 { 1, kVectorSize, kVectorSize, kVectorSize });
30 weights->getShape(), maxTileSize, { 4, kVectorSize });
32 return { bestInputTilingDims, bestWeightTilingDims };
35 void TilingOptimizer::enumPostFCTilingConfigs(
39 std::array<TilingDims, 2> strategies,
40 std::list<TilingConfig>& fullConfigs) {
52 assert(inputTilingDims == None || inputTilingDims == DimN ||
53 inputTilingDims == DimNC);
54 assert(weightTilingDims == None || weightTilingDims == DimNC);
55 std::vector<TensorShape> inputsConfigs;
56 if (inputTilingDims == DimN) {
57 std::vector<int> minShape = inputsShape.dims();
60 inputsShape, maxTileSize, minShape, { 1, 1 }, inputsConfigs);
61 }
else if (inputTilingDims == DimNC) {
68 inputsConfigs.push_back(inputsShape);
70 assert(!inputsConfigs.empty() &&
"No tiling configurations found!");
73 for (
auto it = inputsConfigs.begin(); it != inputsConfigs.end(); ++it) {
74 TensorShape& inputsConfig = *it;
75 if (weightTilingDims == DimNC) {
76 if (needsCwiseTiling(inputTilingDims)) {
80 config.weights = weightsShape;
81 config.weights[1] = inputsConfig[1];
82 if (config.weights.storageSize() <= maxTileSize) {
83 config.inputs = inputsConfig;
84 config.outputs = inputsConfig;
85 fullConfigs.push_back(config);
90 int minChannels = std::min(weightsShape[1], kVectorSize);
91 for (
int c = minChannels; c <= weightsShape[1];
94 config.weights = weightsShape;
95 config.weights[1] = c;
96 if (config.weights.storageSize() <= maxTileSize) {
97 config.inputs = inputsConfig;
98 config.outputs = inputsConfig;
99 fullConfigs.push_back(config);
106 TilingConfig config(inputsConfig, weightsShape, inputsConfig);
107 fullConfigs.push_back(config);
110 assert(!fullConfigs.empty() &&
"No tiling configurations found!");
113 void TilingOptimizer::enumPostConvTilingConfigs(
114 TensorShape inputsShape,
115 TensorShape weightsShape,
117 std::array<TilingDims, 2> strategies,
118 std::list<TilingConfig>& fullConfigs) {
129 assert(inputTilingDims == None || inputTilingDims == DimN ||
130 inputTilingDims == DimNC || inputTilingDims == DimNH ||
131 inputTilingDims == DimNW || inputTilingDims == DimNHW ||
132 inputTilingDims == DimNCH || inputTilingDims == DimNCW);
133 assert(weightTilingDims == None);
134 std::vector<TensorShape> inputsConfigs;
135 if (inputTilingDims == DimN) {
136 std::vector<int> minShape = inputsShape.dims();
143 }
else if (inputTilingDims == DimNC) {
144 std::vector<int> minShape = inputsShape.dims();
146 minShape[3] = kVectorSize;
150 { 1, 1, 1, kVectorSize },
152 }
else if (inputTilingDims == DimNH) {
153 std::vector<int> minShape = inputsShape.dims();
155 minShape[1] = kVectorSize;
159 { 1, kVectorSize, 1, 1 },
161 }
else if (inputTilingDims == DimNW) {
162 std::vector<int> minShape = inputsShape.dims();
164 minShape[2] = kVectorSize;
168 { 1, 1, kVectorSize, 1 },
170 }
else if (inputTilingDims == DimNHW) {
171 std::vector<int> minShape = { 1, kVectorSize, kVectorSize,
176 { 1, kVectorSize, kVectorSize, 1 },
178 }
else if (inputTilingDims == DimNCH) {
179 std::vector<int> minShape = { 1, kVectorSize, inputsShape[2],
184 { 1, kVectorSize, 1, kVectorSize },
186 }
else if (inputTilingDims == DimNCW) {
187 std::vector<int> minShape = { 1, inputsShape[1], kVectorSize,
192 { 1, 1, kVectorSize, kVectorSize },
195 inputsConfigs.push_back(inputsShape);
197 assert(!inputsConfigs.empty() &&
"No tiling configurations found!");
200 for (
auto it = inputsConfigs.begin(); it != inputsConfigs.end(); ++it) {
201 TilingConfig config(*it, weightsShape, *it);
202 fullConfigs.push_back(config);
204 assert(!fullConfigs.empty() &&
"No tiling configurations found!");
210 int maxTileSize = SmvBackend::SpadSize() / inputs->getDataTypeSize();
212 assert(inputs->getShape() == outputs->getShape());
213 std::array<TilingDims, 2> strategies =
217 TilingDims outputTilingDims = inputTilingDims;
219 dout(2) <<
" Tiling dimensions chosen: \n"
220 <<
" input: " << inputTilingDims
221 <<
", weight: " << weightTilingDims
222 <<
", output: " << inputTilingDims <<
"\n";
226 std::list<TilingConfig> fullConfigs;
227 bool isPostConv = (inputs->ndims() == 4);
229 enumPostConvTilingConfigs(inputsShape,
235 enumPostFCTilingConfigs(inputsShape,
242 dout(2) <<
" Number of possible tiling configs: " << fullConfigs.size()
244 for (
auto& config : fullConfigs)
245 dout(2) <<
" " << config <<
"\n";
246 auto maxIt = std::max_element(
250 return c1.getTotalSize() < c2.getTotalSize();
252 assert(maxIt != fullConfigs.end() &&
"Failed to get best tiling config!");
254 (*maxIt).inputTilingDims = inputTilingDims;
255 (*maxIt).weightTilingDims = weightTilingDims;
256 (*maxIt).outputTilingDims = outputTilingDims;
261 auto inputs = op->getInput(SmvBatchNormOp::Inputs);
262 auto mean = op->getInput(SmvBatchNormOp::Mean);
263 auto variance = op->getInput(SmvBatchNormOp::Variance);
264 auto gamma = op->getInput(SmvBatchNormOp::Gamma);
265 auto beta = op->getInput(SmvBatchNormOp::Beta);
268 { mean, variance, gamma, beta }, 0, op->getWorkspace());
269 auto outputs = op->getOutput(SmvBatchNormOp::Outputs);
280 return { tiledInputs, tiledWeights, tiledOutputs };