3 #include "smaug/core/backend.h"
5 #include "smaug/operators/smv/smv_inner_product_op.h"
6 #include "smaug/operators/smv/smv_inner_product_tiling.h"
7 #include "smaug/utility/debug_stream.h"
18 inputs->getShape(), maxTileSize, { 1, kNumMaccsPerPE });
20 weights->getShape(), maxTileSize, { kNumPEs, kNumMaccsPerPE });
22 outputs->getShape(), maxTileSize, { 1, kNumPEs });
29 if (needsNwiseTiling(bestWeightTilingDims) && bestOutputTilingDims != None)
30 bestOutputTilingDims = DimNC;
32 return { bestInputTilingDims, bestWeightTilingDims, bestOutputTilingDims };
36 Tensor* inputs = op->getInput(op->Inputs);
37 Tensor* weights = op->getInput(op->Weights);
38 Tensor* outputs = op->getOutput(op->Outputs);
39 int maxTileSize = SmvBackend::SpadSize() / inputs->getDataTypeSize();
40 std::array<TilingDims, 3> strategies =
46 dout(1) <<
" Tiling dimensions chosen:\n"
47 <<
" input: " << inputTilingDims
48 <<
", weight: " << weightTilingDims
49 <<
", output: " << outputTilingDims <<
"\n";
66 std::vector<TensorShape> inputConfigs;
67 if (inputTilingDims == DimN) {
70 { 1, inputsShape[1] },
73 }
else if (inputTilingDims == DimNC) {
74 std::vector<int> minShape = inputsShape.dims();
77 { 1, kNumMaccsPerPE },
78 { 1, kNumMaccsPerPE },
82 inputConfigs.push_back(inputsShape);
84 assert(!inputConfigs.empty() &&
"No tiling configurations found!");
87 std::list<TilingConfig> inputWeightConfigs;
88 for (
auto it = inputConfigs.begin(); it != inputConfigs.end(); ++it) {
90 if (weightTilingDims == DimN) {
91 int minOfmaps = std::min(weightsShape[0], kNumPEs);
92 for (
int n = minOfmaps; n <= weightsShape[0]; n += kNumPEs) {
95 inputsShape.getLayout(),
96 SmvBackend::Alignment);
97 if (config.weights.storageSize() <= maxTileSize) {
98 config.inputs = inputsShape;
99 inputWeightConfigs.push_back(config);
104 }
else if (weightTilingDims == DimNC) {
105 int minNeurons = std::min(weightsShape[0], kNumPEs);
106 int minActs = std::min(weightsShape[1], kNumMaccsPerPE);
107 for (
int n = minNeurons; n <= weightsShape[0]; n += kNumPEs) {
109 config.weights = weightsShape;
110 config.weights[0] = n;
111 if (needsCwiseTiling(inputTilingDims)) {
114 config.weights[1] = inputsShape[1];
115 if (config.weights.storageSize() <= maxTileSize) {
116 config.inputs = inputsShape;
117 inputWeightConfigs.push_back(config);
124 for (
int c = minActs; c <= weightsShape[1];
125 c += kNumMaccsPerPE) {
126 config.weights[1] = c;
127 if (config.weights.storageSize() <= maxTileSize) {
128 config.inputs = inputsShape;
129 inputWeightConfigs.push_back(config);
138 config.inputs = inputsShape;
139 config.weights = weightsShape;
140 if (needsCwiseTiling(inputTilingDims)) {
144 config.weights[1] = inputsShape[1];
146 inputWeightConfigs.push_back(config);
149 assert(!inputWeightConfigs.empty() &&
"No tiling configurations found!");
152 std::vector<TilingConfig> fullConfigs;
153 for (
auto it = inputWeightConfigs.begin(); it != inputWeightConfigs.end();
155 int minChannels = std::min(it->weights[0], kNumPEs);
156 bool weightsNeedTiling = (weightTilingDims != None);
157 bool outputsNeedTiling = (outputTilingDims != None);
158 for (
int c = minChannels; c <= weightsShape[0]; c += kNumPEs) {
160 config.outputs = outputsShape;
161 config.outputs[0] = config.inputs[0];
162 if (weightsNeedTiling && outputsNeedTiling) {
163 config.outputs[1] = config.weights[0];
164 }
else if (outputsNeedTiling) {
169 config.outputs[1] = c;
171 if (config.outputs.storageSize() <= maxTileSize) {
172 fullConfigs.push_back(config);
176 if (weightsNeedTiling || outputTilingDims == None)
180 dout(2) <<
" Number of possible tiling configs: " << fullConfigs.size()
182 for (
auto& config : fullConfigs)
183 dout(2) <<
" " << config <<
"\n";
184 auto maxIt = std::max_element(
188 return c1.getTotalSize() < c2.getTotalSize();
190 assert(maxIt != fullConfigs.end() &&
"Failed to get best tiling config!");
192 maxIt->inputTilingDims = inputTilingDims;
193 maxIt->weightTilingDims = weightTilingDims;
194 maxIt->outputTilingDims = outputTilingDims;
199 auto input = op->getInput(SmvInnerProductOp::Inputs);
200 auto kernels = op->getInput(SmvInnerProductOp::Weights);
201 auto output = op->getOutput(SmvInnerProductOp::Outputs);
211 return { tiledInputs, tiledWeights, tiledOutputs };