blob: 18d2900efff9dbee148a824b49c8d41ef5a59df2 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefPreCompiledWorkload.hpp"
7
8namespace armnn
9{
10
11TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12 const WorkloadInfo& info)
13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
Matthew Sloyan5c54c382022-11-09 16:28:51 +000014 , m_workloadInfo(info)
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000016 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010022}
23
24void TosaRefPreCompiledWorkload::Execute() const
25{
Matthew Sloyan5c54c382022-11-09 16:28:51 +000026 uint32_t numInputBuffers = static_cast<uint32_t>(m_Data.m_Inputs.size());
27 uint32_t numOutputBuffers = static_cast<uint32_t>(m_Data.m_Outputs.size());
28
29 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
30
31 std::vector<std::string> input_names = handler->GetInputs();
32 std::vector<std::string> output_names = handler->GetOutputs();
33
34 TosaReference::IModelRunner runner;
35 GraphStatus status;
36
37 // Initialise the model runner with the TosaSerializationHandler
38 status = runner.initialize(*handler);
39 if(status != GraphStatus::TOSA_VALID)
40 {
41 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
42 }
43
44 // Set the inputs
45 for (uint32_t inputSlotIdx = 0; inputSlotIdx < numInputBuffers; ++inputSlotIdx)
46 {
47 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
48 switch (dataType)
49 {
50 case DataType::Float32:
51 SetInput<float>(runner, input_names[inputSlotIdx], inputSlotIdx);
52 break;
53 default:
54 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
55 }
56 }
57
58 // Run the TOSA Reference Model
59 status = runner.run();
60 if(status != GraphStatus::TOSA_VALID)
61 {
62 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
63 }
64
65 // Gets the outputs
66 for (uint32_t outputSlotIdx = 0; outputSlotIdx < numOutputBuffers; ++outputSlotIdx)
67 {
68 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
69 switch (dataType)
70 {
71 case DataType::Float32:
72 GetOutput<float>(runner, output_names[outputSlotIdx], outputSlotIdx);
73 break;
74 default:
75 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
76 }
77 }
78}
79
80template <typename T>
81void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
82 std::string inputName,
83 uint32_t inputIndex) const
84{
85 std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
86 m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
87
88 runner.setInput<T>(inputName, inputData);
89}
90
91template <typename T>
92void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
93 std::string outputName,
94 uint32_t outputIndex) const
95{
96 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
97
98 m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
Francis Murtagh9270d9e2022-08-12 13:54:17 +010099}
100
101bool TosaRefPreCompiledWorkloadValidate(std::string*)
102{
103 return true;
104}
105
106} //namespace armnn