1 #ifndef _OPERATORS_CONCAT_OP_H_
2 #define _OPERATORS_CONCAT_OP_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
17 template <
typename Backend>
18 class ConcatOp :
public Operator {
20 ConcatOp(
const std::string& name, Workspace* workspace)
21 : Operator(name, OpType::Concat, workspace) {
22 outputs.resize(1,
nullptr);
37 :
Operator(name, OpType::Concat, workspace), concatAxis(axis) {
39 outputs.resize(1,
nullptr);
48 assert(getInputs().size() > 0 &&
"Unable to get inputs for concat op!");
49 std::vector<int> dims = getInput(0)->getShape().dims();
50 DataLayout layout = getInput(0)->getShape().getLayout();
52 for (
int i = 0; i < getInputs().size(); i++) {
53 dim += getInput(i)->dim(concatAxis);
55 dims[concatAxis] = dim;
56 return TensorShape(dims, layout, Backend::Alignment);
59 void createOutputTensor() {
60 TensorShape shape = inferOutputShape();
61 Tensor* output =
new Tensor(name, shape);
62 workspace->addTensor(output);
63 outputs.at(0) = output;
66 void createAllTensors()
override{
71 Tensor* output = getOutput(0);
72 int ndims = output->ndims();
73 std::vector<int> dstOrigin(ndims, 0);
74 for (
int i = 0; i < getInputs().size(); i++) {
75 Tensor* input = getInput(i);
79 std::vector<int>(ndims, 0),
80 input->getShape().dims());
81 dstOrigin[concatAxis] += input->dim(concatAxis);
85 int getConcatAxis()
const {
return concatAxis; }