6 #ifndef _CORE_TENSOR_UTILS_H_
7 #define _CORE_TENSOR_UTILS_H_
13 #include "smaug/core/tensor.h"
20 std::ostream& operator<<(std::ostream& os,
const TensorIndexIterator& iter);
21 std::ostream& operator<<(std::ostream& os,
const TensorShape& shape);
22 std::ostream& operator<<(std::ostream& os,
const Tensor& tensor);
24 template <
typename DType>
25 void printTensorElement(std::ostream& os,
const DType* data,
int index) {
30 void printTensorElement<float16>(std::ostream& os,
37 template <
typename DType>
40 if (shape.ndims() == 0) {
44 int ndims = shape.ndims();
45 int newlineAfterElems = shape[ndims - 1];
46 int newGroupAfterElems =
47 (shape.ndims() >= 2 ? shape[ndims - 1] * shape[ndims - 2]
50 const DType* data = tensor.template data<DType>();
51 os << tensor.getName() <<
", shape = " << shape <<
"\n";
52 for (
auto idx = tensor.
startIndex(); !idx.end(); ++idx) {
57 printTensorElement<DType>(os, data, idx);
60 if (counter % newGroupAfterElems == 0) {
63 }
else if (counter % newlineAfterElems == 0) {
71 template <
typename DType>
74 const std::vector<int>& destOrigin,
75 const std::vector<int>& srcOrigin,
76 const std::vector<int>& regionSize) {
77 const TensorShape& srcShape = src->getShape();
78 const TensorShape& destShape = dest->getShape();
79 TensorShape regionShape(
80 regionSize, srcShape.getLayout(), srcShape.getAlignment());
81 const int ndims = srcShape.ndims();
82 auto destIt = TensorRegionIndexIterator(destShape, destOrigin, regionSize);
83 auto srcIt = TensorRegionIndexIterator(srcShape, srcOrigin, regionSize);
89 std::vector<int> contiguousRegion(ndims, 1);
90 int contiguousSize = 1;
91 for (
int i = ndims - 1; i >= 0; i--) {
92 contiguousSize *= regionShape.getStorageDim(i);
93 contiguousRegion[i] = regionShape[i];
96 if (regionShape[i] < srcShape[i] || regionShape[i] < destShape[i])
101 DType* destPtr = dest->template data<DType>();
102 DType* srcPtr = src->template data<DType>();
103 while (!srcIt.end() && !destIt.end()) {
105 destPtr[destIt] = srcPtr[srcIt];
109 memcpy(&destPtr[destIt],
111 contiguousSize *
sizeof(DType));
112 destIt += contiguousRegion;
113 srcIt += contiguousRegion;
118 template <
typename DType>
124 DType* destPtr = dest->template data<DType>();
125 DType* srcPtr = src->template data<DType>();
127 &destPtr[destOffset], &srcPtr[srcOffset], copySize *
sizeof(DType));
130 template <
typename DType>
133 std::vector<int> destOrigin,
134 std::vector<int> srcOrigin,
136 TensorIndexIterator destIdx = dest->startIndex();
137 TensorIndexIterator srcIdx = src->startIndex();
138 destIdx += destOrigin;
140 DType* destPtr = dest->template data<DType>();
141 DType* srcPtr = src->template data<DType>();
142 for (; !srcIdx.end(); ++srcIdx, ++destIdx)
143 destPtr[destIdx] = srcPtr[srcIdx];
170 std::vector<int> destOrigin,
171 std::vector<int> srcOrigin,
172 std::vector<int> regionSize);
180 std::vector<int> destOffset,
181 std::vector<int> srcOffset,
197 Tensor* dest, Tensor* src,
int destOffset,
int srcOffset,
int copySize);
212 const TensorShape& tileShape,
214 bool copyData =
true);
235 const TensorShape& tileShape,
241 PaddingType paddingType,
242 bool copyData =
false);
256 const TensorShape& tileShape,
258 bool copyData =
false);
271 Workspace* workspace);