1 #include "smaug/core/tensor.h"
2 #include "smaug/operators/reorder_op_impl.h"
6 void convertNchwToNhwc(Tensor* input, Tensor* output) {
7 DataType datatype = input->getDataType();
8 assert(input->ndims() == output->ndims() && input->ndims() == 4);
11 convertNchwToNhwcImpl<float16>(input, output);
14 convertNchwToNhwcImpl<float>(input, output);
17 convertNchwToNhwcImpl<double>(input, output);
20 convertNchwToNhwcImpl<int>(input, output);
23 convertNchwToNhwcImpl<int64_t>(input, output);
26 assert(
false &&
"Unknown data format!");
30 void convertNhwcToNchw(Tensor* input, Tensor* output) {
31 DataType datatype = input->getDataType();
32 assert(input->ndims() == output->ndims() && input->ndims() == 4);
35 convertNhwcToNchwImpl<float16>(input, output);
38 convertNhwcToNchwImpl<float>(input, output);
41 convertNhwcToNchwImpl<double>(input, output);
44 convertNhwcToNchwImpl<int>(input, output);
47 convertNhwcToNchwImpl<int64_t>(input, output);
50 assert(
false &&
"Unknown data format!");
54 void flatten(Tensor* input, Tensor* output) {
55 DataType datatype = input->getDataType();
56 assert(input->ndims() == 4 && output->ndims() == 2);
59 flattenImpl<float16>(input, output);
62 flattenImpl<float>(input, output);
65 flattenImpl<double>(input, output);
68 flattenImpl<int>(input, output);
71 flattenImpl<int64_t>(input, output);
74 assert(
false &&
"Unknown data format!");
78 void transpose3D(Tensor* input, Tensor* output) {
79 DataType datatype = input->getDataType();
80 assert(input->ndims() == 3 && output->ndims() == 3);
83 transpose3DImpl<float16>(input, output);
86 transpose3DImpl<float>(input, output);
89 transpose3DImpl<double>(input, output);
92 transpose3DImpl<int>(input, output);
95 transpose3DImpl<int64_t>(input, output);
98 assert(
false &&
"Unknown data format!");
102 void transpose2D(Tensor* input, Tensor* output) {
103 DataType datatype = input->getDataType();
104 assert(input->ndims() == 2 && output->ndims() == 2);
107 transpose2DImpl<float16>(input, output);
110 transpose2DImpl<float>(input, output);
113 transpose2DImpl<double>(input, output);
116 transpose2DImpl<int>(input, output);
119 transpose2DImpl<int64_t>(input, output);
122 assert(
false &&
"Unknown data format!");