1 #ifndef _OPERATORS_SPLIT_OP_H_
2 #define _OPERATORS_SPLIT_OP_H_
5 #include <initializer_list>
7 #include "smaug/core/backend.h"
8 #include "smaug/core/operator.h"
20 template <
typename Backend>
21 class SplitOp :
public Operator {
23 SplitOp(
const std::string& name, Workspace* workspace)
24 : Operator(name, OpType::Split, workspace) {
25 inputs.resize(1,
nullptr);
28 SplitOp(
const std::string& name,
30 const std::vector<int>& _splits,
32 : Operator(name, OpType::Split, workspace), splits(_splits),
34 inputs.resize(1,
nullptr);
35 outputs.resize(splits.size());
41 outputs.resize(splits.size());
43 void setSplits(
const std::initializer_list<int>& _splits) {
45 outputs.resize(splits.size());
51 const std::vector<int>& getSplits()
const {
return splits; }
52 int getSplitAxis()
const {
return splitAxis; }
54 bool validate()
override {
56 for (
int i = 0; i < splits.size(); i++)
57 splitSum += splits[i];
58 return (splitSum == inputs.at(0)->dim(splitAxis) &&
62 void createAllTensors()
override {
63 std::vector<int> dims = getInput(0)->getShape().dims();
64 DataLayout layout = getInput(0)->getShape().getLayout();
65 for (
int i = 0; i < splits.size(); i++) {
66 dims[splitAxis] = splits[i];
67 TensorShape shape(dims, layout, Backend::Alignment);
68 Tensor* output =
new Tensor(name + std::to_string(i), shape);
69 workspace->addTensor(output);
70 outputs.at(i) = output;
75 Tensor* input = getInput(0);
76 int ndims = input->ndims();
77 std::vector<int> srcOrigin(ndims, 0);
78 for (
int i = 0; i < getOutputs().size(); i++) {
79 Tensor* output = getOutput(i);
82 std::vector<int>(ndims, 0),
84 output->getShape().dims());
85 srcOrigin[splitAxis] += output->dim(splitAxis);
91 std::vector<int> splits;