blob: 8b08f01b23b26029faa8cc55da96d126e0ff6c8b [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefPreCompiledWorkload.hpp"
7
8namespace armnn
9{
10
11TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12 const WorkloadInfo& info)
13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000014 , m_workloadInfo(info)
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000016 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010022}
Francis Murtagh9270d9e2022-08-12 13:54:17 +010023void TosaRefPreCompiledWorkload::Execute() const
24{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000025 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
26
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010027 std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28 std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
Matthew Sloyan5c54c382022-11-09 16:28:51 +000029
30 TosaReference::IModelRunner runner;
31 GraphStatus status;
32
33 // Initialise the model runner with the TosaSerializationHandler
34 status = runner.initialize(*handler);
35 if(status != GraphStatus::TOSA_VALID)
36 {
37 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
38 }
39
40 // Set the inputs
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000041 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000042 {
43 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
44 switch (dataType)
45 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000046 case DataType::Float16:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000047 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000048 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000049 case DataType::Float32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000050 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan5c54c382022-11-09 16:28:51 +000051 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000052 case DataType::QAsymmU8:
53 case DataType::QAsymmS8:
54 case DataType::QSymmS8:
55 case DataType::QSymmS16:
56 case DataType::Signed32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000057 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000058 break;
59 case DataType::Signed64:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000060 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000061 break;
62 case DataType::Boolean:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000063 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000064 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000065 default:
66 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
67 }
68 }
69
70 // Run the TOSA Reference Model
71 status = runner.run();
72 if(status != GraphStatus::TOSA_VALID)
73 {
74 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
75 }
76
77 // Gets the outputs
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000078 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000079 {
80 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
81 switch (dataType)
82 {
Matthew Sloyan2523b792022-11-14 10:18:01 +000083 case DataType::Float16:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000084 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000085 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000086 case DataType::Float32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000087 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan5c54c382022-11-09 16:28:51 +000088 break;
Matthew Sloyan2523b792022-11-14 10:18:01 +000089 case DataType::QAsymmU8:
90 case DataType::QAsymmS8:
91 case DataType::QSymmS8:
92 case DataType::QSymmS16:
93 case DataType::Signed32:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000094 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000095 break;
96 case DataType::Signed64:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000097 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +000098 break;
99 case DataType::Boolean:
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000100 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
Matthew Sloyan2523b792022-11-14 10:18:01 +0000101 break;
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000102 default:
103 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
104 }
105 }
106}
107
108template <typename T>
109void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
110 std::string inputName,
111 uint32_t inputIndex) const
112{
113 std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
114 m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
115
116 runner.setInput<T>(inputName, inputData);
117}
118
119template <typename T>
120void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
121 std::string outputName,
122 uint32_t outputIndex) const
123{
124 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
125
126 m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100127}
128
129bool TosaRefPreCompiledWorkloadValidate(std::string*)
130{
131 return true;
132}
133
134} //namespace armnn