5 #include <google/protobuf/io/zero_copy_stream_impl.h>
6 #include <google/protobuf/text_format.h>
8 #include "smaug/core/backend.h"
9 #include "smaug/core/graph.pb.h"
10 #include "smaug/core/network.h"
11 #include "smaug/core/network_builder.h"
12 #include "smaug/core/node.pb.h"
13 #include "smaug/core/tensor.h"
14 #include "smaug/core/tensor.pb.h"
15 #include "smaug/core/types.pb.h"
16 #include "smaug/core/workspace.h"
17 #include "smaug/operators/batch_norm_op.h"
19 #include "smaug/operators/concat_op.h"
20 #include "smaug/operators/control_flow_ops.h"
21 #include "smaug/operators/convolution_op.h"
22 #include "smaug/operators/data_op.h"
23 #include "smaug/operators/depthwise_convolution_op.h"
24 #include "smaug/operators/eltwise_add_op.h"
25 #include "smaug/operators/eltwise_mul_op.h"
26 #include "smaug/operators/elu_op.h"
27 #include "smaug/operators/greater_op.h"
28 #include "smaug/operators/inner_product_op.h"
29 #include "smaug/operators/less_op.h"
30 #include "smaug/operators/padding_op.h"
31 #include "smaug/operators/pooling_op.h"
32 #include "smaug/operators/relu_op.h"
33 #include "smaug/operators/reorder_op.h"
34 #include "smaug/operators/repeat_op.h"
35 #include "smaug/operators/reshape_op.h"
36 #include "smaug/operators/sigmoid_op.h"
37 #include "smaug/operators/smv/smv_batch_norm_op.h"
38 #include "smaug/operators/smv/smv_convolution_op.h"
39 #include "smaug/operators/smv/smv_eltwise_add_op.h"
40 #include "smaug/operators/smv/smv_eltwise_mul_op.h"
41 #include "smaug/operators/smv/smv_elu_op.h"
42 #include "smaug/operators/smv/smv_greater_op.h"
43 #include "smaug/operators/smv/smv_inner_product_op.h"
44 #include "smaug/operators/smv/smv_less_op.h"
45 #include "smaug/operators/smv/smv_pooling_op.h"
46 #include "smaug/operators/smv/smv_relu_op.h"
47 #include "smaug/operators/smv/smv_sigmoid_op.h"
48 #include "smaug/operators/smv/smv_softmax_op.h"
49 #include "smaug/operators/smv/smv_tanh_op.h"
50 #include "smaug/operators/softmax_op.h"
51 #include "smaug/operators/split_op.h"
52 #include "smaug/operators/tanh_op.h"
53 #include "smaug/utility/debug_stream.h"
54 #include "smaug/utility/utils.h"
56 using namespace smaug;
61 OpType opType = params.activation();
64 actInfo.function = activation_type::RELU;
67 actInfo.function = activation_type::LRELU;
68 actInfo.params.slope = params.lrelu_params().slope();
71 actInfo.function = activation_type::ELU;
72 actInfo.params.alpha = params.elu_params().alpha();
75 actInfo.function = activation_type::SELU;
76 actInfo.params.alpha = params.elu_params().alpha();
77 actInfo.params.lambda = params.elu_params().lambda_param();
80 actInfo.function = activation_type::TANH;
82 case OpType::HardTanh:
83 actInfo.function = activation_type::HARD_TANH;
84 actInfo.params.min = params.hard_tanh_params().min();
85 actInfo.params.max = params.hard_tanh_params().max();
88 actInfo.function = activation_type::SIGMOID;
91 actInfo.function = activation_type::SOFTMAX;
93 actInfo.function = activation_type::NO_ACTIVATION;
100 template <
typename Backend>
101 static void createAndAddOperator(
const NodeProto& node,
102 const TensorDataArray& tensorDataArray,
103 HostMemoryAccessPolicy memPolicy,
106 const std::string& name = node.name();
107 OpType type = node.op();
109 dout(0) <<
"Adding " << name <<
" (" << OpType_Name(type) <<
").\n";
111 if (type == OpType::Data) {
113 TensorData tensorData;
114 for (
int i = 0; i < tensorDataArray.data_array_size(); i++) {
115 if (tensorDataArray.data_array(i).name() ==
116 node.input_tensors(0).name()) {
117 tensorData = tensorDataArray.data_array(i);
121 auto inputTensor = workspace->addTensor(
122 new Tensor(node.input_tensors(0), tensorData));
123 auto inputTensorOp = Backend::createDataOp(name, workspace);
124 inputTensorOp->setData(inputTensor);
125 network->addOperator(inputTensorOp);
126 }
else if (type == OpType::Convolution3d ||
127 type == OpType::ConvolutionDepthwise) {
129 if (type == OpType::Convolution3d)
130 op = Backend::createConvolutionOp(name, workspace);
132 op = Backend::createDepthwiseConvolutionOp(name, workspace);
133 assert(node.input_tensors_size() == 2);
134 const TensorProto& filterTensorProto = node.input_tensors(1);
135 const TensorShapeProto& shapeProto = filterTensorProto.shape();
136 assert(shapeProto.dims_size() == 4);
137 if (shapeProto.layout() == NCHW) {
139 shapeProto.dims(2), shapeProto.dims(3), shapeProto.dims(0));
142 shapeProto.dims(1), shapeProto.dims(2), shapeProto.dims(0));
144 const ConvParams& convParams = node.params().conv_params();
145 assert(convParams.stride_size() == 2);
146 op->setStride(convParams.stride(0), convParams.stride(1));
147 op->setPadding(convParams.padding());
148 op->setActivation(getActivationInfo(node.params().act_params()));
149 network->addOperator(op);
150 }
else if (type == OpType::MaxPooling || type == OpType::AveragePooling) {
152 if (type == MaxPooling)
153 op = Backend::createMaxPoolingOp(name, workspace);
155 op = Backend::createAvgPoolingOp(name, workspace);
156 const PoolParams& poolParams = node.params().pool_params();
157 assert(poolParams.stride_size() == 2);
158 assert(poolParams.pool_size_size() == 2);
159 op->setPoolingSize(poolParams.pool_size(0), poolParams.pool_size(1));
160 op->setPoolingStride(poolParams.stride(0), poolParams.stride(1));
161 network->addOperator(op);
162 }
else if (type == OpType::InnerProduct) {
163 auto op = Backend::createInnerProductOp(name, workspace);
164 assert(node.input_tensors_size() == 2);
165 const TensorProto& weightTensorProto = node.input_tensors(1);
166 if (weightTensorProto.shape().layout() == NC)
167 op->setNumOutputs(weightTensorProto.shape().dims(0));
169 op->setNumOutputs(weightTensorProto.shape().dims(1));
170 op->setActivation(getActivationInfo(node.params().act_params()));
171 network->addOperator(op);
172 }
else if (type == OpType::Reorder) {
173 DataLayout srcLayout = node.input_tensors(0).shape().layout();
174 DataLayout targetLayout = node.output_tensors(0).shape().layout();
176 if (node.input_tensors(0).shape().dims_size() == 4 &&
177 (targetLayout == NC || targetLayout == CN)) {
178 op = Backend::createFlattenOp(name, workspace);
180 op = Backend::createReorderOp(name, workspace);
181 op->setTargetLayout(node.output_tensors(0).shape().layout());
183 network->addOperator(op);
184 }
else if (type == OpType::Concat) {
185 auto op = Backend::createConcatOp(name, workspace);
186 op->setNumInputs(node.input_tensors_size());
187 op->setConcatAxis(node.params().concat_params().concat_axis());
188 network->addOperator(op);
189 }
else if (type == OpType::Split) {
190 auto op = Backend::createSplitOp(name, workspace);
191 int axis = node.params().split_params().split_axis();
192 std::vector<int> splits;
193 for (
const auto& tensor : node.output_tensors())
194 splits.push_back(tensor.shape().dims(axis));
195 op->setSplits(splits);
196 op->setSplitAxis(axis);
197 network->addOperator(op);
198 }
else if (type == OpType::Reshape) {
199 auto op = Backend::createReshapeOp(name, workspace);
200 const TensorShapeProto& shapeProto = node.output_tensors(0).shape();
201 std::vector<int> shape(
202 shapeProto.dims().begin(), shapeProto.dims().end());
203 DataLayout layout = shapeProto.layout();
204 op->setShape(shape, layout);
205 network->addOperator(op);
206 }
else if (type == OpType::Repeat) {
207 auto op = Backend::createRepeatOp(name, workspace);
208 const TensorShapeProto& inputShape = node.input_tensors(0).shape();
209 const TensorShapeProto& outputShape = node.output_tensors(0).shape();
210 std::vector<int> multiples;
211 for (
int i = 0; i < inputShape.dims_size(); i++)
212 multiples.push_back(outputShape.dims(i) / inputShape.dims(i));
213 op->setMultiples(multiples);
214 network->addOperator(op);
215 }
else if (type == OpType::BatchNorm) {
216 auto op = Backend::createBatchNormOp(name, workspace);
217 op->setActivation(getActivationInfo(node.params().act_params()));
218 network->addOperator(op);
219 }
else if (type == OpType::EltwiseAdd) {
220 auto op = Backend::createEltwiseAddOp(name, workspace);
221 network->addOperator(op);
222 }
else if (type == OpType::EltwiseMul) {
223 auto op = Backend::createEltwiseMulOp(name, workspace);
224 network->addOperator(op);
225 }
else if (type == OpType::Less) {
226 auto op = Backend::createLessOp(name, workspace);
227 network->addOperator(op);
228 }
else if (type == OpType::LessEqual) {
229 auto op = Backend::createLessEqualOp(name, workspace);
230 network->addOperator(op);
231 }
else if (type == OpType::Greater) {
232 auto op = Backend::createGreaterOp(name, workspace);
233 network->addOperator(op);
234 }
else if (type == OpType::GreaterEqual) {
235 auto op = Backend::createGreaterEqualOp(name, workspace);
236 network->addOperator(op);
237 }
else if (type == OpType::Switch) {
238 auto op = Backend::createSwitchOp(name, workspace);
239 network->addOperator(op);
240 }
else if (type == OpType::Merge) {
241 auto op = Backend::createMergeOp(name, workspace);
242 op->setNumInputs(node.input_tensors_size());
243 network->addOperator(op);
244 }
else if (type == OpType::ReLU) {
245 auto op = Backend::createReluOp(name, workspace);
246 network->addOperator(op);
247 }
else if (type == OpType::LReLU) {
249 auto op = Backend::createReluOp(name, workspace);
251 network->addOperator(op);
252 }
else if (type == OpType::ELU) {
253 auto op = Backend::createEluOp(name, workspace);
254 network->addOperator(op);
255 }
else if (type == OpType::SELU) {
256 auto op = Backend::createSeluOp(name, workspace);
257 network->addOperator(op);
258 }
else if (type == OpType::Sigmoid) {
259 auto op = Backend::createSigmoidOp(name, workspace);
260 network->addOperator(op);
261 }
else if (type == OpType::Softmax) {
262 auto op = Backend::createSoftmaxOp(name, workspace);
263 network->addOperator(op);
264 }
else if (type == OpType::Tanh) {
265 auto op = Backend::createTanhOp(name, workspace);
266 network->addOperator(op);
267 }
else if (type == OpType::Padding) {
268 auto op = Backend::createPaddingOp(name, workspace);
269 op->setPaddingSize(node.params().padding_params().padding_size());
270 network->addOperator(op);
271 }
else if (type == OpType::HardTanh) {
272 auto op = Backend::createHardTanhOp(name, workspace);
273 network->addOperator(op);
274 }
else if (type == OpType::UnknownOp) {
275 assert(
false &&
"Invalid operator type!");
278 Operator* op = network->getOperator(name);
280 if (op->isSamplingSupported())
281 op->setSamplingInfo(network->getSamplingInfo());
283 if (memPolicy == HostMemoryAccessPolicy::AllDma) {
284 op->setInputsMemType(MemoryType::dma);
285 op->setWeightsMemType(MemoryType::dma);
286 op->setOutputsMemType(MemoryType::dma);
287 }
else if (memPolicy == HostMemoryAccessPolicy::AllAcp) {
288 op->setInputsMemType(MemoryType::acp);
289 op->setWeightsMemType(MemoryType::acp);
290 op->setOutputsMemType(MemoryType::acp);
291 }
else if (memPolicy == HostMemoryAccessPolicy::AllAcpWithDmaForWeights) {
292 op->setInputsMemType(MemoryType::acp);
293 op->setWeightsMemType(MemoryType::dma);
294 op->setOutputsMemType(MemoryType::acp);
295 }
else if (memPolicy == HostMemoryAccessPolicy::UnknownMemoryPolicy) {
296 assert(
false &&
"Invalid host memory access policy!");
304 for (
int i = 0; i < op->getOutputs().size(); i++) {
305 if (!op->getOutput(i)) {
306 const TensorProto& tensorProto = node.output_tensors(i);
307 Tensor* output = workspace->addTensor(
308 new Tensor(tensorProto.name(), tensorProto.shape()));
310 op->setOutput(output, i);
317 template <
typename Backend>
318 static Network* createNetworkFromProto(
const GraphProto& graphProto,
319 const TensorDataArray& tensorDataArray,
323 network->setSamplingInfo(sampling);
324 for (
int i = 0; i < graphProto.nodes_size(); i++) {
325 const NodeProto& node = graphProto.nodes(i);
326 createAndAddOperator<Backend>(node,
328 graphProto.mem_policy(),
335 for (
int i = 0; i < graphProto.nodes_size(); i++) {
336 const NodeProto& node = graphProto.nodes(i);
337 Operator* op = network->getOperator(node.name());
338 for (
int i = 0; i < node.parents_size(); i++) {
339 std::string inputOpName = node.parents(i);
340 int srcTensorIdx = node.src_tensors_indices(i);
341 Operator* inputOp = network->getOperator(inputOpName);
342 network->addEdge(inputOp, op, { srcTensorIdx, i });
348 const Graph& graph = network->getGraph();
349 EdgeNameMap edges = get(boost::edge_name, graph);
350 std::list<Vertex> vertices;
351 boost::topological_sort(graph, std::front_inserter(vertices));
352 for (
auto v : vertices) {
353 Operator* op = get(boost::vertex_op, graph, v);
354 const std::vector<TensorBase*>& outputs = op->getOutputs();
355 out_edge_iter outEdgeIt, outEdgeEnd;
357 for (boost::tie(outEdgeIt, outEdgeEnd) = out_edges(v, graph);
358 outEdgeIt != outEdgeEnd;
360 Vertex childVertex = target(*outEdgeIt, graph);
361 Operator* child = get(boost::vertex_op, graph, childVertex);
363 child->setInput(op->getOutput(indices.srcIdx), indices.destIdx);
371 const std::string& modelParams,
376 int modelTopoDescriptor = open(modelTopo.c_str(), O_RDONLY);
377 if (modelTopoDescriptor < 0) {
378 cout << modelTopo <<
": network topology file not found." << endl;
381 google::protobuf::io::FileInputStream modelTopoInput(modelTopoDescriptor);
382 if (!google::protobuf::TextFormat::Parse(&modelTopoInput, &graph)) {
383 cout <<
"Failed to parse the network topology file!" << endl;
387 TensorDataArray tensorDataArray;
388 fstream modelParamsFile(modelParams, ios::in | ios::binary);
389 if (!modelParamsFile) {
390 cout << modelParams <<
": network parameters file not found." << endl;
392 }
else if (!tensorDataArray.ParseFromIstream(&modelParamsFile)) {
393 cout <<
"Failed to parse the network parameters file.\n";
397 cout <<
"======================================================\n";
398 cout <<
" Loading the network model...\n";
399 cout <<
"======================================================\n";
401 if (graph.backend() == ReferenceBackend::Name) {
402 network = createNetworkFromProto<ReferenceBackend>(
403 graph, tensorDataArray, sampling, workspace);
404 }
else if (graph.backend() == SmvBackend::Name) {
405 network = createNetworkFromProto<SmvBackend>(
406 graph, tensorDataArray, sampling, workspace);
408 assert(
false &&
"Unknown backend!");
411 cout <<
"======================================================\n";
412 cout <<
" Summary of the network.\n";
413 cout <<
"======================================================\n";
414 network->printSummary();