blob: 55a484882508087bf4e87bd356f892d27e636446 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include <stdio.h>
17
Eric Kunzee5e26762020-10-13 16:11:07 -070018#include "model_common.h"
19#include "ops/op_factory.h"
20#include "subgraph_traverser.h"
21#include "tosa_serialization_handler.h"
22#include <Eigen/CXX11/Tensor>
23#include <iostream>
24
Kevin Chengcd79f0e2021-06-03 15:00:34 -070025#include <fstream>
26#include <nlohmann/json.hpp>
27
Eric Kunzee5e26762020-10-13 16:11:07 -070028using namespace TosaReference;
29using namespace tosa;
Kevin Chengcd79f0e2021-06-03 15:00:34 -070030using json = nlohmann::json;
Eric Kunzee5e26762020-10-13 16:11:07 -070031
32// Global instantiation of configuration and debug objects
33func_config_t g_func_config;
34func_debug_t g_func_debug;
35
Kevin Chengcd79f0e2021-06-03 15:00:34 -070036int initTestDesc(json& test_desc);
37int readInputTensors(SubgraphTraverser& gt, json test_desc);
38int writeFinalTensors(SubgraphTraverser& gt, json test_desc);
39int loadGraph(TosaSerializationHandler& tsh, json test_desc);
Eric Kunzee5e26762020-10-13 16:11:07 -070040
41int main(int argc, const char** argv)
42{
43 // Initialize configuration and debug subsystems
44 func_model_init_config();
45 func_model_set_default_config(&g_func_config);
46 func_init_debug(&g_func_debug, 0);
47 TosaSerializationHandler tsh;
48
49 if (func_model_parse_cmd_line(&g_func_config, &g_func_debug, argc, argv))
50 {
51 return 1;
52 }
53
Kevin Chengcd79f0e2021-06-03 15:00:34 -070054 json test_desc;
55
56 // Initialize test descriptor
57 if (initTestDesc(test_desc))
58 {
59 SIMPLE_FATAL_ERROR("Unable to load test json");
60 }
61
62 if (loadGraph(tsh, test_desc))
Eric Kunzee5e26762020-10-13 16:11:07 -070063 {
64 SIMPLE_FATAL_ERROR("Unable to load graph");
65 }
66
Eric Kunzee5e26762020-10-13 16:11:07 -070067 SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
68
69 if (main_gt.initializeGraph())
70 {
Kevin Chengacb550f2021-06-29 15:32:19 -070071 WARNING("Unable to initialize main graph traverser.");
72 goto done;
Eric Kunzee5e26762020-10-13 16:11:07 -070073 }
74
75 if (main_gt.linkTensorsAndNodes())
76 {
77 SIMPLE_FATAL_ERROR("Failed to link tensors and nodes");
78 }
79
80 if (main_gt.validateGraph())
81 {
82 SIMPLE_FATAL_ERROR("Failed to validate graph");
83 }
84
85 if (g_func_config.validate_only)
86 {
87 goto done;
88 }
89
Kevin Chengcd79f0e2021-06-03 15:00:34 -070090 if (readInputTensors(main_gt, test_desc))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 {
92 SIMPLE_FATAL_ERROR("Unable to read input tensors");
93 }
94
95 if (g_func_config.eval)
96 {
97
Kevin Chengacb550f2021-06-29 15:32:19 -070098 // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
Eric Kunzee5e26762020-10-13 16:11:07 -070099 if (main_gt.evaluateAll())
100 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700101 ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID,
102 "Upon evaluateAll() returning 1, graph can not be VALID.");
103 }
104 else
105 {
106 ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID ||
107 main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
108 "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 }
110
Kevin Chengacb550f2021-06-29 15:32:19 -0700111 // Only generate output tensor if graph is valid.
112 if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700114 // make sure output tensor is evaluated and show its value
115 int num_output_tensors = main_gt.getNumOutputTensors();
116 bool all_output_valid = true;
117 for (int i = 0; i < num_output_tensors; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700118 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700119 const Tensor* ct = main_gt.getOutputTensor(i);
120 ASSERT_MEM(ct);
121 if (!ct->getIsValid())
Eric Kunzee5e26762020-10-13 16:11:07 -0700122 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700123 ct->dumpTensorParams(g_func_debug.func_debug_file);
124 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
125 {
126 ct->dumpTensor(g_func_debug.func_debug_file);
127 }
128 all_output_valid = false;
Eric Kunzee5e26762020-10-13 16:11:07 -0700129 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700131 if (!all_output_valid)
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700133 main_gt.dumpGraph(g_func_debug.func_debug_file);
134 SIMPLE_FATAL_ERROR(
135 "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
136 }
137
138 if (g_func_config.output_tensors)
139 {
140 if (writeFinalTensors(main_gt, test_desc))
141 {
142 WARNING("Errors encountered in saving output tensors");
143 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 }
145 }
146 }
147
148done:
Kevin Chengacb550f2021-06-29 15:32:19 -0700149 switch (main_gt.getGraphStatus())
150 {
151 case GraphStatus::TOSA_VALID:
152 // Result is valid.
153 break;
154 case GraphStatus::TOSA_UNPREDICTABLE:
155 fprintf(stderr, "Graph result: UNPREDICTABLE.\n");
156 break;
157 case GraphStatus::TOSA_ERROR:
158 fprintf(stderr, "Graph result: ERROR.\n");
159 break;
160 default:
161 fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus());
162 }
163
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 func_fini_debug(&g_func_debug);
165 func_model_config_cleanup();
166
Kevin Chengacb550f2021-06-29 15:32:19 -0700167 return (int)main_gt.getGraphStatus();
Eric Kunzee5e26762020-10-13 16:11:07 -0700168}
169
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700170int loadGraph(TosaSerializationHandler& tsh, json test_desc)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171{
172 char graph_fullname[1024];
173
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700174 snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.flatbuffer_dir,
175 test_desc["tosa_file"].get<std::string>().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700176
177 if (strlen(graph_fullname) <= 2)
178 {
179 func_model_print_help(stderr);
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700180 SIMPLE_FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc=");
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 }
182
183 const char JSON_EXT[] = ".json";
184 int is_json = 0;
185 {
186 // look for JSON file extension
187 size_t suffix_len = strlen(JSON_EXT);
188 size_t str_len = strlen(graph_fullname);
189
190 if (str_len > suffix_len && strncasecmp(graph_fullname + (str_len - suffix_len), JSON_EXT, suffix_len) == 0)
191 {
192 is_json = 1;
193 }
194 }
195
196 if (is_json)
197 {
198 if (tsh.LoadFileSchema(g_func_config.operator_fbs))
199 {
200 SIMPLE_FATAL_ERROR(
201 "\nJSON file detected. Unable to load TOSA flatbuffer schema from: %s\nCheck -Coperator_fbs=",
202 g_func_config.operator_fbs);
203 }
204
205 if (tsh.LoadFileJson(graph_fullname))
206 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700207 SIMPLE_FATAL_ERROR(
208 "\nError loading JSON graph file: %s\nCheck -Ctest_desc=, -Ctosa_file= and -Cflatbuffer_dir=",
209 graph_fullname);
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 }
211 }
212 else
213 {
214 if (tsh.LoadFileTosaFlatbuffer(graph_fullname))
215 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700216 SIMPLE_FATAL_ERROR(
217 "\nError loading TOSA flatbuffer file: %s\nCheck -Ctest_desc=, -Ctosa_file= and -Cflatbuffer_dir=",
218 graph_fullname);
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 }
220 }
221
222 return 0;
223}
224
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700225int readInputTensors(SubgraphTraverser& gt, json test_desc)
Eric Kunzee5e26762020-10-13 16:11:07 -0700226{
227 int tensorCount = gt.getNumInputTensors();
228 Tensor* tensor;
229 char filename[1024];
230
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700231 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700233 if ((tensorCount != (int)test_desc["ifm_name"].size()) || (tensorCount != (int)test_desc["ifm_file"].size()))
Eric Kunzee5e26762020-10-13 16:11:07 -0700234 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700235 WARNING("Number of input tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount,
236 test_desc["ifm_name"].size(), test_desc["ifm_file"].size());
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 return 1;
238 }
239
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700240 for (int i = 0; i < tensorCount; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700242 tensor = gt.getInputTensorByName(test_desc["ifm_name"][i].get<std::string>());
243 if (!tensor)
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700245 WARNING("Unable to find input tensor %s", test_desc["ifm_name"][i].get<std::string>().c_str());
246 return 1;
247 }
248
249 snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir,
250 test_desc["ifm_file"][i].get<std::string>().c_str());
251
252 DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename);
253
254 if (tensor->allocate())
255 {
256 WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
257 return 1;
258 }
259
260 if (tensor->readFromNpyFile(filename))
261 {
262 WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename);
263 tensor->dumpTensorParams(g_func_debug.func_debug_file);
264 return 1;
265 }
266
267 // Push ready consumers to the next node list
268 for (auto gn : tensor->getConsumers())
269 {
270 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
271 {
272 gt.addToNextNodeList(gn);
273 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700274 }
275 }
276 }
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700277 catch (nlohmann::json::type_error& e)
278 {
279 WARNING("Fail accessing test descriptor: %s", e.what());
280 return 1;
281 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700282
283 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
284 {
285 gt.dumpNextNodeList(g_func_debug.func_debug_file);
286 }
287
288 return 0;
289}
290
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700291int writeFinalTensors(SubgraphTraverser& gt, json test_desc)
Eric Kunzee5e26762020-10-13 16:11:07 -0700292{
293 int tensorCount = gt.getNumOutputTensors();
294 const Tensor* tensor;
295 char filename[1024];
296
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700297 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700298 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700299 if ((tensorCount != (int)test_desc["ofm_name"].size()) || (tensorCount != (int)test_desc["ofm_file"].size()))
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700301 WARNING("Number of output tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount,
302 test_desc["ofm_name"].size(), test_desc["ofm_file"].size());
Eric Kunzee5e26762020-10-13 16:11:07 -0700303 return 1;
304 }
305
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700306 for (int i = 0; i < tensorCount; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 {
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700308 tensor = gt.getOutputTensorByName(test_desc["ofm_name"][i].get<std::string>());
309 if (!tensor)
310 {
311 WARNING("Unable to find output tensor %s", test_desc["ofm_name"][i].get<std::string>().c_str());
312 return 1;
313 }
314
Kevin Chengd5934142021-06-28 16:23:24 -0700315 snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir,
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700316 test_desc["ofm_file"][i].get<std::string>().c_str());
317
318 DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
319
320 if (tensor->writeToNpyFile(filename))
321 {
322 WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
323 return 1;
324 }
325 }
326 }
327 catch (nlohmann::json::type_error& e)
328 {
329 WARNING("Fail accessing test descriptor: %s", e.what());
330 return 1;
331 }
332
333 return 0;
334}
335
336// Read "foo,bar,..." and return std::vector({foo, bar, ...})
337std::vector<std::string> parseFromString(std::string raw_str)
338{
339 bool last_pair = false;
340 std::string::size_type start = 0, end;
341 std::string name;
342
343 std::vector<std::string> result;
344 do
345 {
346 end = raw_str.find(',', start);
347 if (end == std::string::npos)
348 last_pair = true;
349
350 name = raw_str.substr(start, end);
351
352 result.push_back(name);
353
354 start = end + 1; // skip comma
355 } while (!last_pair);
356
357 return result;
358}
359
360int initTestDesc(json& test_desc)
361{
362 std::ifstream ifs(g_func_config.test_desc);
363
364 if (ifs.good())
365 {
366 try
367 {
368 test_desc = nlohmann::json::parse(ifs);
369 }
370 catch (nlohmann::json::parse_error& e)
371 {
372 WARNING("Error parsing test descriptor json: %s", e.what());
Eric Kunzee5e26762020-10-13 16:11:07 -0700373 return 1;
374 }
375 }
376
Kevin Chengd5934142021-06-28 16:23:24 -0700377 // Overwrite flatbuffer_dir/output_dir with dirname(g_func_config.test_desc) if it's not specified.
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700378 std::string flatbuffer_dir_str(g_func_config.flatbuffer_dir);
Kevin Chengd5934142021-06-28 16:23:24 -0700379 std::string output_dir_str(g_func_config.output_dir);
380 if (flatbuffer_dir_str.empty() || output_dir_str.empty())
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700381 {
382 std::string test_path(g_func_config.test_desc);
383 std::string test_dir = test_path.substr(0, test_path.find_last_of("/\\"));
Kevin Chengd5934142021-06-28 16:23:24 -0700384 if (flatbuffer_dir_str.empty())
385 {
386 strncpy(g_func_config.flatbuffer_dir, test_dir.c_str(), FOF_STR_LEN);
387 }
388 if (output_dir_str.empty())
389 {
390 strncpy(g_func_config.output_dir, test_dir.c_str(), FOF_STR_LEN);
391 }
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700392 }
393
394 // Overwrite test_desc["tosa_file"] if -Ctosa_file= specified.
395 std::string tosa_file_str(g_func_config.tosa_file);
396 if (!tosa_file_str.empty())
397 {
398 test_desc["tosa_file"] = tosa_file_str;
399 }
400
401 // Overwrite test_desc["ifm_name"] if -Cifm_name= specified.
402 std::string ifm_name_str(g_func_config.ifm_name);
403 if (!ifm_name_str.empty())
404 {
405 std::vector<std::string> ifm_name_vec = parseFromString(ifm_name_str);
406 test_desc["ifm_name"] = ifm_name_vec;
407 }
408
409 // Overwrite test_desc["ifm_file"] if -Cifm_file= specified.
410 std::string ifm_file_str(g_func_config.ifm_file);
411 if (!ifm_file_str.empty())
412 {
413 std::vector<std::string> ifm_file_vec = parseFromString(ifm_file_str);
414 test_desc["ifm_file"] = ifm_file_vec;
415 }
416
417 // Overwrite test_desc["ofm_name"] if -Cofm_name= specified.
418 std::string ofm_name_str(g_func_config.ofm_name);
419 if (!ofm_name_str.empty())
420 {
421 std::vector<std::string> ofm_name_vec = parseFromString(ofm_name_str);
422 test_desc["ofm_name"] = ofm_name_vec;
423 }
424
425 // Overwrite test_desc["ofm_file"] if -Cofm_file= specified.
426 std::string ofm_file_str(g_func_config.ofm_file);
427 if (!ofm_file_str.empty())
428 {
429 std::vector<std::string> ofm_file_vec = parseFromString(ofm_file_str);
430 test_desc["ofm_file"] = ofm_file_vec;
431 }
432
Eric Kunzee5e26762020-10-13 16:11:07 -0700433 return 0;
434}