SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
unary_op.h
1 #ifndef _OPERATORS_UNARY_OP_H_
2 #define _OPERATORS_UNARY_OP_H_
3 
4 #include <string>
5 
6 #include "smaug/core/operator.h"
7 #include "smaug/core/workspace.h"
8 
9 namespace smaug {
10 
19 template <typename Backend>
20 class UnaryOp : public Operator {
21  public:
22  UnaryOp(const std::string& name, OpType opType, Workspace* workspace)
23  : Operator(name, opType, workspace) {
24  inputs.resize(kNumInputs, nullptr);
25  outputs.resize(kNumOutputs, nullptr);
26  }
27 
28  bool validate() override { return Operator::validate(); }
29 
30  void createAllTensors() override {
31  createOutputTensors();
32  }
33 
34  void createOutputTensors() {
35  if (outputs[Outputs])
36  return;
37  TensorShape shape = inputs.at(Inputs)->getShape();
38  Tensor* output = new Tensor(name, shape);
39  workspace->addTensor(output);
40  outputs[Outputs] = output;
41  }
42 
43  enum { Inputs, kNumInputs };
44  enum { Outputs, kNumOutputs };
45 };
46 
47 } // namespace smaug
48 
49 #endif
smaug::Tensor
Tensor represents a single multi-dimensional array of data.
Definition: tensor.h:344
smaug::Workspace
Workspace is the container and owner of all Tensors and Operators in the Network.
Definition: workspace.h:17
smaug::UnaryOp::validate
bool validate() override
Returns true if the parameters/tensors of this operator are all valid.
Definition: unary_op.h:28
smaug::TensorShape
TensorShape describes the shape of a Tensor.
Definition: tensor.h:35
smaug::Operator::outputs
std::vector< TensorBase * > outputs
An ordered list of output tensors produced by this operator.
Definition: operator.h:141
smaug::UnaryOp::createAllTensors
void createAllTensors() override
For tests: creates all input and output tensors for this operator.
Definition: unary_op.h:30
smaug::Operator
Operator is the base class for all graph operators supported by SMAUG.
Definition: operator.h:28
smaug
The smaug namespace is the parent namespace of all C++ code in SMAUG.
Definition: backend.cpp:38
smaug::Operator::validate
virtual bool validate()
Returns true if the parameters/tensors of this operator are all valid.
Definition: operator.h:47
smaug::UnaryOp
Base class for all operators with one input.
Definition: unary_op.h:20
smaug::Operator::inputs
std::vector< TensorBase * > inputs
An ordered list of input tensors consumed by this operator.
Definition: operator.h:134