blob: eb7180ab5fa13f1de15df15c0f31cddc5934a66e [file] [log] [blame]
Matthew Sloyanba5fad32022-09-26 13:31:43 +01001
Fabrizio Indirli72038352023-12-11 11:15:32 +00002// Copyright (c) 2022-2024, ARM Limited.
Matthew Sloyanba5fad32022-09-26 13:31:43 +01003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "model_runner_impl.h"
17
18using namespace TosaReference;
19
20ModelRunnerImpl::ModelRunnerImpl()
21{}
22
Jerry Ge9c9c8da2023-07-19 23:08:16 +000023ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config, const func_debug_t& func_debug)
Matthew Sloyanba5fad32022-09-26 13:31:43 +010024{
25 g_func_config = func_config;
Jerry Ge9c9c8da2023-07-19 23:08:16 +000026 g_func_debug = func_debug;
Matthew Sloyanba5fad32022-09-26 13:31:43 +010027}
28
29ModelRunnerImpl::~ModelRunnerImpl()
30{
31 g_func_debug.fini_debug();
32 delete _main_gt;
33};
34
35void ModelRunnerImpl::setFuncConfig(func_config_t& func_config)
36{
37 g_func_config = func_config;
38}
39void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug)
40{
41 g_func_debug = func_debug;
42}
43
44GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler)
45{
46 validateTosaVersion(serialization_handler);
Jerry Ge9e94af82022-10-27 09:57:00 -070047 return initialize(serialization_handler.GetMainRegion()->GetBlocks()[0], &serialization_handler);
Grant Watson64285a12022-11-16 15:32:39 +000048}
Matthew Sloyanba5fad32022-09-26 13:31:43 +010049
Grant Watson64285a12022-11-16 15:32:39 +000050GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb)
51{
52 return initialize(&bb, nullptr);
Matthew Sloyanba5fad32022-09-26 13:31:43 +010053}
54
55GraphStatus ModelRunnerImpl::run()
56{
57 if (_main_gt == nullptr)
58 {
59 FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()");
60 }
61
62 if (g_func_config.validate_only)
63 {
64 goto done;
65 }
66
67 // Validate the number of inputs matches the
68 if (static_cast<uint32_t>(_main_gt->getNumInputTensors()) != n_input_tensors)
69 {
70 FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). "
71 "setInput() must be called for each input.",
72 n_input_tensors, _main_gt->getNumInputTensors());
73 }
74
75 if (g_func_config.eval)
76 {
77 // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
78 if (_main_gt->evaluateAll())
79 {
80 ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID,
81 "Upon evaluateAll() returning 1, graph can not be VALID.");
82 }
83 else
84 {
85 ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID ||
86 _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
87 "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
88 }
89
90 // Only generate output tensor if graph is valid.
91 if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID)
92 {
93 // Make sure output tensor is evaluated and show its value
94 int num_output_tensors = _main_gt->getNumOutputTensors();
95 bool all_output_valid = true;
96 for (int i = 0; i < num_output_tensors; i++)
97 {
98 const Tensor* ct = _main_gt->getOutputTensor(i);
99 ASSERT_MEM(ct);
100 if (!ct->getIsValid())
101 {
102 ct->dumpTensorParams(g_func_debug.func_debug_file);
103 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
104 {
105 ct->dumpTensor(g_func_debug.func_debug_file);
106 }
107 all_output_valid = false;
108 }
109 }
110 if (!all_output_valid)
111 {
112 _main_gt->dumpGraph(g_func_debug.func_debug_file);
113 FATAL_ERROR(
114 "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
115 }
116 }
117 }
118
119done:
120 // Print status if not valid and do cleanup.
121 checkGraphStatus(*_main_gt);
122 g_func_debug.fini_debug();
123
124 return _main_gt->getGraphStatus();
125}
126
127template <typename T>
Grant Watson64285a12022-11-16 15:32:39 +0000128int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy<T> vals)
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100129{
130 if (_main_gt == nullptr)
131 {
132 FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()");
133 }
134
135 Tensor* tensor;
136 tensor = _main_gt->getInputTensorByName(input_name);
137
138 if (!tensor)
139 {
140 WARNING("Unable to find input tensor %s", input_name.c_str());
141 return 1;
142 }
143
144 if (!tensor->is_allocated())
145 {
146 WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str());
147 return 1;
148 }
149
150 if (tensor->readfromVector(vals))
151 {
152 WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str());
153 return 1;
154 }
155
156 // Push ready consumers to the next node list
157 for (auto gn : tensor->getConsumers())
158 {
159 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
160 {
161 _main_gt->addToNextNodeList(gn);
162 }
163 }
164
165 n_input_tensors++;
166 return 0;
167}
168
Fabrizio Indirli72038352023-12-11 11:15:32 +0000169int ModelRunnerImpl::setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size)
170{
171 ASSERT_MSG(tensor, "Tensor not provided!");
172 if (!g_func_config.precise_mode)
173 {
174 WARNING("Cannot set input tensor %s using precise mode setters when not running in precise mode!",
175 input_name.c_str());
176 return 1;
177 }
178
179 DType ser_dtype = tensor->getSerializationDtype();
180 int status;
181
182 switch (ser_dtype)
183 {
184 case DType::DType_FP16: {
185 auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
186 const int elements = size / sizeof(half_float::half);
187 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
188 break;
189 }
190 case DType::DType_FP32: {
191 auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
192 const int elements = size / sizeof(float);
193 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
194 break;
195 }
196 default:
197 status = 1;
198 }
199
200 return status;
201}
202
Grant Watson64285a12022-11-16 15:32:39 +0000203int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size)
204{
205 if (_main_gt == nullptr)
206 {
207 FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()");
208 }
209
210 Tensor* tensor;
211 tensor = _main_gt->getInputTensorByName(input_name);
212
213 if (!tensor)
214 {
215 WARNING("Unable to find input tensor %s", input_name.c_str());
216 return 1;
217 }
218
219 int status = 0;
220 switch (tensor->getDtype())
221 {
Tai Lya4d748b2023-03-28 22:06:56 +0000222 case TOSA_REF_TYPE_FP16: {
Grant Watson64285a12022-11-16 15:32:39 +0000223 auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
224 const int elements = size / sizeof(half_float::half);
225 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
226 break;
227 }
Tai Lya4d748b2023-03-28 22:06:56 +0000228 case TOSA_REF_TYPE_FP32: {
Grant Watson64285a12022-11-16 15:32:39 +0000229 auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
230 const int elements = size / sizeof(float);
231 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
232 break;
233 }
Fabrizio Indirli72038352023-12-11 11:15:32 +0000234 case TOSA_REF_TYPE_FP64:
235 if (g_func_config.precise_mode)
236 {
237 status = setInputForPrecMode(tensor, input_name, raw_ptr, size);
238 }
239 else
240 {
241 auto typed_ptr = reinterpret_cast<double*>(raw_ptr);
242 const int elements = size / sizeof(double);
243 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
244 }
245 break;
Jerry Gec5291692024-01-02 22:29:08 +0000246 case TOSA_REF_TYPE_INT8: {
247 auto typed_ptr = reinterpret_cast<int8_t*>(raw_ptr);
248 const int elements = size / sizeof(int8_t);
249 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
250 break;
251 }
Georgios Pinitase9059772023-12-06 18:52:30 +0000252 case TOSA_REF_TYPE_INT16: {
253 auto typed_ptr = reinterpret_cast<int16_t*>(raw_ptr);
254 const int elements = size / sizeof(int16_t);
255 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
256 break;
257 }
Jiacheng Liange7c7cab2023-07-14 12:43:46 +0100258 case TOSA_REF_TYPE_INT32: {
259 auto typed_ptr = reinterpret_cast<int*>(raw_ptr);
260 const int elements = size / sizeof(int);
261 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
262 break;
263 }
Jack Franklandc48590e2023-10-17 17:01:07 +0100264 case TOSA_REF_TYPE_BOOL: {
265 auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr);
266 const int elements = size / sizeof(unsigned char);
267 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
268 break;
269 }
Dmitrii Agibov455e8702024-01-29 15:39:52 +0000270 case TOSA_REF_TYPE_SHAPE: {
271 auto typed_ptr = reinterpret_cast<int64_t*>(raw_ptr);
272 const int elements = size / sizeof(int64_t);
273 status = setInput(input_name, ArrayProxy(elements, typed_ptr));
274 break;
275 }
Grant Watson64285a12022-11-16 15:32:39 +0000276 default:
277 status = 1;
278 }
279
280 return status;
281}
282
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100283template <typename T>
284std::vector<T> ModelRunnerImpl::getOutput(std::string output_name)
285{
286 if (_main_gt == nullptr)
287 {
288 FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()");
289 }
290
291 Tensor* tensor;
292 tensor = _main_gt->getOutputTensorByName(output_name);
293
294 if (!tensor)
295 {
296 WARNING("Unable to find output tensor %s", output_name.c_str());
297 return std::vector<T>();
298 }
299
Matthew Sloyan2e4d8892022-10-18 18:02:48 +0100300 std::vector<T> outputs(tensor->getElementCount());
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100301
Grant Watson64285a12022-11-16 15:32:39 +0000302 if (tensor->writeToVector(ArrayProxy<T>(outputs)))
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100303 {
304 WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str());
305 return std::vector<T>();
306 }
307
308 return outputs;
309}
310
Grant Watson64285a12022-11-16 15:32:39 +0000311int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t size)
312{
313 if (_main_gt == nullptr)
314 {
315 FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()");
316 }
317
318 Tensor* tensor;
319 tensor = _main_gt->getOutputTensorByName(output_name);
320
321 if (!tensor)
322 {
323 WARNING("Unable to find output tensor %s", output_name.c_str());
324 return 1;
325 }
326
327 int status = 0;
328 switch (tensor->getDtype())
329 {
Tai Lya4d748b2023-03-28 22:06:56 +0000330 case TOSA_REF_TYPE_FP16: {
Grant Watson64285a12022-11-16 15:32:39 +0000331 auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
332 const int elements = size / sizeof(half_float::half);
333 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
334 break;
335 }
Tai Lya4d748b2023-03-28 22:06:56 +0000336 case TOSA_REF_TYPE_FP32: {
Grant Watson64285a12022-11-16 15:32:39 +0000337 auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
338 const int elements = size / sizeof(float);
339 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
340 break;
341 }
Fabrizio Indirli72038352023-12-11 11:15:32 +0000342 case TOSA_REF_TYPE_FP64: {
343 auto typed_ptr = reinterpret_cast<double*>(raw_ptr);
344 const int elements = size / sizeof(double);
345 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
346 break;
347 }
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100348 case TOSA_REF_TYPE_BOOL: {
349 auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr);
350 const int elements = size / sizeof(unsigned char);
351 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
352 break;
353 }
Jerry Gec5291692024-01-02 22:29:08 +0000354 case TOSA_REF_TYPE_INT8: {
355 auto typed_ptr = reinterpret_cast<int8_t*>(raw_ptr);
356 const int elements = size / sizeof(int8_t);
357 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
358 break;
359 }
Georgios Pinitase9059772023-12-06 18:52:30 +0000360 case TOSA_REF_TYPE_INT16: {
361 auto typed_ptr = reinterpret_cast<int16_t*>(raw_ptr);
362 const int elements = size / sizeof(int16_t);
363 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
364 break;
365 }
Jiacheng Liange7c7cab2023-07-14 12:43:46 +0100366 case TOSA_REF_TYPE_INT32: {
367 auto typed_ptr = reinterpret_cast<int*>(raw_ptr);
368 const int elements = size / sizeof(int);
369 status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
370 break;
371 }
Grant Watson64285a12022-11-16 15:32:39 +0000372 default:
373 status = 1;
374 }
375 if (status)
376 {
377 WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str());
378 return 1;
379 }
380
381 return 0;
382}
383
384GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb,
385 TosaSerializationHandler* serialization_handler)
386{
387 if (serialization_handler != nullptr)
388 validateTosaVersion(*serialization_handler);
389
390 // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
391 _main_gt = nullptr;
Jerry Ge9e94af82022-10-27 09:57:00 -0700392 _main_gt = new SubgraphTraverser(bb, serialization_handler, nullptr);
Grant Watson64285a12022-11-16 15:32:39 +0000393
394 if (_main_gt == nullptr)
395 {
396 WARNING("An error occurred when generating main graph traverser.");
397 return GraphStatus::TOSA_ERROR;
398 }
399
400 if (_main_gt->initializeGraph())
401 {
402 WARNING("Unable to initialize main graph traverser.");
403 return _main_gt->getGraphStatus();
404 }
405
406 if (_main_gt->linkTensorsAndNodes())
407 {
408 WARNING("Failed to link tensors and nodes");
409 return _main_gt->getGraphStatus();
410 }
411
412 if (_main_gt->validateGraph())
413 {
414 WARNING("Failed to validate graph.");
415 return _main_gt->getGraphStatus();
416 }
417
Jerry Gee5cabbf2023-07-17 21:33:17 +0000418 if (_main_gt->allocateInputTensors())
Grant Watson64285a12022-11-16 15:32:39 +0000419 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000420 WARNING("Failed to allocate input tensors.");
Grant Watson64285a12022-11-16 15:32:39 +0000421 return _main_gt->getGraphStatus();
422 }
423
424 return _main_gt->getGraphStatus();
425}
426
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100427void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler)
428{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000429 TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, TOSA_REFERENCE_MODEL_VERSION_MINOR,
430 TOSA_REFERENCE_MODEL_VERSION_PATCH, TOSA_REFERENCE_MODEL_VERSION_DRAFT);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100431
Jerry Ge391cc5e2023-08-05 00:23:28 +0000432 TosaVersion::compat_t is_compat = TosaVersion::is_compatible(model_version, serialization_handler.GetVersion());
433
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100434 switch (is_compat)
435 {
436 case TosaVersion::compat_t::COMPLETELY_COMPATIBLE:
437 break;
Jerry Ge391cc5e2023-08-05 00:23:28 +0000438 case TosaVersion::compat_t::BACKWARD_COMPATIBLE:
439 WARNING("Reference model version %s is backward compatible with serializer version %s.",
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100440 model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str());
441 break;
442 case TosaVersion::compat_t::NOT_COMPATIBLE:
443 FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.",
444 model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str());
445 }
446}
447
448void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt)
449{
450 switch (main_gt.getGraphStatus())
451 {
452 case GraphStatus::TOSA_VALID:
453 // Result is valid.
454 break;
455 case GraphStatus::TOSA_UNPREDICTABLE:
456 WARNING("Graph result: UNPREDICTABLE.");
457 break;
458 case GraphStatus::TOSA_ERROR:
459 WARNING("Graph result: ERROR.");
460 break;
461 default:
462 WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus());
463 }
464}
465
466// Template explicit specialization
Fabrizio Indirli72038352023-12-11 11:15:32 +0000467template int ModelRunnerImpl::setInput<double>(std::string input_name, ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000468template int ModelRunnerImpl::setInput<float>(std::string input_name, ArrayProxy<float> vals);
469template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, ArrayProxy<half_float::half> vals);
Jerry Gec5291692024-01-02 22:29:08 +0000470template int ModelRunnerImpl::setInput<int8_t>(std::string input_name, ArrayProxy<int8_t> vals);
471template int ModelRunnerImpl::setInput<int16_t>(std::string input_name, ArrayProxy<int16_t> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000472template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, ArrayProxy<int32_t> vals);
473template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, ArrayProxy<int64_t> vals);
474template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100475
Fabrizio Indirli72038352023-12-11 11:15:32 +0000476template std::vector<double> ModelRunnerImpl::getOutput<double>(std::string output_name);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100477template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name);
Matthew Sloyan2e4d8892022-10-18 18:02:48 +0100478template std::vector<half_float::half> ModelRunnerImpl::getOutput<half_float::half>(std::string output_name);
Jerry Gec5291692024-01-02 22:29:08 +0000479template std::vector<int8_t> ModelRunnerImpl::getOutput<int8_t>(std::string output_name);
480template std::vector<int16_t> ModelRunnerImpl::getOutput<int16_t>(std::string output_name);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100481template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name);
482template std::vector<int64_t> ModelRunnerImpl::getOutput<int64_t>(std::string output_name);
Jack Franklandc48590e2023-10-17 17:01:07 +0100483template std::vector<unsigned char> ModelRunnerImpl::getOutput<unsigned char>(std::string output_name);