blob: ba353a32c7af0dc7dfab1c53bac6c02e7e0b6342 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefPreCompiledWorkload.hpp"
7
8namespace armnn
9{
10
11TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12 const WorkloadInfo& info)
13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000014 , m_workloadInfo(info)
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000016 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010022}
23
24void TosaRefPreCompiledWorkload::Execute() const
25{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000026 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
27
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000028 std::vector<std::string> inputNames = handler->GetInputs();
29 std::vector<std::string> outputNames = handler->GetOutputs();
Matthew Sloyan5c54c382022-11-09 16:28:51 +000030
31 TosaReference::IModelRunner runner;
32 GraphStatus status;
33
34 // Initialise the model runner with the TosaSerializationHandler
35 status = runner.initialize(*handler);
36 if(status != GraphStatus::TOSA_VALID)
37 {
38 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
39 }
40
41 // Set the inputs
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000042 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000043 {
44 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
45 switch (dataType)
46 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000047 case DataType::Float16:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000048 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000049 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000050 case DataType::Float32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000051 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan5c54c382022-11-09 16:28:51 +000052 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000053 case DataType::QAsymmU8:
54 case DataType::QAsymmS8:
55 case DataType::QSymmS8:
56 case DataType::QSymmS16:
57 case DataType::Signed32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000058 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000059 break;
60 case DataType::Signed64:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000061 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000062 break;
63 case DataType::Boolean:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000064 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000065 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000066 default:
67 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
68 }
69 }
70
71 // Run the TOSA Reference Model
72 status = runner.run();
73 if(status != GraphStatus::TOSA_VALID)
74 {
75 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
76 }
77
78 // Gets the outputs
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000079 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000080 {
81 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
82 switch (dataType)
83 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000084 case DataType::Float16:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000085 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000086 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000087 case DataType::Float32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000088 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan5c54c382022-11-09 16:28:51 +000089 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000090 case DataType::QAsymmU8:
91 case DataType::QAsymmS8:
92 case DataType::QSymmS8:
93 case DataType::QSymmS16:
94 case DataType::Signed32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000095 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000096 break;
97 case DataType::Signed64:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000098 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000099 break;
100 case DataType::Boolean:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000101 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +0000102 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000103 default:
104 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
105 }
106 }
107}
108
109template <typename T>
110void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
111 std::string inputName,
112 uint32_t inputIndex) const
113{
114 std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
115 m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
116
117 runner.setInput<T>(inputName, inputData);
118}
119
120template <typename T>
121void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
122 std::string outputName,
123 uint32_t outputIndex) const
124{
125 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
126
127 m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100128}
129
130bool TosaRefPreCompiledWorkloadValidate(std::string*)
131{
132 return true;
133}
134
135} //namespace armnn