1 #include "smaug/core/tensor.h" 
    6 template <
typename DType>
 
    7 void convertNchwToNhwcImpl(Tensor* input, Tensor* output) {
 
    8     TensorIndexIterator inputIdx = input->startIndex();
 
    9     TensorIndexIterator outputIdx = output->startIndex();
 
   10     const TensorShape& inputShape = input->getShape();
 
   11     DType* inputData = input->template data<DType>();
 
   12     DType* outputData = output->template data<DType>();
 
   13     for (
int n = 0; n < inputShape[0]; n++) {
 
   14         for (
int c = 0; c < inputShape[1]; c++) {
 
   15             for (
int h = 0; h < inputShape[2]; h++) {
 
   16                 for (
int w = 0; w < inputShape[3]; w++) {
 
   17                     outputData[outputIdx(n, h, w, c)] =
 
   18                             inputData[inputIdx(n, c, h, w)];
 
   25 template <
typename DType>
 
   26 void convertNhwcToNchwImpl(Tensor* input, Tensor* output) {
 
   27     TensorIndexIterator inputIdx = input->startIndex();
 
   28     TensorIndexIterator outputIdx = output->startIndex();
 
   29     const TensorShape& inputShape = input->getShape();
 
   30     DType* inputData = input->template data<DType>();
 
   31     DType* outputData = output->template data<DType>();
 
   32     for (
int n = 0; n < inputShape[0]; n++) {
 
   33         for (
int h = 0; h < inputShape[1]; h++) {
 
   34             for (
int w = 0; w < inputShape[2]; w++) {
 
   35                 for (
int c = 0; c < inputShape[3]; c++) {
 
   36                     outputData[outputIdx(n, c, h, w)] =
 
   37                             inputData[inputIdx(n, h, w, c)];
 
   44 template <
typename DType>
 
   45 void flattenImpl(Tensor* input, Tensor* output) {
 
   46     TensorIndexIterator inputIdx = input->startIndex();
 
   47     TensorIndexIterator outputIdx = output->startIndex();
 
   48     const TensorShape& inputShape = input->getShape();
 
   49     const TensorShape& outputShape = output->getShape();
 
   50     DType* inputData = input->template data<DType>();
 
   51     DType* outputData = output->template data<DType>();
 
   52     bool targetNC = outputShape.getLayout() == NC;
 
   53     for (
int n = 0; n < inputShape[0]; n++) {
 
   58         for (
int i = 0; i < inputShape[1]; i++) {
 
   59             for (
int j = 0; j < inputShape[2]; j++) {
 
   60                 for (
int k = 0; k < inputShape[3]; k++) {
 
   62                         outputData[outputIdx(n, out_i++)] =
 
   63                                 inputData[inputIdx(n, i, j, k)];
 
   65                         outputData[outputIdx(out_i++, n)] =
 
   66                                 inputData[inputIdx(n, i, j, k)];
 
   74 template <
typename DType>
 
   75 void transpose3DImpl(Tensor* input, Tensor* output) {
 
   76     TensorIndexIterator inputIdx = input->startIndex();
 
   77     TensorIndexIterator outputIdx = output->startIndex();
 
   78     const TensorShape& inputShape = input->getShape();
 
   79     auto inputData = input->template data<DType>();
 
   80     auto outputData = output->template data<DType>();
 
   81     for (
int i = 0; i < inputShape[0]; i++) {
 
   82         for (
int j = 0; j < inputShape[1]; j++) {
 
   83             for (
int k = 0; k < inputShape[2]; k++) {
 
   84                 outputData[outputIdx(i, k, j)] = inputData[inputIdx(i, j, k)];
 
   90 template <
typename DType>
 
   91 void transpose2DImpl(Tensor* input, Tensor* output) {
 
   92     TensorIndexIterator inputIdx = input->startIndex();
 
   93     TensorIndexIterator outputIdx = output->startIndex();
 
   94     const TensorShape& inputShape = input->getShape();
 
   95     auto inputData = input->template data<DType>();
 
   96     auto outputData = output->template data<DType>();
 
   97     for (
int n = 0; n < inputShape[0]; n++) {
 
   98         for (
int c = 0; c < inputShape[1]; c++) {
 
   99             outputData[outputIdx(c, n)] = inputData[inputIdx(n, c)];
 
  104 void convertNchwToNhwc(Tensor* input, Tensor* output);
 
  106 void convertNhwcToNchw(Tensor* input, Tensor* output);
 
  108 void flatten(Tensor* input, Tensor* output);
 
  110 void transpose3D(Tensor* input, Tensor* output);
 
  112 void transpose2D(Tensor* input, Tensor* output);