1 #ifndef _OPERATORS_RESHAPE_OP_H_
2 #define _OPERATORS_RESHAPE_OP_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
20 template <
typename Backend>
21 class ReshapeOp :
public Operator {
23 ReshapeOp(
const std::string& name, Workspace* workspace)
24 : Operator(name, OpType::Reshape, workspace) {
25 inputs.resize(1,
nullptr);
26 outputs.resize(1,
nullptr);
29 ReshapeOp(
const std::string& name,
31 const std::vector<int>& _shape,
33 : Operator(name, OpType::Reshape, workspace), shape(_shape),
35 inputs.resize(1,
nullptr);
36 outputs.resize(1,
nullptr);
40 void setShape(
const std::vector<int>& _shape, DataLayout _layout) {
45 void setShape(
const std::initializer_list<int>& _shape,
51 void createAllTensors()
override {
52 Tensor* input = getInput(0);
54 name,
TensorShape(shape, layout, Backend::Alignment));
55 workspace->addTensor(output);
56 outputs.at(0) = output;
61 Tensor* input = getInput(0);
62 Tensor* output = getOutput(0);
63 const TensorShape& inputShape = input->getShape();
64 const TensorShape& outputShape = output->getShape();
65 int inputNumDims = input->ndims();
66 int outputNumDims = output->ndims();
67 int inputPadding = inputShape.getPadding(inputNumDims - 1);
68 int outputPadding = outputShape.getPadding(outputNumDims - 1);
69 if (inputPadding == outputPadding) {
78 std::vector<int>(outputNumDims, 0),
79 std::vector<int>(inputNumDims, 0),
85 std::vector<int> shape;