1 #include "smaug/operators/smv/smv_eltwise_mul_op.h"
2 #include "smaug/core/backend.h"
4 #include "smaug/operators/smv/smv_kernels.h"
5 #include "smaug/operators/smv/smv_unary_op_common.h"
6 #include "smaug/utility/debug_stream.h"
11 void SmvEltwiseMulOp::runX(TiledTensor& inputs0,
13 TiledTensor& outputs) {
14 assert(inputs0.size() == inputs1.size() &&
15 inputs0.size() == outputs.size());
17 smv::kEltwiseOpHw,
"host_inputs0", getInputsMemType());
19 smv::kEltwiseOpHw,
"host_inputs1", getInputsMemType());
21 smv::kEltwiseOpHw,
"host_results", getOutputsMemType());
22 for (
int i = 0; i < inputs0.size(); i++) {
23 dout(1) <<
"Input0: " << i <<
", input1: " << i <<
", output: " << i
25 Tensor* input0Tile = inputs0.getTileWithData(i);
26 Tensor* input1Tile = inputs1.getTileWithData(i);
27 Tensor* outputTile = outputs[i];
28 const TensorShape& inputShape = input0Tile->getShape();
29 const TensorShape& outputShape = outputTile->getShape();
31 input0Tile->data<float16>(),
32 inputShape.storageSize() *
sizeof(float16));
34 input1Tile->data<float16>(),
35 inputShape.storageSize() *
sizeof(float16));
37 outputTile->data<float16>(),
38 outputShape.storageSize() *
sizeof(float16));
41 input0Tile->data<float16>(), input1Tile->data<float16>(),
42 outputTile->data<float16>(), smv::spad0, smv::spad1,
43 smv::spad2, inputShape.storageSize());
47 void SmvEltwiseMulOp::tile() {
50 auto inputs0 = getInput(Input0);
51 auto inputs1 = getInput(Input1);
52 auto outputs = getOutput(Outputs);
54 std::min(SmvBackend::SpadSize() / inputs0->getDataTypeSize(),
55 inputs0->getShape().storageSize());
56 TensorShape tileShape(
57 { 1, maxTileSize }, DataLayout::NC, SmvBackend::Alignment);
66 void SmvEltwiseMulOp::run() {
67 auto inputs0 = getInput(Input0);
68 auto inputs1 = getInput(Input1);
69 auto outputs = getOutput(Outputs);
70 const TensorShape& inputs0Shape = inputs0->getShape();
71 const TensorShape& inputs1Shape = inputs1->getShape();
72 const TensorShape& outputsShape = outputs->getShape();
73 assert(inputs0Shape == inputs1Shape && inputs0Shape == outputsShape);
76 auto stats = gem5::ScopedStats(
77 stats::kTensorPrepStart, stats::kTensorPrepEnd);
78 tiledTensors[0].copyDataToAllTiles();
79 tiledTensors[1].copyDataToAllTiles();
82 runX(tiledTensors[0], tiledTensors[1], tiledTensors[2]);
85 auto stats = gem5::ScopedStats(
86 stats::kTensorFinalStart, stats::kTensorFinalEnd);