blob: 0ef44f9d9d9f044615854e6b49fddcce81bd6222 [file] [log] [blame]
Matthew Sloyanba5fad32022-09-26 13:31:43 +01001
Jerry Ge9c9c8da2023-07-19 23:08:16 +00002// Copyright (c) 2022-2023, ARM Limited.
Matthew Sloyanba5fad32022-09-26 13:31:43 +01003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "model_runner_impl.h"
17
18using namespace TosaReference;
19
20// Global instantiation of configuration and debug objects
21func_config_t g_func_config;
22func_debug_t g_func_debug;
23
Jerry Ge9c9c8da2023-07-19 23:08:16 +000024IModelRunner::IModelRunner()
25 : model_runner_impl(new ModelRunnerImpl())
Matthew Sloyanba5fad32022-09-26 13:31:43 +010026{}
27
Jerry Ge9c9c8da2023-07-19 23:08:16 +000028IModelRunner::IModelRunner(const func_config_t& func_config, const func_debug_t& func_debug)
Matthew Sloyanba5fad32022-09-26 13:31:43 +010029 : model_runner_impl(new ModelRunnerImpl(func_config, func_debug))
30{}
31
32IModelRunner::~IModelRunner()
33{}
34
35void IModelRunner::setFuncConfig(func_config_t& func_config)
36{
37 model_runner_impl->setFuncConfig(func_config);
38}
39
40void IModelRunner::setFuncDebug(func_debug_t& func_debug)
41{
42 model_runner_impl->setFuncDebug(func_debug);
43}
44
45GraphStatus IModelRunner::initialize(tosa::TosaSerializationHandler& serialization_handler)
46{
47 return model_runner_impl->initialize(serialization_handler);
48}
49
50GraphStatus IModelRunner::run()
51{
52 return model_runner_impl->run();
53}
54
55template <typename T>
Matthew Sloyan2e4d8892022-10-18 18:02:48 +010056int IModelRunner::setInput(std::string input_name, std::vector<T>& vals)
Matthew Sloyanba5fad32022-09-26 13:31:43 +010057{
Grant Watson64285a12022-11-16 15:32:39 +000058 return model_runner_impl->setInput<T>(input_name, ArrayProxy(vals.size(), vals.data()));
59}
60
61int IModelRunner::setInput(std::string input_name, uint8_t* raw_ptr, size_t size)
62{
63 return model_runner_impl->setInput(input_name, raw_ptr, size);
Matthew Sloyanba5fad32022-09-26 13:31:43 +010064}
65
66template <typename T>
67std::vector<T> IModelRunner::getOutput(std::string output_name)
68{
69 return model_runner_impl->getOutput<T>(output_name);
70}
71
Grant Watson64285a12022-11-16 15:32:39 +000072int IModelRunner::getOutput(std::string output_name, uint8_t* raw_ptr, size_t size)
73{
74 return model_runner_impl->getOutput(output_name, raw_ptr, size);
75}
76
Matthew Sloyanba5fad32022-09-26 13:31:43 +010077// Template explicit specialization
Matthew Sloyan2e4d8892022-10-18 18:02:48 +010078template int IModelRunner::setInput<float>(std::string input_name, std::vector<float>& vals);
79template int IModelRunner::setInput<half_float::half>(std::string input_name, std::vector<half_float::half>& vals);
80template int IModelRunner::setInput<int32_t>(std::string input_name, std::vector<int32_t>& vals);
81template int IModelRunner::setInput<int64_t>(std::string input_name, std::vector<int64_t>& vals);
82template int IModelRunner::setInput<unsigned char>(std::string input_name, std::vector<unsigned char>& vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +010083
84template std::vector<float> IModelRunner::getOutput<float>(std::string output_name);
Matthew Sloyan2e4d8892022-10-18 18:02:48 +010085template std::vector<half_float::half> IModelRunner::getOutput<half_float::half>(std::string output_name);
Matthew Sloyanba5fad32022-09-26 13:31:43 +010086template std::vector<int32_t> IModelRunner::getOutput<int32_t>(std::string output_name);
87template std::vector<int64_t> IModelRunner::getOutput<int64_t>(std::string output_name);
88template std::vector<unsigned char> IModelRunner::getOutput<unsigned char>(std::string output_name);