blob: 26bd29cc03c7ad8f0b40673ca1c7e06ffc317803 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
Kevin May1bea6be2023-12-12 11:18:46 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Francis Murtagh9270d9e2022-08-12 13:54:17 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefPreCompiledWorkload.hpp"
7
8namespace armnn
9{
10
11TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12 const WorkloadInfo& info)
13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000014 , m_workloadInfo(info)
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000016 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010022}
Francis Murtagh9270d9e2022-08-12 13:54:17 +010023void TosaRefPreCompiledWorkload::Execute() const
24{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000025 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
26
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010027 std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28 std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
Matthew Sloyan5c54c382022-11-09 16:28:51 +000029
30 TosaReference::IModelRunner runner;
31 GraphStatus status;
32
33 // Initialise the model runner with the TosaSerializationHandler
34 status = runner.initialize(*handler);
35 if(status != GraphStatus::TOSA_VALID)
36 {
37 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
38 }
39
40 // Set the inputs
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000041 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000042 {
43 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
44 switch (dataType)
45 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000046 case DataType::Float16:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000047 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000048 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000049 case DataType::Float32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000050 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan5c54c382022-11-09 16:28:51 +000051 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000052 case DataType::QAsymmU8:
Teresa Charlince655882023-11-21 15:44:13 +000053 SetInput<uint8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
54 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000055 case DataType::QAsymmS8:
56 case DataType::QSymmS8:
Teresa Charlince655882023-11-21 15:44:13 +000057 SetInput<int8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
58 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000059 case DataType::QSymmS16:
Teresa Charlince655882023-11-21 15:44:13 +000060 SetInput<int16_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
61 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000062 case DataType::Signed32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000063 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000064 break;
65 case DataType::Signed64:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000066 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000067 break;
68 case DataType::Boolean:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000069 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000070 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000071 default:
72 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
73 }
74 }
75
76 // Run the TOSA Reference Model
77 status = runner.run();
78 if(status != GraphStatus::TOSA_VALID)
79 {
80 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
81 }
82
83 // Gets the outputs
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000084 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000085 {
86 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
87 switch (dataType)
88 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000089 case DataType::Float16:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000090 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000091 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000092 case DataType::Float32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000093 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan5c54c382022-11-09 16:28:51 +000094 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000095 case DataType::QAsymmU8:
Teresa Charlince655882023-11-21 15:44:13 +000096 GetOutput<uint8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
97 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000098 case DataType::QAsymmS8:
99 case DataType::QSymmS8:
Teresa Charlince655882023-11-21 15:44:13 +0000100 GetOutput<int8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
101 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +0000102 case DataType::QSymmS16:
Teresa Charlince655882023-11-21 15:44:13 +0000103 GetOutput<int16_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
104 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +0000105 case DataType::Signed32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000106 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +0000107 break;
108 case DataType::Signed64:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000109 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +0000110 break;
111 case DataType::Boolean:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000112 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +0000113 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000114 default:
115 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
116 }
117 }
118}
119
120template <typename T>
121void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
122 std::string inputName,
123 uint32_t inputIndex) const
124{
Teresa Charlince655882023-11-21 15:44:13 +0000125 SetInput<T, T>(runner, inputName, inputIndex);
126}
127
128template <typename T, typename Trunner>
129void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
130 std::string inputName,
131 uint32_t inputIndex) const
132{
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000133 std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
Teresa Charlince655882023-11-21 15:44:13 +0000134 std::vector<Trunner> inputDataRunner(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
135
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000136 m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
137
Teresa Charlince655882023-11-21 15:44:13 +0000138 std::transform(inputData.begin(), inputData.end(),
139 inputDataRunner.begin(), [](T x) { return static_cast<Trunner>(x);});
140
141 runner.setInput<Trunner>(inputName, inputDataRunner);
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000142}
143
144template <typename T>
145void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
146 std::string outputName,
147 uint32_t outputIndex) const
148{
Teresa Charlince655882023-11-21 15:44:13 +0000149 GetOutput<T, T>(runner, outputName, outputIndex);
150}
151
152template <typename T, typename Trunner>
153void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
154 std::string outputName,
155 uint32_t outputIndex) const
156{
157 std::vector<Trunner> actualOutputsRunner = runner.getOutput<Trunner>(outputName);
158 std::vector<T> actualOutputs (actualOutputsRunner.size());
159
160 std::transform(actualOutputsRunner.begin(), actualOutputsRunner.end(),
161 actualOutputs.begin(), [](Trunner x) { return static_cast<T>(x);});
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000162
163 m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100164}
165
166bool TosaRefPreCompiledWorkloadValidate(std::string*)
167{
168 return true;
169}
170
171} //namespace armnn