SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
reorder_op_impl.h
1 #include "smaug/core/tensor.h"
3 
4 namespace smaug {
5 
6 template <typename DType>
7 void convertNchwToNhwcImpl(Tensor* input, Tensor* output) {
8  TensorIndexIterator inputIdx = input->startIndex();
9  TensorIndexIterator outputIdx = output->startIndex();
10  const TensorShape& inputShape = input->getShape();
11  DType* inputData = input->template data<DType>();
12  DType* outputData = output->template data<DType>();
13  for (int n = 0; n < inputShape[0]; n++) {
14  for (int c = 0; c < inputShape[1]; c++) {
15  for (int h = 0; h < inputShape[2]; h++) {
16  for (int w = 0; w < inputShape[3]; w++) {
17  outputData[outputIdx(n, h, w, c)] =
18  inputData[inputIdx(n, c, h, w)];
19  }
20  }
21  }
22  }
23 }
24 
25 template <typename DType>
26 void convertNhwcToNchwImpl(Tensor* input, Tensor* output) {
27  TensorIndexIterator inputIdx = input->startIndex();
28  TensorIndexIterator outputIdx = output->startIndex();
29  const TensorShape& inputShape = input->getShape();
30  DType* inputData = input->template data<DType>();
31  DType* outputData = output->template data<DType>();
32  for (int n = 0; n < inputShape[0]; n++) {
33  for (int h = 0; h < inputShape[1]; h++) {
34  for (int w = 0; w < inputShape[2]; w++) {
35  for (int c = 0; c < inputShape[3]; c++) {
36  outputData[outputIdx(n, c, h, w)] =
37  inputData[inputIdx(n, h, w, c)];
38  }
39  }
40  }
41  }
42 }
43 
44 template <typename DType>
45 void flattenImpl(Tensor* input, Tensor* output) {
46  TensorIndexIterator inputIdx = input->startIndex();
47  TensorIndexIterator outputIdx = output->startIndex();
48  const TensorShape& inputShape = input->getShape();
49  const TensorShape& outputShape = output->getShape();
50  DType* inputData = input->template data<DType>();
51  DType* outputData = output->template data<DType>();
52  bool targetNC = outputShape.getLayout() == NC;
53  for (int n = 0; n < inputShape[0]; n++) {
54  int out_i = 0;
55  // At this point, it doesn't matter whether the layout is NCHW or NHWC.
56  // We just need to flatten the HWC part, which is dictated by the size
57  // of each dimension and not the logical meaning of each dim.
58  for (int i = 0; i < inputShape[1]; i++) {
59  for (int j = 0; j < inputShape[2]; j++) {
60  for (int k = 0; k < inputShape[3]; k++) {
61  if (targetNC) {
62  outputData[outputIdx(n, out_i++)] =
63  inputData[inputIdx(n, i, j, k)];
64  } else {
65  outputData[outputIdx(out_i++, n)] =
66  inputData[inputIdx(n, i, j, k)];
67  }
68  }
69  }
70  }
71  }
72 }
73 
74 template <typename DType>
75 void transpose3DImpl(Tensor* input, Tensor* output) {
76  TensorIndexIterator inputIdx = input->startIndex();
77  TensorIndexIterator outputIdx = output->startIndex();
78  const TensorShape& inputShape = input->getShape();
79  auto inputData = input->template data<DType>();
80  auto outputData = output->template data<DType>();
81  for (int i = 0; i < inputShape[0]; i++) {
82  for (int j = 0; j < inputShape[1]; j++) {
83  for (int k = 0; k < inputShape[2]; k++) {
84  outputData[outputIdx(i, k, j)] = inputData[inputIdx(i, j, k)];
85  }
86  }
87  }
88 }
89 
90 template <typename DType>
91 void transpose2DImpl(Tensor* input, Tensor* output) {
92  TensorIndexIterator inputIdx = input->startIndex();
93  TensorIndexIterator outputIdx = output->startIndex();
94  const TensorShape& inputShape = input->getShape();
95  auto inputData = input->template data<DType>();
96  auto outputData = output->template data<DType>();
97  for (int n = 0; n < inputShape[0]; n++) {
98  for (int c = 0; c < inputShape[1]; c++) {
99  outputData[outputIdx(c, n)] = inputData[inputIdx(n, c)];
100  }
101  }
102 }
103 
104 void convertNchwToNhwc(Tensor* input, Tensor* output);
105 
106 void convertNhwcToNchw(Tensor* input, Tensor* output);
107 
108 void flatten(Tensor* input, Tensor* output);
109 
110 void transpose3D(Tensor* input, Tensor* output);
111 
112 void transpose2D(Tensor* input, Tensor* output);
113 
114 } // namespace smaug
tensor_utils.h
Utility functions for copying/printing/tiling tensors.
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38