blob: ffdbf6f49b3d79ad4e22a11dd8d54562502b7b76 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefPreCompiledWorkload.hpp"
7
8namespace armnn
9{
10
11TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12 const WorkloadInfo& info)
13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000014 , m_workloadInfo(info)
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000016 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010022}
23
24void TosaRefPreCompiledWorkload::Execute() const
25{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000026 uint32_t numInputBuffers = static_cast<uint32_t>(m_Data.m_Inputs.size());
27 uint32_t numOutputBuffers = static_cast<uint32_t>(m_Data.m_Outputs.size());
28
29 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
30
31 std::vector<std::string> input_names = handler->GetInputs();
32 std::vector<std::string> output_names = handler->GetOutputs();
33
34 TosaReference::IModelRunner runner;
35 GraphStatus status;
36
37 // Initialise the model runner with the TosaSerializationHandler
38 status = runner.initialize(*handler);
39 if(status != GraphStatus::TOSA_VALID)
40 {
41 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
42 }
43
44 // Set the inputs
45 for (uint32_t inputSlotIdx = 0; inputSlotIdx < numInputBuffers; ++inputSlotIdx)
46 {
47 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
48 switch (dataType)
49 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000050 case DataType::Float16:
51 SetInput<half_float::half>(runner, input_names[inputSlotIdx], inputSlotIdx);
52 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000053 case DataType::Float32:
54 SetInput<float>(runner, input_names[inputSlotIdx], inputSlotIdx);
55 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000056 case DataType::QAsymmU8:
57 case DataType::QAsymmS8:
58 case DataType::QSymmS8:
59 case DataType::QSymmS16:
60 case DataType::Signed32:
61 SetInput<int32_t>(runner, input_names[inputSlotIdx], inputSlotIdx);
62 break;
63 case DataType::Signed64:
64 SetInput<int64_t>(runner, input_names[inputSlotIdx], inputSlotIdx);
65 break;
66 case DataType::Boolean:
67 SetInput<unsigned char>(runner, input_names[inputSlotIdx], inputSlotIdx);
68 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000069 default:
70 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
71 }
72 }
73
74 // Run the TOSA Reference Model
75 status = runner.run();
76 if(status != GraphStatus::TOSA_VALID)
77 {
78 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
79 }
80
81 // Gets the outputs
82 for (uint32_t outputSlotIdx = 0; outputSlotIdx < numOutputBuffers; ++outputSlotIdx)
83 {
84 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
85 switch (dataType)
86 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000087 case DataType::Float16:
88 GetOutput<half_float::half>(runner, output_names[outputSlotIdx], outputSlotIdx);
89 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000090 case DataType::Float32:
91 GetOutput<float>(runner, output_names[outputSlotIdx], outputSlotIdx);
92 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000093 case DataType::QAsymmU8:
94 case DataType::QAsymmS8:
95 case DataType::QSymmS8:
96 case DataType::QSymmS16:
97 case DataType::Signed32:
98 GetOutput<int32_t>(runner, output_names[outputSlotIdx], outputSlotIdx);
99 break;
100 case DataType::Signed64:
101 GetOutput<int64_t>(runner, output_names[outputSlotIdx], outputSlotIdx);
102 break;
103 case DataType::Boolean:
104 GetOutput<unsigned char>(runner, output_names[outputSlotIdx], outputSlotIdx);
105 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000106 default:
107 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
108 }
109 }
110}
111
112template <typename T>
113void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
114 std::string inputName,
115 uint32_t inputIndex) const
116{
117 std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
118 m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
119
120 runner.setInput<T>(inputName, inputData);
121}
122
123template <typename T>
124void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
125 std::string outputName,
126 uint32_t outputIndex) const
127{
128 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
129
130 m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100131}
132
133bool TosaRefPreCompiledWorkloadValidate(std::string*)
134{
135 return true;
136}
137
138} //namespace armnn