4 #include <boost/program_options.hpp>
6 #include "core/backend.h"
8 #include "core/scheduler.h"
9 #include "core/network_builder.h"
11 #include "utility/debug_stream.h"
12 #include "utility/utils.h"
13 #include "utility/thread_pool.h"
15 namespace po = boost::program_options;
17 using namespace smaug;
19 int main(
int argc,
char* argv[]) {
20 std::string modelTopo;
21 std::string modelParams;
23 std::string lastOutputFile;
24 bool dumpGraph =
false;
27 std::string samplingLevel =
"no";
32 po::options_description options(
33 "SMAUG Usage: ./smaug model_topo.pbtxt model_params.pb [options]");
36 (
"help,h",
"Display this help message")
37 (
"debug-level", po::value(&debugLevel)->implicit_value(0),
38 "Set the debugging output level. If omitted, all debugging output "
39 "is ignored. If specified without a value, the debug level is set "
41 (
"dump-graph", po::value(&dumpGraph)->implicit_value(
true),
42 "Dump the network in GraphViz format.")
44 "Run the network in gem5 simulation.")
45 (
"print-last-output,p",
46 po::value(&lastOutputFile)->implicit_value(
"stdout"),
47 "Dump the output of the last layer to this file. If specified with "
48 "'proto', the output tensor is serialized to a output.pb file. By "
49 "default, it is printed to stdout.")
51 po::value(&samplingLevel)->implicit_value(
"no"),
52 "Set the sampling level. By default, SMAUG doesn't do any sampling. "
53 "There are five options of sampling: no, low, medium, high and "
54 "very_high. With more sampling, the simulation speed can be greatly "
55 "improved at the expense of accuracy loss.")
58 "Set the number of sample iterations used by every sampling enabled "
59 "entity. By default, the global sample number is set to 1. Larger "
60 "sample number means less sampling.")
63 "The number of accelerators that the backend has. As far as "
64 "simulation goes, if there are multiple accelerators available, "
65 "SMAUG requires the accelerator IDs (configured in the gem5 "
66 "configuration file) to be monotonically incremented by 1.")
68 po::value(&numThreads)->implicit_value(1),
69 "Number of threads in the thread pool.")
70 (
"use-systolic-array",
72 "If the backend contains a systolic array, use it whenever possible.");
75 po::options_description hidden;
76 hidden.add_options()(
"model-topo-file", po::value(&modelTopo),
77 "Model topology protobuf file");
78 hidden.add_options()(
"model-params-file", po::value(&modelParams),
79 "Model parameters protobuf file");
80 po::options_description all, visible;
81 all.add(options).add(hidden);
84 po::positional_options_description p;
85 p.add(
"model-topo-file", 1);
86 p.add(
"model-params-file", 1);
88 po::store(po::command_line_parser(argc, argv)
95 }
catch (po::error& e) {
96 std::cout <<
"ERROR: " << e.what() <<
"\n";
100 if (vm.count(
"help")) {
101 std::cout << visible <<
"\n";
104 if (modelTopo.empty() || modelParams.empty()) {
105 std::cout <<
"The model protobuf files must be specified!\n";
110 std::cout <<
"Model topology file: " << modelTopo <<
"\n";
111 std::cout <<
"Model parameters file: " << modelParams <<
"\n";
113 if (samplingLevel ==
"no") {
114 sampling.
level = NoSampling;
115 }
else if (samplingLevel ==
"low") {
116 sampling.
level = Low;
117 }
else if (samplingLevel ==
"medium") {
118 sampling.
level = Medium;
119 }
else if (samplingLevel ==
"high") {
120 sampling.
level = High;
121 }
else if (samplingLevel ==
"very_high") {
122 sampling.
level = VeryHigh;
124 std::cout <<
"Doesn't support the specified sampling option: "
125 << samplingLevel <<
"\n";
128 if (sampling.
level > NoSampling) {
129 std::cout <<
"Sampling level: " << samplingLevel
130 <<
", number of sample iterations: "
135 std::cout <<
"The number of accelerators exceeds the max number!\n";
140 std::cout <<
"SMAUG requires the accelerator IDs (configured in the "
141 "gem5 configuration file) to be monotonically incremented "
145 if (numThreads != -1) {
146 std::cout <<
"Using a thread pool, size: " << numThreads <<
".\n";
152 buildNetwork(modelTopo, modelParams, sampling, workspace);
153 ReferenceBackend::initGlobals();
154 SmvBackend::initGlobals();
157 network->dumpDataflowGraph();
159 if (!network->validate())
163 Tensor* output = scheduler.runNetwork();
165 if (!lastOutputFile.empty()) {
166 if (lastOutputFile ==
"stdout") {
167 std::cout <<
"Final network output:\n" << *output <<
"\n";
168 }
else if (lastOutputFile ==
"proto") {
170 std::fstream outfile(
"output.pb", std::ios::out | std::ios::trunc |
173 if (!tensorProto->SerializeToOstream(&outfile)) {
174 std::cerr <<
"Failed to serialize the output tensor and write "
175 "it to the given C++ ostream! Did you run out of "
181 std::ofstream outfile(lastOutputFile);
182 outfile <<
"Final network output:\n" << *output <<
"\n";
191 ReferenceBackend::freeGlobals();
192 SmvBackend::freeGlobals();