1 #include "smaug/core/tensor.h"
6 template <
typename DType>
7 void convertNchwToNhwcImpl(Tensor* input, Tensor* output) {
8 TensorIndexIterator inputIdx = input->startIndex();
9 TensorIndexIterator outputIdx = output->startIndex();
10 const TensorShape& inputShape = input->getShape();
11 DType* inputData = input->template data<DType>();
12 DType* outputData = output->template data<DType>();
13 for (
int n = 0; n < inputShape[0]; n++) {
14 for (
int c = 0; c < inputShape[1]; c++) {
15 for (
int h = 0; h < inputShape[2]; h++) {
16 for (
int w = 0; w < inputShape[3]; w++) {
17 outputData[outputIdx(n, h, w, c)] =
18 inputData[inputIdx(n, c, h, w)];
25 template <
typename DType>
26 void convertNhwcToNchwImpl(Tensor* input, Tensor* output) {
27 TensorIndexIterator inputIdx = input->startIndex();
28 TensorIndexIterator outputIdx = output->startIndex();
29 const TensorShape& inputShape = input->getShape();
30 DType* inputData = input->template data<DType>();
31 DType* outputData = output->template data<DType>();
32 for (
int n = 0; n < inputShape[0]; n++) {
33 for (
int h = 0; h < inputShape[1]; h++) {
34 for (
int w = 0; w < inputShape[2]; w++) {
35 for (
int c = 0; c < inputShape[3]; c++) {
36 outputData[outputIdx(n, c, h, w)] =
37 inputData[inputIdx(n, h, w, c)];
44 template <
typename DType>
45 void flattenImpl(Tensor* input, Tensor* output) {
46 TensorIndexIterator inputIdx = input->startIndex();
47 TensorIndexIterator outputIdx = output->startIndex();
48 const TensorShape& inputShape = input->getShape();
49 const TensorShape& outputShape = output->getShape();
50 DType* inputData = input->template data<DType>();
51 DType* outputData = output->template data<DType>();
52 bool targetNC = outputShape.getLayout() == NC;
53 for (
int n = 0; n < inputShape[0]; n++) {
58 for (
int i = 0; i < inputShape[1]; i++) {
59 for (
int j = 0; j < inputShape[2]; j++) {
60 for (
int k = 0; k < inputShape[3]; k++) {
62 outputData[outputIdx(n, out_i++)] =
63 inputData[inputIdx(n, i, j, k)];
65 outputData[outputIdx(out_i++, n)] =
66 inputData[inputIdx(n, i, j, k)];
74 template <
typename DType>
75 void transpose3DImpl(Tensor* input, Tensor* output) {
76 TensorIndexIterator inputIdx = input->startIndex();
77 TensorIndexIterator outputIdx = output->startIndex();
78 const TensorShape& inputShape = input->getShape();
79 auto inputData = input->template data<DType>();
80 auto outputData = output->template data<DType>();
81 for (
int i = 0; i < inputShape[0]; i++) {
82 for (
int j = 0; j < inputShape[1]; j++) {
83 for (
int k = 0; k < inputShape[2]; k++) {
84 outputData[outputIdx(i, k, j)] = inputData[inputIdx(i, j, k)];
90 template <
typename DType>
91 void transpose2DImpl(Tensor* input, Tensor* output) {
92 TensorIndexIterator inputIdx = input->startIndex();
93 TensorIndexIterator outputIdx = output->startIndex();
94 const TensorShape& inputShape = input->getShape();
95 auto inputData = input->template data<DType>();
96 auto outputData = output->template data<DType>();
97 for (
int n = 0; n < inputShape[0]; n++) {
98 for (
int c = 0; c < inputShape[1]; c++) {
99 outputData[outputIdx(c, n)] = inputData[inputIdx(n, c)];
104 void convertNchwToNhwc(Tensor* input, Tensor* output);
106 void convertNhwcToNchw(Tensor* input, Tensor* output);
108 void flatten(Tensor* input, Tensor* output);
110 void transpose3D(Tensor* input, Tensor* output);
112 void transpose2D(Tensor* input, Tensor* output);