3 #include "smaug/core/backend.h"
5 #include "smaug/operators/smv/smv_pooling_op.h"
6 #include "smaug/operators/smv/smv_pooling_tiling.h"
7 #include "smaug/utility/debug_stream.h"
17 std::pair<int, int> poolSize) {
23 { 1, poolSize.first, poolSize.second, kVectorSize });
25 outputs->getShape(), maxTileSize, { 1, 1, 1, kVectorSize });
32 if (needsHwiseTiling(bestInputTilingDims)) {
33 if (needsCwiseTiling(bestOutputTilingDims))
34 bestOutputTilingDims = DimNCH;
35 else if (needsWwiseTiling(bestOutputTilingDims))
36 bestOutputTilingDims = DimNHW;
38 bestOutputTilingDims = DimNH;
40 if (needsWwiseTiling(bestInputTilingDims)) {
41 if (needsCwiseTiling(bestOutputTilingDims))
42 bestOutputTilingDims = DimNCW;
43 else if (needsHwiseTiling(bestOutputTilingDims))
44 bestOutputTilingDims = DimNHW;
46 bestOutputTilingDims = DimNW;
49 return { bestInputTilingDims, bestOutputTilingDims };
53 Tensor* inputs = op->getInput(op->Inputs);
54 Tensor* outputs = op->getOutput(op->Outputs);
55 int maxTileSize = SmvBackend::SpadSize() / inputs->getDataTypeSize();
56 std::pair<int, int> poolSize = op->getPoolingSize();
57 std::pair<int, int> poolStride = op->getPoolingStride();
58 std::array<TilingDims, 2> strategies =
63 dout(2) <<
" Tiling dimensions chosen: \n"
64 <<
" input: " << inputTilingDims
65 <<
", output: " << outputTilingDims <<
"\n";
78 std::vector<TensorShape> inputConfigs;
79 if (inputTilingDims == DimN) {
80 std::vector<int> minShape = inputsShape.dims();
87 }
else if (inputTilingDims == DimNC) {
88 std::vector<int> minShape = inputsShape.dims();
90 minShape[3] = kVectorSize;
94 { 1, 1, 1, kVectorSize },
96 }
else if (inputTilingDims == DimNH) {
97 std::vector<int> minShape = inputsShape.dims();
99 minShape[1] = poolSize.first;
103 { 1, poolStride.first, 1, 1 },
105 }
else if (inputTilingDims == DimNW) {
106 std::vector<int> minShape = inputsShape.dims();
108 minShape[2] = poolSize.second;
112 { 1, 1, poolStride.second, 1 },
114 }
else if (inputTilingDims == DimNHW) {
115 std::vector<int> minShape = { 1, poolSize.first, poolSize.second,
117 std::vector<int> strides = { 1, poolStride.first, poolStride.second,
120 inputsShape, maxTileSize, minShape, strides, inputConfigs);
121 }
else if (inputTilingDims == DimNCH) {
122 std::vector<int> minShape = { 1, poolSize.first, inputsShape[2],
124 std::vector<int> strides = { 1, poolStride.first, 1, kVectorSize };
126 inputsShape, maxTileSize, minShape, strides, inputConfigs);
127 }
else if (inputTilingDims == DimNCW) {
128 std::vector<int> minShape = { 1, inputsShape[1], poolSize.second,
130 std::vector<int> strides = { 1, 1, poolStride.second, kVectorSize };
132 inputsShape, maxTileSize, minShape, strides, inputConfigs);
134 inputConfigs.push_back(inputsShape);
136 assert(!inputConfigs.empty() &&
"No tiling configurations found!");
139 std::vector<TilingConfig> fullConfigs;
140 for (
auto it = inputConfigs.begin(); it != inputConfigs.end(); ++it) {
142 config.outputs = outputsShape;
143 config.outputs[0] = config.inputs[0];
144 if (needsHwiseTiling(outputTilingDims)) {
145 config.outputs[1] = op->calcOutputRows(config.inputs[1]);
147 if (needsWwiseTiling(outputTilingDims)) {
148 config.outputs[2] = op->calcOutputCols(config.inputs[2]);
152 if (needsCwiseTiling(inputTilingDims) &&
153 needsCwiseTiling(outputTilingDims)) {
154 config.outputs[3] = config.inputs[3];
156 if (config.outputs.storageSize() <= maxTileSize) {
157 fullConfigs.push_back(config);
160 dout(2) <<
" Number of possible tiling configs: " << fullConfigs.size()
162 for (
auto& config : fullConfigs)
163 dout(2) <<
" " << config <<
"\n";
164 auto maxIt = std::max_element(
168 return c1.getTotalSize() < c2.getTotalSize();
170 assert(maxIt != fullConfigs.end() &&
"Failed to get best tiling config!");
172 maxIt->inputTilingDims = inputTilingDims;
173 maxIt->outputTilingDims = outputTilingDims;
177 std::array<TiledTensor, 2> TilingOptimizer::doTiling(
SmvPoolingOp* op) {
178 auto input = op->getInput(SmvPoolingOp::Inputs);
179 auto output = op->getOutput(SmvPoolingOp::Outputs);
181 int poolRowSize, poolColSize, poolRowStride, poolColStride;
182 std::tie(poolRowSize, poolColSize) = op->getPoolingSize();
183 std::tie(poolRowStride, poolColStride) = op->getPoolingStride();
195 output, tileConfig.outputs, op);
196 return { tiledInputs, tiledOutputs };