1 #include "smaug/core/backend.h"
4 #include "smaug/operators/smv/smv_tiling_base.h"
5 #include "smaug/utility/debug_stream.h"
13 const std::vector<int>& minShape) {
14 DataLayout layout = shape.getLayout();
15 assert(layout == NHWC || layout == NCHW || layout == NC);
16 if (shape.storageSize() <= maxTileSize)
17 return TilingDims::None;
18 int minN = std::min(shape[0], minShape[0]);
19 bool isNHWC = layout == NHWC;
20 int cIdx = isNHWC ? 3 : 1;
21 int minC = std::min(shape[cIdx], minShape[cIdx]);
22 int sizePerN = shape.storageSize() / shape[0];
23 if (sizePerN * minN <= maxTileSize)
24 return TilingDims::DimN;
25 if (sizePerN * (minC * 1.0 / shape[cIdx]) <= maxTileSize)
26 return TilingDims::DimNC;
27 if (shape.ndims() == 2) {
28 std::cerr <<
"[ERROR]: Unable to find a supported set of tiling "
29 "dimensions for 2D tensor with shape "
33 int hIdx = isNHWC ? 1 : 2;
34 int wIdx = isNHWC ? 2 : 3;
35 int minH = std::min(shape[hIdx], minShape[hIdx]);
36 int minW = std::min(shape[wIdx], minShape[wIdx]);
37 if (sizePerN * (minH * 1.0 / shape[hIdx]) <= maxTileSize)
38 return TilingDims::DimNH;
39 if (sizePerN * (minW * 1.0 / shape[wIdx]) <= maxTileSize)
40 return TilingDims::DimNW;
41 if (sizePerN * (minH * 1.0 / shape[hIdx]) * (minW * 1.0 / shape[wIdx]) <=
43 return TilingDims::DimNHW;
44 if (sizePerN * (minC * 1.0 / shape[cIdx]) * (minH * 1.0 / shape[hIdx]) <=
46 return TilingDims::DimNCH;
47 if (sizePerN * (minC * 1.0 / shape[cIdx]) * (minW * 1.0 / shape[wIdx]) <=
49 return TilingDims::DimNCW;
50 std::cerr <<
"[ERROR]: Unable to find a supported set of tiling dimensions "
51 "for 4D tensor with shape "
59 const std::vector<int>& minShape,
60 const std::vector<int>& strides,
61 std::vector<TensorShape>& configs) {
62 int minN = std::min(minShape[0], shape[0]);
63 int minC = std::min(minShape[1], shape[1]);
64 int strideN = strides[0];
65 int strideC = strides[1];
66 for (
int n = minN; n <= shape[0]; n += strideN) {
67 for (
int c = minC; c <= shape[1]; c += strideC) {
69 { n, c }, shape.getLayout(), shape.getAlignment());
70 if (config.storageSize() <= maxTileSize)
71 configs.push_back(config);
81 const std::vector<int>& minShape,
82 const std::vector<int>& strides,
83 std::vector<TensorShape>& configs) {
84 bool isNHWC = shape.getLayout() == NHWC;
85 int idxH = isNHWC ? 1 : 2;
86 int idxW = isNHWC ? 2 : 3;
87 int idxC = isNHWC ? 3 : 1;
88 int minN = std::min(minShape[0], shape[0]);
89 int minH = std::min(minShape[idxH], shape[idxH]);
90 int minW = std::min(minShape[idxW], shape[idxW]);
91 int minC = std::min(minShape[idxC], shape[idxC]);
92 int strideN = strides[0];
93 int strideH = strides[idxH];
94 int strideW = strides[idxW];
95 int strideC = strides[idxC];
96 for (
int n = minN; n <= shape[0]; n += strideN) {
97 for (
int c = minC; c <= shape[idxC]; c += strideC) {
98 for (
int h = minH; h <= shape[idxH]; h += strideH) {
99 for (
int w = minW; w <= shape[idxW]; w += strideW) {
104 shape.getAlignment());
108 shape.getAlignment());
110 if (config.storageSize() <= maxTileSize)
111 configs.push_back(config);