SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
reorder_op_impl.cpp
1 #include "smaug/core/tensor.h"
2 #include "smaug/operators/reorder_op_impl.h"
3 
4 namespace smaug {
5 
6 void convertNchwToNhwc(Tensor* input, Tensor* output) {
7  DataType datatype = input->getDataType();
8  assert(input->ndims() == output->ndims() && input->ndims() == 4);
9  switch (datatype) {
10  case Float16:
11  convertNchwToNhwcImpl<float16>(input, output);
12  return;
13  case Float32:
14  convertNchwToNhwcImpl<float>(input, output);
15  return;
16  case Float64:
17  convertNchwToNhwcImpl<double>(input, output);
18  return;
19  case Int32:
20  convertNchwToNhwcImpl<int>(input, output);
21  return;
22  case Int64:
23  convertNchwToNhwcImpl<int64_t>(input, output);
24  return;
25  default:
26  assert(false && "Unknown data format!");
27  }
28 }
29 
30 void convertNhwcToNchw(Tensor* input, Tensor* output) {
31  DataType datatype = input->getDataType();
32  assert(input->ndims() == output->ndims() && input->ndims() == 4);
33  switch (datatype) {
34  case Float16:
35  convertNhwcToNchwImpl<float16>(input, output);
36  return;
37  case Float32:
38  convertNhwcToNchwImpl<float>(input, output);
39  return;
40  case Float64:
41  convertNhwcToNchwImpl<double>(input, output);
42  return;
43  case Int32:
44  convertNhwcToNchwImpl<int>(input, output);
45  return;
46  case Int64:
47  convertNhwcToNchwImpl<int64_t>(input, output);
48  return;
49  default:
50  assert(false && "Unknown data format!");
51  }
52 }
53 
54 void flatten(Tensor* input, Tensor* output) {
55  DataType datatype = input->getDataType();
56  assert(input->ndims() == 4 && output->ndims() == 2);
57  switch (datatype) {
58  case Float16:
59  flattenImpl<float16>(input, output);
60  return;
61  case Float32:
62  flattenImpl<float>(input, output);
63  return;
64  case Float64:
65  flattenImpl<double>(input, output);
66  return;
67  case Int32:
68  flattenImpl<int>(input, output);
69  return;
70  case Int64:
71  flattenImpl<int64_t>(input, output);
72  return;
73  default:
74  assert(false && "Unknown data format!");
75  }
76 }
77 
78 void transpose3D(Tensor* input, Tensor* output) {
79  DataType datatype = input->getDataType();
80  assert(input->ndims() == 3 && output->ndims() == 3);
81  switch (datatype) {
82  case Float16:
83  transpose3DImpl<float16>(input, output);
84  return;
85  case Float32:
86  transpose3DImpl<float>(input, output);
87  return;
88  case Float64:
89  transpose3DImpl<double>(input, output);
90  return;
91  case Int32:
92  transpose3DImpl<int>(input, output);
93  return;
94  case Int64:
95  transpose3DImpl<int64_t>(input, output);
96  return;
97  default:
98  assert(false && "Unknown data format!");
99  }
100 }
101 
102 void transpose2D(Tensor* input, Tensor* output) {
103  DataType datatype = input->getDataType();
104  assert(input->ndims() == 2 && output->ndims() == 2);
105  switch (datatype) {
106  case Float16:
107  transpose2DImpl<float16>(input, output);
108  return;
109  case Float32:
110  transpose2DImpl<float>(input, output);
111  return;
112  case Float64:
113  transpose2DImpl<double>(input, output);
114  return;
115  case Int32:
116  transpose2DImpl<int>(input, output);
117  return;
118  case Int64:
119  transpose2DImpl<int64_t>(input, output);
120  return;
121  default:
122  assert(false && "Unknown data format!");
123  }
124 }
125 
126 } // namespace smaug
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38