1 #include "smaug/core/backend.h" 
    3 #include "smaug/operators/smv/smv_pooling_op.h" 
    4 #include "smaug/operators/smv/smv_pooling_tiling.h" 
    5 #include "smaug/operators/smv/smv_kernels.h" 
    6 #include "smaug/utility/debug_stream.h" 
   12 const int kVectorSize = 8;
 
   25     int inputIfmapTiles = 
inputs.getShape()[0];
 
   26     int inputRowTiles = 
inputs.getShape()[1];
 
   27     int inputColTiles = 
inputs.getShape()[2];
 
   28     int inputChanTiles = 
inputs.getShape()[3];
 
   29     int outputChanTiles = 
outputs.getShape()[3];
 
   30     auto inputIdx = 
inputs.startIndex();
 
   31     auto outputIdx = 
outputs.startIndex();
 
   33             smv::kPoolingHw, 
"host_inputs", getInputsMemType());
 
   35             smv::kPoolingHw, 
"host_results", getOutputsMemType());
 
   36     for (
int N = 0; N < inputIfmapTiles; N++) {
 
   37         for (
int H = 0; H < inputRowTiles; H++) {
 
   38             for (
int W = 0; W < inputColTiles; W++) {
 
   42                 while (iC < inputChanTiles && oC < outputChanTiles) {
 
   43                     int inputTileIdx = inputIdx(N, H, W, iC);
 
   44                     int outputTileIdx = outputIdx(N, H, W, oC);
 
   49                     dout(1) << 
"Input: " << inputTileIdx
 
   50                             << 
", output: " << outputTileIdx << 
"\n";
 
   51                     Tensor* inputTile = 
inputs.getTileWithData(inputTileIdx);
 
   53                     const TensorShape& inputShape = inputTile->getShape();
 
   54                     const TensorShape& outputShape = outputTile->getShape();
 
   56                                     inputTile->
data<float16>(),
 
   57                                     inputShape.storageSize() * 
sizeof(float16));
 
   59                             smv::kPoolingHw, 
"host_results",
 
   60                             outputTile->
data<float16>(),
 
   61                             outputShape.storageSize() * 
sizeof(float16));
 
   62                     int inputDims[4] = { inputShape[0], inputShape[1],
 
   63                                          inputShape[2], inputShape[3] };
 
   64                     int outputDims[4] = { outputShape[0], outputShape[1],
 
   65                                           outputShape[2], outputShape[3] };
 
   71                     int ofmapStart = (iC == oC) ? 0 : ofmapOffset;
 
   77                             inputTile->
data<float16>(),
 
   78                             outputTile->
data<float16>(), smv::spad0, smv::spad1,
 
   79                             inputDims, outputDims, inputShape.getPadding(3),
 
   80                             outputShape.getPadding(3), getPoolingSize().first,
 
   81                             getPoolingSize().second, getPoolingStride().first,
 
   82                             getPoolingStride().second, ofmapStart, &sampling);
 
   84                     ofmapOffset += inputTile->getShape()[3];
 
   85                     if (inputChanTiles == outputChanTiles) {
 
   88                     } 
else if (outputChanTiles == 1) {
 
   92                                "The inputs/outputs tiles can have different " 
   93                                "number of channels only when the outputs don't " 
   94                                "need channelwise tiling.");
 
  102 void SmvPoolingOp::tile() {
 
  106     tiledTensors = smaug::smv::pool::TilingOptimizer::doTiling(
this);
 
  110     auto input = getInput(Inputs);
 
  111     auto output = getOutput(Outputs);
 
  113     const TensorShape& outputShape = output->getShape();
 
  114     assert(inputShape.getLayout() == DataLayout::NHWC);
 
  115     assert(outputShape.getLayout() == DataLayout::NHWC);
 
  119                 stats::kTensorPrepStart, stats::kTensorPrepEnd);
 
  120         tiledTensors[0].copyDataToAllTiles();
 
  123     runNHWC(tiledTensors[0], tiledTensors[1]);
 
  127                 stats::kTensorFinalStart, stats::kTensorFinalEnd);
 
  128         tiledTensors[1].untile();
 
  132 void SmvMaxPoolingOp::tile() { SmvPoolingOp::tile(); }
 
  134 void SmvAvgPoolingOp::tile() { SmvPoolingOp::tile(); }