SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
concat_op.h
1 #ifndef _OPERATORS_CONCAT_OP_H_
2 #define _OPERATORS_CONCAT_OP_H_
3 
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
7 
8 namespace smaug {
9 
17 template <typename Backend>
18 class ConcatOp : public Operator {
19  public:
20  ConcatOp(const std::string& name, Workspace* workspace)
21  : Operator(name, OpType::Concat, workspace) {
22  outputs.resize(1, nullptr);
23  }
24 
33  ConcatOp(const std::string& name,
34  Workspace* workspace,
35  int num,
36  int axis = 0)
37  : Operator(name, OpType::Concat, workspace), concatAxis(axis) {
38  inputs.resize(num);
39  outputs.resize(1, nullptr);
40  }
41 
43  void setNumInputs(int num) { inputs.resize(num); }
45  void setConcatAxis(int axis) { concatAxis = axis; }
46 
47  TensorShape inferOutputShape() const {
48  assert(getInputs().size() > 0 && "Unable to get inputs for concat op!");
49  std::vector<int> dims = getInput(0)->getShape().dims();
50  DataLayout layout = getInput(0)->getShape().getLayout();
51  int dim = 0;
52  for (int i = 0; i < getInputs().size(); i++) {
53  dim += getInput(i)->dim(concatAxis);
54  }
55  dims[concatAxis] = dim;
56  return TensorShape(dims, layout, Backend::Alignment);
57  }
58 
59  void createOutputTensor() {
60  TensorShape shape = inferOutputShape();
61  Tensor* output = new Tensor(name, shape);
62  workspace->addTensor(output);
63  outputs.at(0) = output;
64  }
65 
66  void createAllTensors() override{
67  createOutputTensor();
68  }
69 
70  void run() override {
71  Tensor* output = getOutput(0);
72  int ndims = output->ndims();
73  std::vector<int> dstOrigin(ndims, 0);
74  for (int i = 0; i < getInputs().size(); i++) {
75  Tensor* input = getInput(i);
76  copyTensorRegion(output,
77  input,
78  dstOrigin,
79  std::vector<int>(ndims, 0),
80  input->getShape().dims());
81  dstOrigin[concatAxis] += input->dim(concatAxis);
82  }
83  }
84 
85  int getConcatAxis() const { return concatAxis; }
86 
87  protected:
88  int concatAxis;
89 };
90 
91 } // namespace smaug
92 
93 #endif
smaug::copyTensorRegion
void copyTensorRegion(Tensor *dest, Tensor *src, std::vector< int > destOrigin, std::vector< int > srcOrigin, std::vector< int > regionSize)
Copies a region of a source Tensor to a corresponding region in a destination Tensor.
Definition: tensor_utils.cpp:65
smaug::ConcatOp::setNumInputs
void setNumInputs(int num)
Set the number of Tensors to concatenate.
Definition: concat_op.h:43
smaug::Workspace
Workspace is the container and owner of all Tensors and Operators in the Network.
Definition: workspace.h:17
tensor_utils.h
Utility functions for copying/printing/tiling tensors.
smaug::ConcatOp::ConcatOp
ConcatOp(const std::string &name, Workspace *workspace, int num, int axis=0)
Create a ConcatOp.
Definition: concat_op.h:33
smaug::ConcatOp::setConcatAxis
void setConcatAxis(int axis)
Set the axis along which to concatenate.
Definition: concat_op.h:45
smaug::TensorShape
TensorShape describes the shape of a Tensor.
Definition: tensor.h:35
smaug::Operator
Operator is the base class for all graph operators supported by SMAUG.
Definition: operator.h:28
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38