SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
reorder_op.h
1 #ifndef _CORE_REORDER_OP_H_
2 #define _CORE_REORDER_OP_H_
3 
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/operators/reorder_op_impl.h"
7 
8 namespace smaug {
9 
17 template <typename Backend>
18 class ReorderOp : public Operator {
19  public:
20  ReorderOp(const std::string& name,
21  DataLayout _targetLayout,
22  Workspace* workspace)
23  : Operator(name, OpType::Reorder, workspace),
24  targetLayout(_targetLayout) {
25  inputs.resize(kNumInputs, nullptr);
26  outputs.resize(kNumOutputs, nullptr);
27  }
28 
29  ReorderOp(const std::string& name, Workspace* workspace)
30  : ReorderOp(name, DataLayout::UnknownLayout, workspace) {}
31 
32  DataLayout getTargetDataLayout() const { return targetLayout; }
33  void setTargetLayout(DataLayout layout) { targetLayout = layout; }
34 
35  void run() override {
36  auto stats = gem5::ScopedStats(
37  stats::kReorderingStart, stats::kReorderingEnd);
38  Tensor* input = getInput(Inputs);
39  Tensor* output = getOutput(Outputs);
40  DataLayout srcLayout = input->getShape().getLayout();
41  if (srcLayout == DataLayout::NCHW) {
42  if (targetLayout == DataLayout::NHWC) {
43  convertNchwToNhwc(input, output);
44  } else if (output->getShape().ndims() == 2) {
45  flatten(input, output);
46  }
47  } else if (srcLayout == DataLayout::NHWC) {
48  if (targetLayout == DataLayout::NCHW) {
49  convertNhwcToNchw(input, output);
50  } else if (output->getShape().ndims() == 2) {
51  flatten(input, output);
52  }
53  } else if (input->getShape().ndims() == 3) {
54  // NTC->NCT or NCT->NTC.
55  assert(srcLayout == DataLayout::NCT &&
56  targetLayout == DataLayout::NTC ||
57  srcLayout == DataLayout::NTC &&
58  targetLayout == DataLayout::NCT &&
59  "Only NCT->NTC or NCT->NTC is supported for 3D "
60  "reorderings!");
61  transpose3D(input, output);
62  } else if (input->getShape().ndims() == 2) {
63  if (srcLayout == targetLayout) {
64  return;
65  } else if (output->getShape().ndims() == 2) {
66  transpose2D(input, output);
67  } else {
68  std::cerr << "Data layout reordering from "
69  << DataLayout_Name(srcLayout) << " to "
70  << DataLayout_Name(targetLayout)
71  << " is not supported!\n";
72  exit(1);
73  }
74  }
75  }
76 
77  bool validate() override {
78  if (!Operator::validate())
79  return false;
80  DataLayout sourceLayout = inputs[Inputs]->getShape().getLayout();
81  if (sourceLayout == DataLayout::UnknownLayout) {
82  std::cerr << "[ERROR]: Reorder operation has unknown source "
83  "layout!\n";
84  return false;
85  }
86  if (targetLayout == DataLayout::UnknownLayout) {
87  std::cerr << "[ERROR]: Reorder operation has unknown target "
88  "layout!\n";
89  return false;
90  }
91  if (sourceLayout == targetLayout) {
92  std::cerr << "[ERROR]: Reorder operation does not change the data "
93  "layout!\n";
94  return false;
95  }
96  return true;
97  }
98 
99  TensorShape inferOutputShape() const {
100  TensorShape inputShape = getInput(Inputs)->getShape();
101  if (inputShape.ndims() == 4 && (targetLayout == DataLayout::NC ||
102  targetLayout == DataLayout::CN)) {
103  // Flatten a 4D tensor to 2D.
104  std::vector<int> dims(2, 1);
105  dims[0] = inputShape[0];
106  for (int i = 1; i < inputShape.ndims(); ++i) {
107  dims[1] *= inputShape[i];
108  }
109  return TensorShape(dims, targetLayout, Backend::Alignment);
110  } else if (targetLayout == DataLayout::NC ||
111  targetLayout == DataLayout::CN) {
112  // Transpose a 2D tensor.
113  return TensorShape({ inputShape[1], inputShape[0] }, targetLayout,
114  Backend::Alignment);
115  } else if (targetLayout == DataLayout::NCT ||
116  targetLayout == DataLayout::NTC) {
117  // Transpose a 3D tensor.
118  return TensorShape({ inputShape[0], inputShape[2], inputShape[1] },
119  targetLayout, Backend::Alignment);
120  } else if (targetLayout == DataLayout::NCHW) {
121  return TensorShape({ inputShape[0], inputShape[3], inputShape[1],
122  inputShape[2] },
123  targetLayout, Backend::Alignment);
124  } else if (targetLayout == DataLayout::NHWC) {
125  return TensorShape({ inputShape[0], inputShape[2], inputShape[3],
126  inputShape[1] },
127  targetLayout, Backend::Alignment);
128  }
129  return TensorShape();
130  }
131 
132  void createOutputTensors() {
133  assert(targetLayout != DataLayout::UnknownLayout &&
134  "Cannot create output tensor with unknown target data layout!");
135  TensorShape shape = inferOutputShape();
136  Tensor* output = new Tensor(name, shape);
137  workspace->addTensor(output);
138  outputs.at(Outputs) = output;
139  }
140  void createAllTensors() override {
141  createOutputTensors();
142  }
143 
144  protected:
145  enum { Inputs, kNumInputs };
146  enum { Outputs, kNumOutputs };
147  DataLayout targetLayout;
148 };
149 
157 template <typename Backend>
158 class FlattenOp : public ReorderOp<Backend> {
159  public:
160  typedef ReorderOp<Backend> Parent;
161 
162  FlattenOp(const std::string& name, Workspace* workspace)
163  : ReorderOp<Backend>(name, DataLayout::NC, workspace) {}
164 };
165 
166 } // namespace smaug
167 
168 #endif
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47