1 #ifndef _OPERATORS_CONTROL_FLOW_OPS_H_
2 #define _OPERATORS_CONTROL_FLOW_OPS_H_
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
22 template <
typename Backend>
23 class SwitchOp :
public Operator {
42 :
Operator(name, OpType::Switch, workspace) {
43 inputs.resize(2,
nullptr);
44 outputs.resize(2,
nullptr);
47 bool validate()
override {
48 if (getInput(
Pred)->getShape().size() != 1)
53 void createAllTensors()
override {
54 Tensor* input = getInput(
Input);
55 TensorShape shape = inputs.at(
Input)->getShape();
56 Tensor* outputFalse =
new Tensor(name +
"_false", shape);
57 Tensor* outputTrue =
new Tensor(name +
"_true", shape);
58 workspace->addTensor(outputFalse);
59 workspace->addTensor(outputTrue);
65 Tensor* input = getInput(
Input);
68 const TensorShape& inputShape = input->getShape();
69 Tensor* predTensor = getInput(
Pred);
70 bool* pred = predTensor->data<
bool>();
72 outputFalse->setDead();
74 outputTrue, input, 0, 0, inputShape.storageSize());
76 outputTrue->setDead();
78 outputFalse, input, 0, 0, inputShape.storageSize());
90 template <
typename Backend>
91 class MergeOp :
public Operator {
93 MergeOp(
const std::string& name, Workspace* workspace)
94 : Operator(name, OpType::Merge, workspace) {
95 outputs.resize(1,
nullptr);
98 void setNumInputs(
int num) { inputs.resize(num); }
100 void createAllTensors()
override {
102 workspace->addTensor(
new Tensor(name, getInput(0)->getShape()));
103 outputs.at(0) = output;
108 for (
auto input : inputs) {
109 if (!input->isDead())
115 void run()
override {
116 Tensor* output = getOutput(0);
117 bool forwarded =
false;
118 for (
int i = 0; i < getInputs().size(); i++) {
119 Tensor* input = getInput(i);
120 if (!input->isDead()) {
122 output, input, 0, 0, input->getShape().storageSize());
128 std::cerr <<
"All inputs to the merge operator are dead!\n";