1 #include "smaug/core/backend.h"
4 #include "smaug/operators/smv/smv_eltwise_add_op.h"
5 #include "smaug/operators/smv/smv_unary_op_common.h"
6 #include "smaug/operators/smv/smv_kernels.h"
7 #include "smaug/utility/debug_stream.h"
12 void SmvEltwiseAddOp::runX(TiledTensor& inputs0,
14 TiledTensor& outputs) {
15 assert(inputs0.size() == inputs1.size() &&
16 inputs0.size() == outputs.size());
18 smv::kEltwiseOpHw,
"host_inputs0", getInputsMemType());
20 smv::kEltwiseOpHw,
"host_inputs1", getInputsMemType());
22 smv::kEltwiseOpHw,
"host_results", getOutputsMemType());
23 for (
int i = 0; i < inputs0.size(); i++) {
24 dout(1) <<
"Input0: " << i <<
", input1: " << i <<
", output: " << i
26 Tensor* input0Tile = inputs0.getTileWithData(i);
27 Tensor* input1Tile = inputs1.getTileWithData(i);
28 Tensor* outputTile = outputs[i];
29 const TensorShape& inputShape = input0Tile->getShape();
30 const TensorShape& outputShape = outputTile->getShape();
32 input0Tile->data<float16>(),
33 inputShape.storageSize() *
sizeof(float16));
35 input1Tile->data<float16>(),
36 inputShape.storageSize() *
sizeof(float16));
38 outputTile->data<float16>(),
39 outputShape.storageSize() *
sizeof(float16));
42 input0Tile->data<float16>(), input1Tile->data<float16>(),
43 outputTile->data<float16>(), smv::spad0, smv::spad1,
44 smv::spad2, inputShape.storageSize());
48 void SmvEltwiseAddOp::tile() {
51 auto inputs0 = getInput(Input0);
52 auto inputs1 = getInput(Input1);
53 auto outputs = getOutput(Outputs);
55 std::min(SmvBackend::SpadSize() / inputs0->getDataTypeSize(),
56 inputs0->getShape().storageSize());
57 TensorShape tileShape(
58 { 1, maxTileSize }, DataLayout::NC, SmvBackend::Alignment);
60 inputs0, tileShape,
this,
false);
62 inputs1, tileShape,
this,
false);
64 outputs, tileShape,
this,
false);
67 void SmvEltwiseAddOp::run() {
68 auto inputs0 = getInput(Input0);
69 auto inputs1 = getInput(Input1);
70 auto outputs = getOutput(Outputs);
71 const TensorShape& inputs0Shape = inputs0->getShape();
72 const TensorShape& inputs1Shape = inputs1->getShape();
73 const TensorShape& outputsShape = outputs->getShape();
74 assert(inputs0Shape == inputs1Shape && inputs0Shape == outputsShape);
77 auto stats = gem5::ScopedStats(
78 stats::kTensorPrepStart, stats::kTensorPrepEnd);
79 tiledTensors[0].copyDataToAllTiles();
80 tiledTensors[1].copyDataToAllTiles();
83 runX(tiledTensors[0], tiledTensors[1], tiledTensors[2]);
86 auto stats = gem5::ScopedStats(
87 stats::kTensorFinalStart, stats::kTensorFinalEnd);