1 #ifndef _OPERATORS_CONTROL_FLOW_OPS_H_ 
    2 #define _OPERATORS_CONTROL_FLOW_OPS_H_ 
    4 #include "smaug/core/backend.h" 
    5 #include "smaug/core/operator.h" 
   22 template <
typename Backend>
 
   23 class SwitchOp : 
public Operator {
 
   42             : 
Operator(name, OpType::Switch, workspace) {
 
   43         inputs.resize(2, 
nullptr);
 
   44         outputs.resize(2, 
nullptr);
 
   47     bool validate()
 override {
 
   48         if (getInput(
Pred)->getShape().size() != 1)
 
   53     void createAllTensors()
 override {
 
   54         Tensor* input = getInput(
Input);
 
   55         TensorShape shape = inputs.at(
Input)->getShape();
 
   56         Tensor* outputFalse = 
new Tensor(name + 
"_false", shape);
 
   57         Tensor* outputTrue = 
new Tensor(name + 
"_true", shape);
 
   58         workspace->addTensor(outputFalse);
 
   59         workspace->addTensor(outputTrue);
 
   65         Tensor* input = getInput(
Input);
 
   68         const TensorShape& inputShape = input->getShape();
 
   69         Tensor* predTensor = getInput(
Pred);
 
   70         bool* pred = predTensor->data<
bool>();
 
   72             outputFalse->setDead();
 
   74                     outputTrue, input, 0, 0, inputShape.storageSize());
 
   76             outputTrue->setDead();
 
   78                     outputFalse, input, 0, 0, inputShape.storageSize());
 
   90 template <
typename Backend>
 
   91 class MergeOp : 
public Operator {
 
   93     MergeOp(
const std::string& name, Workspace* workspace)
 
   94             : Operator(name, OpType::Merge, workspace) {
 
   95         outputs.resize(1, 
nullptr);
 
   98     void setNumInputs(
int num) { inputs.resize(num); }
 
  100     void createAllTensors()
 override {
 
  102                 workspace->addTensor(
new Tensor(name, getInput(0)->getShape()));
 
  103         outputs.at(0) = output;
 
  108         for (
auto input : inputs) {
 
  109             if (!input->isDead())
 
  115     void run()
 override {
 
  116         Tensor* output = getOutput(0);
 
  117         bool forwarded = 
false;
 
  118         for (
int i = 0; i < getInputs().size(); i++) {
 
  119             Tensor* input = getInput(i);
 
  120             if (!input->isDead()) {
 
  122                         output, input, 0, 0, input->getShape().storageSize());
 
  128             std::cerr << 
"All inputs to the merge operator are dead!\n";