SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
split_op.h
1 #ifndef _OPERATORS_SPLIT_OP_H_
2 #define _OPERATORS_SPLIT_OP_H_
3 
4 #include <vector>
5 #include <initializer_list>
6 
7 #include "smaug/core/backend.h"
8 #include "smaug/core/operator.h"
10 
11 namespace smaug {
12 
20 template <typename Backend>
21 class SplitOp : public Operator {
22  public:
23  SplitOp(const std::string& name, Workspace* workspace)
24  : Operator(name, OpType::Split, workspace) {
25  inputs.resize(1, nullptr);
26  }
27 
28  SplitOp(const std::string& name,
29  Workspace* workspace,
30  const std::vector<int>& _splits,
31  int axis = 0)
32  : Operator(name, OpType::Split, workspace), splits(_splits),
33  splitAxis(axis) {
34  inputs.resize(1, nullptr);
35  outputs.resize(splits.size());
36  }
37 
39  void setSplits(const std::vector<int>& _splits) {
40  splits = _splits;
41  outputs.resize(splits.size());
42  }
43  void setSplits(const std::initializer_list<int>& _splits) {
44  splits = _splits;
45  outputs.resize(splits.size());
46  }
47 
49  void setSplitAxis(int axis) { splitAxis = axis; }
50 
51  const std::vector<int>& getSplits() const { return splits; }
52  int getSplitAxis() const { return splitAxis; }
53 
54  bool validate() override {
55  int splitSum = 0;
56  for (int i = 0; i < splits.size(); i++)
57  splitSum += splits[i];
58  return (splitSum == inputs.at(0)->dim(splitAxis) &&
60  }
61 
62  void createAllTensors() override {
63  std::vector<int> dims = getInput(0)->getShape().dims();
64  DataLayout layout = getInput(0)->getShape().getLayout();
65  for (int i = 0; i < splits.size(); i++) {
66  dims[splitAxis] = splits[i];
67  TensorShape shape(dims, layout, Backend::Alignment);
68  Tensor* output = new Tensor(name + std::to_string(i), shape);
69  workspace->addTensor(output);
70  outputs.at(i) = output;
71  }
72  }
73 
74  void run() override {
75  Tensor* input = getInput(0);
76  int ndims = input->ndims();
77  std::vector<int> srcOrigin(ndims, 0);
78  for (int i = 0; i < getOutputs().size(); i++) {
79  Tensor* output = getOutput(i);
80  copyTensorRegion(output,
81  input,
82  std::vector<int>(ndims, 0),
83  srcOrigin,
84  output->getShape().dims());
85  srcOrigin[splitAxis] += output->dim(splitAxis);
86  }
87  }
88 
89  protected:
90  int splitAxis;
91  std::vector<int> splits;
92 };
93 
94 } // namespace smaug
95 
96 #endif
smaug::SplitOp::setSplits
void setSplits(const std::vector< int > &_splits)
Set the size (along the split axis) of each split Tensor.
Definition: split_op.h:39
smaug::SplitOp::setSplitAxis
void setSplitAxis(int axis)
Set the axis along which to split the input Tensor.
Definition: split_op.h:49
smaug::copyTensorRegion
void copyTensorRegion(Tensor *dest, Tensor *src, std::vector< int > destOrigin, std::vector< int > srcOrigin, std::vector< int > regionSize)
Copies a region of a source Tensor to a corresponding region in a destination Tensor.
Definition: tensor_utils.cpp:65
tensor_utils.h
Utility functions for copying/printing/tiling tensors.
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47