1 #ifndef _CORE_REORDER_OP_H_
2 #define _CORE_REORDER_OP_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
6 #include "smaug/operators/reorder_op_impl.h"
17 template <
typename Backend>
18 class ReorderOp :
public Operator {
20 ReorderOp(
const std::string& name,
21 DataLayout _targetLayout,
23 : Operator(name, OpType::Reorder, workspace),
24 targetLayout(_targetLayout) {
25 inputs.resize(kNumInputs,
nullptr);
26 outputs.resize(kNumOutputs,
nullptr);
29 ReorderOp(
const std::string& name, Workspace* workspace)
30 : ReorderOp(name, DataLayout::UnknownLayout, workspace) {}
32 DataLayout getTargetDataLayout()
const {
return targetLayout; }
33 void setTargetLayout(DataLayout layout) { targetLayout = layout; }
36 auto stats = gem5::ScopedStats(
37 stats::kReorderingStart, stats::kReorderingEnd);
38 Tensor* input = getInput(Inputs);
39 Tensor* output = getOutput(Outputs);
40 DataLayout srcLayout = input->getShape().getLayout();
41 if (srcLayout == DataLayout::NCHW) {
42 if (targetLayout == DataLayout::NHWC) {
43 convertNchwToNhwc(input, output);
44 }
else if (output->getShape().ndims() == 2) {
45 flatten(input, output);
47 }
else if (srcLayout == DataLayout::NHWC) {
48 if (targetLayout == DataLayout::NCHW) {
49 convertNhwcToNchw(input, output);
50 }
else if (output->getShape().ndims() == 2) {
51 flatten(input, output);
53 }
else if (input->getShape().ndims() == 3) {
55 assert(srcLayout == DataLayout::NCT &&
56 targetLayout == DataLayout::NTC ||
57 srcLayout == DataLayout::NTC &&
58 targetLayout == DataLayout::NCT &&
59 "Only NCT->NTC or NCT->NTC is supported for 3D "
61 transpose3D(input, output);
62 }
else if (input->getShape().ndims() == 2) {
63 if (srcLayout == targetLayout) {
65 }
else if (output->getShape().ndims() == 2) {
66 transpose2D(input, output);
68 std::cerr <<
"Data layout reordering from "
69 << DataLayout_Name(srcLayout) <<
" to "
70 << DataLayout_Name(targetLayout)
71 <<
" is not supported!\n";
77 bool validate()
override {
80 DataLayout sourceLayout = inputs[Inputs]->getShape().getLayout();
81 if (sourceLayout == DataLayout::UnknownLayout) {
82 std::cerr <<
"[ERROR]: Reorder operation has unknown source "
86 if (targetLayout == DataLayout::UnknownLayout) {
87 std::cerr <<
"[ERROR]: Reorder operation has unknown target "
91 if (sourceLayout == targetLayout) {
92 std::cerr <<
"[ERROR]: Reorder operation does not change the data "
99 TensorShape inferOutputShape()
const {
100 TensorShape inputShape = getInput(Inputs)->getShape();
101 if (inputShape.ndims() == 4 && (targetLayout == DataLayout::NC ||
102 targetLayout == DataLayout::CN)) {
104 std::vector<int> dims(2, 1);
105 dims[0] = inputShape[0];
106 for (
int i = 1; i < inputShape.ndims(); ++i) {
107 dims[1] *= inputShape[i];
109 return TensorShape(dims, targetLayout, Backend::Alignment);
110 }
else if (targetLayout == DataLayout::NC ||
111 targetLayout == DataLayout::CN) {
113 return TensorShape({ inputShape[1], inputShape[0] }, targetLayout,
115 }
else if (targetLayout == DataLayout::NCT ||
116 targetLayout == DataLayout::NTC) {
118 return TensorShape({ inputShape[0], inputShape[2], inputShape[1] },
119 targetLayout, Backend::Alignment);
120 }
else if (targetLayout == DataLayout::NCHW) {
121 return TensorShape({ inputShape[0], inputShape[3], inputShape[1],
123 targetLayout, Backend::Alignment);
124 }
else if (targetLayout == DataLayout::NHWC) {
125 return TensorShape({ inputShape[0], inputShape[2], inputShape[3],
127 targetLayout, Backend::Alignment);
129 return TensorShape();
132 void createOutputTensors() {
133 assert(targetLayout != DataLayout::UnknownLayout &&
134 "Cannot create output tensor with unknown target data layout!");
135 TensorShape shape = inferOutputShape();
136 Tensor* output =
new Tensor(name, shape);
137 workspace->addTensor(output);
138 outputs.at(Outputs) = output;
140 void createAllTensors()
override {
141 createOutputTensors();
145 enum { Inputs, kNumInputs };
146 enum { Outputs, kNumOutputs };
147 DataLayout targetLayout;
157 template <
typename Backend>
158 class FlattenOp :
public ReorderOp<Backend> {
160 typedef ReorderOp<Backend> Parent;
162 FlattenOp(
const std::string& name, Workspace* workspace)
163 : ReorderOp<Backend>(name, DataLayout::NC, workspace) {}