SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
control_flow_ops.h
1 #ifndef _OPERATORS_CONTROL_FLOW_OPS_H_
2 #define _OPERATORS_CONTROL_FLOW_OPS_H_
3 
4 #include "smaug/core/backend.h"
5 #include "smaug/core/operator.h"
7 
8 namespace smaug {
9 
22 template <typename Backend>
23 class SwitchOp : public Operator {
24  public:
25  enum {
31  kNumInputs
32  };
33  enum {
38  kNumOutputs
39  };
40 
41  SwitchOp(const std::string& name, Workspace* workspace)
42  : Operator(name, OpType::Switch, workspace) {
43  inputs.resize(2, nullptr);
44  outputs.resize(2, nullptr);
45  }
46 
47  bool validate() override {
48  if (getInput(Pred)->getShape().size() != 1)
49  return false;
50  return Operator::validate();
51  }
52 
53  void createAllTensors() override {
54  Tensor* input = getInput(Input);
55  TensorShape shape = inputs.at(Input)->getShape();
56  Tensor* outputFalse = new Tensor(name + "_false", shape);
57  Tensor* outputTrue = new Tensor(name + "_true", shape);
58  workspace->addTensor(outputFalse);
59  workspace->addTensor(outputTrue);
60  outputs.at(OutputFalse) = outputFalse;
61  outputs.at(OutputTrue) = outputTrue;
62  }
63 
64  void run() override {
65  Tensor* input = getInput(Input);
66  Tensor* outputFalse = getOutput(OutputFalse);
67  Tensor* outputTrue = getOutput(OutputTrue);
68  const TensorShape& inputShape = input->getShape();
69  Tensor* predTensor = getInput(Pred);
70  bool* pred = predTensor->data<bool>();
71  if (pred[0]) {
72  outputFalse->setDead();
74  outputTrue, input, 0, 0, inputShape.storageSize());
75  } else {
76  outputTrue->setDead();
78  outputFalse, input, 0, 0, inputShape.storageSize());
79  }
80  }
81 };
82 
90 template <typename Backend>
91 class MergeOp : public Operator {
92  public:
93  MergeOp(const std::string& name, Workspace* workspace)
94  : Operator(name, OpType::Merge, workspace) {
95  outputs.resize(1, nullptr);
96  }
97 
98  void setNumInputs(int num) { inputs.resize(num); }
99 
100  void createAllTensors() override {
101  Tensor* output =
102  workspace->addTensor(new Tensor(name, getInput(0)->getShape()));
103  outputs.at(0) = output;
104  }
105 
107  bool isDead() override {
108  for (auto input : inputs) {
109  if (!input->isDead())
110  return false;
111  }
112  return true;
113  }
114 
115  void run() override {
116  Tensor* output = getOutput(0);
117  bool forwarded = false;
118  for (int i = 0; i < getInputs().size(); i++) {
119  Tensor* input = getInput(i);
120  if (!input->isDead()) {
122  output, input, 0, 0, input->getShape().storageSize());
123  forwarded = true;
124  break;
125  }
126  }
127  if (!forwarded) {
128  std::cerr << "All inputs to the merge operator are dead!\n";
129  exit(1);
130  }
131  }
132 };
133 
134 } // namespace smaug
135 
136 #endif
smaug::Tensor
Tensor represents a single multi-dimensional array of data.
Definition: tensor.h:344
smaug::SwitchOp::OutputFalse
@ OutputFalse
The output tensor on the false branch.
Definition: control_flow_ops.h:35
smaug::Workspace
Workspace is the container and owner of all Tensors and Operators in the Network.
Definition: workspace.h:17
smaug::SwitchOp::OutputTrue
@ OutputTrue
The output tensor on the true branch.
Definition: control_flow_ops.h:37
tensor_utils.h
Utility functions for copying/printing/tiling tensors.
smaug::copyRawTensorData
void copyRawTensorData(Tensor *dest, Tensor *src, int destOffset, int srcOffset, int copySize)
Directly copies a linear region of memory from dest to src, without taking dimensions/padding into ac...
Definition: tensor_utils.cpp:138
smaug::SwitchOp
Conditionally forwards an input to one of two outputs.
Definition: backend.h:53
smaug::Operator
Operator is the base class for all graph operators supported by SMAUG.
Definition: operator.h:28
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
smaug::MergeOp::isDead
bool isDead() override
A merge operator is dead only when all its inputs are dead.
Definition: control_flow_ops.h:107
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47
smaug::SwitchOp::Input
@ Input
The input Tensor to pass through.
Definition: control_flow_ops.h:27
smaug::SwitchOp::Pred
@ Pred
A scalar Tensor (a Tensor with just one value).
Definition: control_flow_ops.h:30