1 #ifndef _OPERATORS_REPEAT_OP_H_
2 #define _OPERATORS_REPEAT_OP_H_
5 #include <initializer_list>
7 #include "smaug/core/backend.h"
8 #include "smaug/core/operator.h"
23 template <
typename Backend>
24 class RepeatOp :
public Operator {
26 RepeatOp(
const std::string& name, Workspace* workspace)
27 : Operator(name, OpType::Repeat, workspace) {
28 inputs.resize(1,
nullptr);
29 outputs.resize(1,
nullptr);
32 RepeatOp(
const std::string& name,
34 const std::vector<int> _multiples)
35 : Operator(name, OpType::Repeat, workspace), multiples(_multiples) {
36 inputs.resize(1,
nullptr);
37 outputs.resize(1,
nullptr);
42 multiples = _multiples;
47 multiples = _multiples;
50 bool validate()
override {
51 for (
int multiple : multiples) {
58 void createAllTensors()
override {
59 Tensor* input = getInput(0);
60 std::vector<int> dims = input->getShape().dims();
61 for (
int i = 0; i < multiples.size(); i++)
62 dims[i] *= multiples[i];
64 dims, input->getShape().getLayout(), Backend::Alignment);
65 Tensor* output =
new Tensor(name, shape);
66 workspace->addTensor(output);
67 outputs.at(0) = output;
71 Tensor* input = getInput(0);
72 Tensor* output = getOutput(0);
73 int ndims = input->ndims();
74 std::vector<int> inputDims = input->getShape().dims();
75 std::vector<int> outputDims = output->getShape().dims();
76 std::vector<int> srcOrigin = std::vector<int>(ndims, 0);
79 for (
int i = ndims - 1; i >= 0; i--) {
80 std::vector<int> currCopyRegion = inputDims;
81 for (
int j = i + 1; j < ndims; j++)
82 currCopyRegion[j] = outputDims[j];
83 std::vector<int> dstOrigin(ndims, 0);
84 dstOrigin[i] = inputDims[i];
85 while (dstOrigin[i] + currCopyRegion[i] <= outputDims[i]) {
87 output, output, dstOrigin, srcOrigin, currCopyRegion);
88 dstOrigin[i] += currCopyRegion[i];
90 currCopyRegion[i] *= 2;
93 if (dstOrigin[i] < outputDims[i]) {
94 currCopyRegion[i] = outputDims[i] - dstOrigin[i];
96 output, output, dstOrigin, srcOrigin, currCopyRegion);
102 std::vector<int> multiples;