blob: 87869c984105c35166d2323e0b5916b8e5780c0e [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly386ff1a2021-03-29 15:04:50 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Mike Kelly386ff1a2021-03-29 15:04:50 +01009#include "WorkingMemDescriptor.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010
Narumol Prangnawarat867eba52020-02-03 12:29:56 +000011#include <armnn/backends/IWorkload.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <Profiling.hpp>
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000013#include <ProfilingService.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000016
17namespace armnn
18{
19
telsoa014fcda012018-03-09 14:13:49 +000020// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
21// in the various workload factories.
22// There should never be an instantiation of a NullWorkload.
23class NullWorkload : public IWorkload
24{
25 NullWorkload()=delete;
26};
27
28template <typename QueueDescriptor>
29class BaseWorkload : public IWorkload
30{
31public:
32
33 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000034 : m_Data(descriptor),
Sadik Armagan3184c902020-03-18 10:57:30 +000035 m_Guid(profiling::ProfilingService::GetNextGuid())
telsoa014fcda012018-03-09 14:13:49 +000036 {
37 m_Data.Validate(info);
38 }
39
Finn Williamsb76eaed2021-03-31 16:22:40 +010040 void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
41 {
42 ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
43 std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
44
45 m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
46 m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
47
48 Execute();
49 };
Mike Kelly386ff1a2021-03-29 15:04:50 +010050
Derek Lambertif30f7d32019-04-09 10:25:02 +010051 void PostAllocationConfigure() override {}
52
telsoa014fcda012018-03-09 14:13:49 +000053 const QueueDescriptor& GetData() const { return m_Data; }
54
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000055 profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
56
telsoa014fcda012018-03-09 14:13:49 +000057protected:
Finn Williamsb76eaed2021-03-31 16:22:40 +010058 QueueDescriptor m_Data;
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000059 const profiling::ProfilingGuid m_Guid;
Finn Williamsb76eaed2021-03-31 16:22:40 +010060
61private:
62 std::mutex m_AsyncWorkloadMutex;
telsoa014fcda012018-03-09 14:13:49 +000063};
64
telsoa01c577f2c2018-08-31 09:22:23 +010065// TypedWorkload used
66template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000067class TypedWorkload : public BaseWorkload<QueueDescriptor>
68{
69public:
70
71 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
72 : BaseWorkload<QueueDescriptor>(descriptor, info)
73 {
telsoa01c577f2c2018-08-31 09:22:23 +010074 std::vector<armnn::DataType> dataTypes = {DataTypes...};
75 armnn::DataType expectedInputType;
76
77 if (!info.m_InputTensorInfos.empty())
78 {
79 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
80
81 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
82 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010083 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
telsoa01c577f2c2018-08-31 09:22:23 +010084 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010085 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
telsoa01c577f2c2018-08-31 09:22:23 +010086 info.m_InputTensorInfos.end(),
87 [&](auto it){
88 return it.GetDataType() == expectedInputType;
89 }),
90 "Trying to create workload with incorrect type");
91 }
92 armnn::DataType expectedOutputType;
93
94 if (!info.m_OutputTensorInfos.empty())
95 {
96 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
97
98 if (!info.m_InputTensorInfos.empty())
99 {
100 if (expectedOutputType != expectedInputType)
101 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100102 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
telsoa01c577f2c2018-08-31 09:22:23 +0100103 }
104 }
105 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
106 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100107 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
telsoa01c577f2c2018-08-31 09:22:23 +0100108 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100109 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
telsoa01c577f2c2018-08-31 09:22:23 +0100110 info.m_OutputTensorInfos.end(),
111 [&](auto it){
112 return it.GetDataType() == expectedOutputType;
113 }),
114 "Trying to create workload with incorrect type");
115 }
116 }
117};
118
119template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
120class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
121{
122public:
123
124 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
125 : BaseWorkload<QueueDescriptor>(descriptor, info)
126 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100127 ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
telsoa014fcda012018-03-09 14:13:49 +0000128 info.m_InputTensorInfos.end(),
129 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100130 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000131 }),
132 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000133
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100134 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
telsoa014fcda012018-03-09 14:13:49 +0000135 info.m_OutputTensorInfos.end(),
136 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100137 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000138 }),
139 "Trying to create workload with incorrect type");
140 }
telsoa014fcda012018-03-09 14:13:49 +0000141};
142
narpra014951d842019-01-18 16:53:53 +0000143// FirstInputTypedWorkload used to check type of the first input
144template <typename QueueDescriptor, armnn::DataType DataType>
145class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
146{
147public:
148
149 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
150 : BaseWorkload<QueueDescriptor>(descriptor, info)
151 {
152 if (!info.m_InputTensorInfos.empty())
153 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100154 ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
narpra014951d842019-01-18 16:53:53 +0000155 "Trying to create workload with incorrect type");
156 }
157
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100158 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
narpra014951d842019-01-18 16:53:53 +0000159 info.m_OutputTensorInfos.end(),
160 [&](auto it){
161 return it.GetDataType() == DataType;
162 }),
163 "Trying to create workload with incorrect type");
164 }
165};
166
telsoa014fcda012018-03-09 14:13:49 +0000167template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100168using FloatWorkload = TypedWorkload<QueueDescriptor,
169 armnn::DataType::Float16,
170 armnn::DataType::Float32>;
171
172template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000173using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
174
175template <typename QueueDescriptor>
Derek Lambertif90c56d2020-01-10 17:14:08 +0000176using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
telsoa014fcda012018-03-09 14:13:49 +0000177
telsoa01c577f2c2018-08-31 09:22:23 +0100178template <typename QueueDescriptor>
narpra01db2b1602019-01-23 15:23:11 +0000179using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
180
181template <typename QueueDescriptor>
kevmay012b4d88e2019-01-24 14:05:09 +0000182using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
183
184template <typename QueueDescriptor>
185using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
186 armnn::DataType::Float32,
187 armnn::DataType::Boolean>;
188
189template <typename QueueDescriptor>
190using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000191 armnn::DataType::QAsymmU8,
kevmay012b4d88e2019-01-24 14:05:09 +0000192 armnn::DataType::Boolean>;
193
194template <typename QueueDescriptor>
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000195using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
196 armnn::DataType::BFloat16,
197 armnn::DataType::Float32>;
198
199template <typename QueueDescriptor>
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000200using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
201 armnn::DataType::Float32,
202 armnn::DataType::BFloat16>;
203
204template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100205using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
206 armnn::DataType::Float16,
207 armnn::DataType::Float32>;
208
209template <typename QueueDescriptor>
210using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
211 armnn::DataType::Float32,
212 armnn::DataType::Float16>;
213
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000214template <typename QueueDescriptor>
215using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000216 armnn::DataType::QAsymmU8,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000217 armnn::DataType::Float32>;
218
telsoa014fcda012018-03-09 14:13:49 +0000219} //namespace armnn