SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
repeat_op.h
1 #ifndef _OPERATORS_REPEAT_OP_H_
2 #define _OPERATORS_REPEAT_OP_H_
3 
4 #include <vector>
5 #include <initializer_list>
6 
7 #include "smaug/core/backend.h"
8 #include "smaug/core/operator.h"
10 
11 namespace smaug {
12 
23 template <typename Backend>
24 class RepeatOp : public Operator {
25  public:
26  RepeatOp(const std::string& name, Workspace* workspace)
27  : Operator(name, OpType::Repeat, workspace) {
28  inputs.resize(1, nullptr);
29  outputs.resize(1, nullptr);
30  }
31 
32  RepeatOp(const std::string& name,
33  Workspace* workspace,
34  const std::vector<int> _multiples)
35  : Operator(name, OpType::Repeat, workspace), multiples(_multiples) {
36  inputs.resize(1, nullptr);
37  outputs.resize(1, nullptr);
38  }
39 
41  void setMultiples(const std::vector<int>& _multiples) {
42  multiples = _multiples;
43  }
44 
46  void setMultiples(const std::initializer_list<int>& _multiples) {
47  multiples = _multiples;
48  }
49 
50  bool validate() override {
51  for (int multiple : multiples) {
52  if (multiple == 0)
53  return false;
54  }
55  return Operator::validate();
56  }
57 
58  void createAllTensors() override {
59  Tensor* input = getInput(0);
60  std::vector<int> dims = input->getShape().dims();
61  for (int i = 0; i < multiples.size(); i++)
62  dims[i] *= multiples[i];
63  TensorShape shape(
64  dims, input->getShape().getLayout(), Backend::Alignment);
65  Tensor* output = new Tensor(name, shape);
66  workspace->addTensor(output);
67  outputs.at(0) = output;
68  }
69 
70  void run() override {
71  Tensor* input = getInput(0);
72  Tensor* output = getOutput(0);
73  int ndims = input->ndims();
74  std::vector<int> inputDims = input->getShape().dims();
75  std::vector<int> outputDims = output->getShape().dims();
76  std::vector<int> srcOrigin = std::vector<int>(ndims, 0);
77  // Copy the first piece of input into output.
78  copyTensorRegion(output, input, srcOrigin, srcOrigin, inputDims);
79  for (int i = ndims - 1; i >= 0; i--) {
80  std::vector<int> currCopyRegion = inputDims;
81  for (int j = i + 1; j < ndims; j++)
82  currCopyRegion[j] = outputDims[j];
83  std::vector<int> dstOrigin(ndims, 0);
84  dstOrigin[i] = inputDims[i];
85  while (dstOrigin[i] + currCopyRegion[i] <= outputDims[i]) {
87  output, output, dstOrigin, srcOrigin, currCopyRegion);
88  dstOrigin[i] += currCopyRegion[i];
89  // Double the copy size for the next iteration.
90  currCopyRegion[i] *= 2;
91  }
92  // Copy the remaining part if there's any.
93  if (dstOrigin[i] < outputDims[i]) {
94  currCopyRegion[i] = outputDims[i] - dstOrigin[i];
96  output, output, dstOrigin, srcOrigin, currCopyRegion);
97  }
98  }
99  }
100 
101  protected:
102  std::vector<int> multiples;
103 };
104 
105 } // namespace smaug
106 
107 #endif
smaug::copyTensorRegion
void copyTensorRegion(Tensor *dest, Tensor *src, std::vector< int > destOrigin, std::vector< int > srcOrigin, std::vector< int > regionSize)
Copies a region of a source Tensor to a corresponding region in a destination Tensor.
Definition: tensor_utils.cpp:65
tensor_utils.h
Utility functions for copying/printing/tiling tensors.
smaug::RepeatOp::setMultiples
void setMultiples(const std::initializer_list< int > &_multiples)
Set the number of copies of the Tensor along each dimension.
Definition: repeat_op.h:46
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47
smaug::RepeatOp::setMultiples
void setMultiples(const std::vector< int > &_multiples)
Set the number of copies of the Tensor along each dimension.
Definition: repeat_op.h:41