blob: ddcc5a8f0a7e713617261441258cc2c49518d312 [file] [log] [blame]
Colm Donelan0c479742021-12-10 12:43:54 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "IWorkload.hpp"
8#include "WorkloadData.hpp"
9#include "WorkloadInfo.hpp"
10#include "WorkingMemDescriptor.hpp"
11
12#include <Profiling.hpp>
13#include <ProfilingService.hpp>
14
15#include <algorithm>
16
17namespace armnn
18{
19
20// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
21// in the various workload factories.
22// There should never be an instantiation of a NullWorkload.
23class NullWorkload : public IWorkload
24{
25 NullWorkload()=delete;
26};
27
28template <typename QueueDescriptor>
29class BaseWorkload : public IWorkload
30{
31public:
32
33 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
34 : m_Data(descriptor),
Cathal Corbett5aa9fd72022-02-25 15:33:28 +000035 m_Guid(arm::pipe::ProfilingService::GetNextGuid())
Colm Donelan0c479742021-12-10 12:43:54 +000036 {
37 m_Data.Validate(info);
38 }
39
40 void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
41 {
42 ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
43 std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
44
45 m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
46 m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
47
48 Execute();
49 };
50
51 void PostAllocationConfigure() override {}
52
53 const QueueDescriptor& GetData() const { return m_Data; }
54
Cathal Corbett5aa9fd72022-02-25 15:33:28 +000055 arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
Colm Donelan0c479742021-12-10 12:43:54 +000056
Finn Williams73c547d2022-02-15 20:47:34 +000057 virtual bool SupportsTensorHandleReplacement() const override
58 {
59 return false;
60 }
61
Teresa Charlin788e2a62022-01-17 21:19:52 +000062 // Replace input tensor handle with the given TensorHandle
63 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
64 {
Finn Williams73c547d2022-02-15 20:47:34 +000065 armnn::IgnoreUnused(tensorHandle, slot);
66 throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
Teresa Charlin788e2a62022-01-17 21:19:52 +000067 }
68
69 // Replace output tensor handle with the given TensorHandle
70 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
71 {
Finn Williams73c547d2022-02-15 20:47:34 +000072 armnn::IgnoreUnused(tensorHandle, slot);
73 throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
Teresa Charlin788e2a62022-01-17 21:19:52 +000074 }
75
Colm Donelan0c479742021-12-10 12:43:54 +000076protected:
77 QueueDescriptor m_Data;
Cathal Corbett5aa9fd72022-02-25 15:33:28 +000078 const arm::pipe::ProfilingGuid m_Guid;
Colm Donelan0c479742021-12-10 12:43:54 +000079
80private:
81 std::mutex m_AsyncWorkloadMutex;
82};
83
84// TypedWorkload used
85template <typename QueueDescriptor, armnn::DataType... DataTypes>
86class TypedWorkload : public BaseWorkload<QueueDescriptor>
87{
88public:
89
90 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
91 : BaseWorkload<QueueDescriptor>(descriptor, info)
92 {
93 std::vector<armnn::DataType> dataTypes = {DataTypes...};
94 armnn::DataType expectedInputType;
95
96 if (!info.m_InputTensorInfos.empty())
97 {
98 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
99
100 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
101 {
102 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
103 }
104 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
105 info.m_InputTensorInfos.end(),
106 [&](auto it){
107 return it.GetDataType() == expectedInputType;
108 }),
109 "Trying to create workload with incorrect type");
110 }
111 armnn::DataType expectedOutputType;
112
113 if (!info.m_OutputTensorInfos.empty())
114 {
115 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
116
117 if (!info.m_InputTensorInfos.empty())
118 {
119 if (expectedOutputType != expectedInputType)
120 {
121 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
122 }
123 }
124 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
125 {
126 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
127 }
128 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
129 info.m_OutputTensorInfos.end(),
130 [&](auto it){
131 return it.GetDataType() == expectedOutputType;
132 }),
133 "Trying to create workload with incorrect type");
134 }
135 }
136};
137
138template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
139class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
140{
141public:
142
143 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
144 : BaseWorkload<QueueDescriptor>(descriptor, info)
145 {
146 ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
147 info.m_InputTensorInfos.end(),
148 [&](auto it){
149 return it.GetDataType() == InputDataType;
150 }),
151 "Trying to create workload with incorrect type");
152
153 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
154 info.m_OutputTensorInfos.end(),
155 [&](auto it){
156 return it.GetDataType() == OutputDataType;
157 }),
158 "Trying to create workload with incorrect type");
159 }
160};
161
162// FirstInputTypedWorkload used to check type of the first input
163template <typename QueueDescriptor, armnn::DataType DataType>
164class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
165{
166public:
167
168 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
169 : BaseWorkload<QueueDescriptor>(descriptor, info)
170 {
171 if (!info.m_InputTensorInfos.empty())
172 {
173 ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
174 "Trying to create workload with incorrect type");
175 }
176
177 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
178 info.m_OutputTensorInfos.end(),
179 [&](auto it){
180 return it.GetDataType() == DataType;
181 }),
182 "Trying to create workload with incorrect type");
183 }
184};
185
186template <typename QueueDescriptor>
187using FloatWorkload = TypedWorkload<QueueDescriptor,
188 armnn::DataType::Float16,
189 armnn::DataType::Float32>;
190
191template <typename QueueDescriptor>
192using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
193
194template <typename QueueDescriptor>
195using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
196
197template <typename QueueDescriptor>
198using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
199
200template <typename QueueDescriptor>
201using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
202
203template <typename QueueDescriptor>
204using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
205 armnn::DataType::Float32,
206 armnn::DataType::Boolean>;
207
208template <typename QueueDescriptor>
209using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
210 armnn::DataType::QAsymmU8,
211 armnn::DataType::Boolean>;
212
213template <typename QueueDescriptor>
214using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
215 armnn::DataType::BFloat16,
216 armnn::DataType::Float32>;
217
218template <typename QueueDescriptor>
219using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
220 armnn::DataType::Float32,
221 armnn::DataType::BFloat16>;
222
223template <typename QueueDescriptor>
224using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
225 armnn::DataType::Float16,
226 armnn::DataType::Float32>;
227
228template <typename QueueDescriptor>
229using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
230 armnn::DataType::Float32,
231 armnn::DataType::Float16>;
232
233template <typename QueueDescriptor>
234using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
235 armnn::DataType::QAsymmU8,
236 armnn::DataType::Float32>;
237
238} //namespace armnn