blob: 940b878d2fbcd4a832f78ff1bc97e176d11723ac [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly386ff1a2021-03-29 15:04:50 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Mike Kelly386ff1a2021-03-29 15:04:50 +01009#include "WorkingMemDescriptor.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010
Narumol Prangnawarat867eba52020-02-03 12:29:56 +000011#include <armnn/backends/IWorkload.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <Profiling.hpp>
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000013#include <ProfilingService.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000016
17namespace armnn
18{
19
telsoa014fcda012018-03-09 14:13:49 +000020// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
21// in the various workload factories.
22// There should never be an instantiation of a NullWorkload.
23class NullWorkload : public IWorkload
24{
25 NullWorkload()=delete;
26};
27
28template <typename QueueDescriptor>
29class BaseWorkload : public IWorkload
30{
31public:
32
33 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000034 : m_Data(descriptor),
Sadik Armagan3184c902020-03-18 10:57:30 +000035 m_Guid(profiling::ProfilingService::GetNextGuid())
telsoa014fcda012018-03-09 14:13:49 +000036 {
37 m_Data.Validate(info);
38 }
39
Mike Kelly386ff1a2021-03-29 15:04:50 +010040 void ExecuteAsync(WorkingMemDescriptor&) override {};
41
Derek Lambertif30f7d32019-04-09 10:25:02 +010042 void PostAllocationConfigure() override {}
43
telsoa014fcda012018-03-09 14:13:49 +000044 const QueueDescriptor& GetData() const { return m_Data; }
45
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000046 profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
47
telsoa014fcda012018-03-09 14:13:49 +000048protected:
49 const QueueDescriptor m_Data;
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000050 const profiling::ProfilingGuid m_Guid;
telsoa014fcda012018-03-09 14:13:49 +000051};
52
telsoa01c577f2c2018-08-31 09:22:23 +010053// TypedWorkload used
54template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000055class TypedWorkload : public BaseWorkload<QueueDescriptor>
56{
57public:
58
59 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
60 : BaseWorkload<QueueDescriptor>(descriptor, info)
61 {
telsoa01c577f2c2018-08-31 09:22:23 +010062 std::vector<armnn::DataType> dataTypes = {DataTypes...};
63 armnn::DataType expectedInputType;
64
65 if (!info.m_InputTensorInfos.empty())
66 {
67 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
68
69 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
70 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010071 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
telsoa01c577f2c2018-08-31 09:22:23 +010072 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010073 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
telsoa01c577f2c2018-08-31 09:22:23 +010074 info.m_InputTensorInfos.end(),
75 [&](auto it){
76 return it.GetDataType() == expectedInputType;
77 }),
78 "Trying to create workload with incorrect type");
79 }
80 armnn::DataType expectedOutputType;
81
82 if (!info.m_OutputTensorInfos.empty())
83 {
84 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
85
86 if (!info.m_InputTensorInfos.empty())
87 {
88 if (expectedOutputType != expectedInputType)
89 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010090 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
telsoa01c577f2c2018-08-31 09:22:23 +010091 }
92 }
93 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
94 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010095 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
telsoa01c577f2c2018-08-31 09:22:23 +010096 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010097 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
telsoa01c577f2c2018-08-31 09:22:23 +010098 info.m_OutputTensorInfos.end(),
99 [&](auto it){
100 return it.GetDataType() == expectedOutputType;
101 }),
102 "Trying to create workload with incorrect type");
103 }
104 }
105};
106
107template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
108class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
109{
110public:
111
112 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
113 : BaseWorkload<QueueDescriptor>(descriptor, info)
114 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100115 ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
telsoa014fcda012018-03-09 14:13:49 +0000116 info.m_InputTensorInfos.end(),
117 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100118 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000119 }),
120 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000121
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100122 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
telsoa014fcda012018-03-09 14:13:49 +0000123 info.m_OutputTensorInfos.end(),
124 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100125 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000126 }),
127 "Trying to create workload with incorrect type");
128 }
telsoa014fcda012018-03-09 14:13:49 +0000129};
130
narpra014951d842019-01-18 16:53:53 +0000131// FirstInputTypedWorkload used to check type of the first input
132template <typename QueueDescriptor, armnn::DataType DataType>
133class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
134{
135public:
136
137 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
138 : BaseWorkload<QueueDescriptor>(descriptor, info)
139 {
140 if (!info.m_InputTensorInfos.empty())
141 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100142 ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
narpra014951d842019-01-18 16:53:53 +0000143 "Trying to create workload with incorrect type");
144 }
145
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100146 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
narpra014951d842019-01-18 16:53:53 +0000147 info.m_OutputTensorInfos.end(),
148 [&](auto it){
149 return it.GetDataType() == DataType;
150 }),
151 "Trying to create workload with incorrect type");
152 }
153};
154
telsoa014fcda012018-03-09 14:13:49 +0000155template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100156using FloatWorkload = TypedWorkload<QueueDescriptor,
157 armnn::DataType::Float16,
158 armnn::DataType::Float32>;
159
160template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000161using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
162
163template <typename QueueDescriptor>
Derek Lambertif90c56d2020-01-10 17:14:08 +0000164using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
telsoa014fcda012018-03-09 14:13:49 +0000165
telsoa01c577f2c2018-08-31 09:22:23 +0100166template <typename QueueDescriptor>
narpra01db2b1602019-01-23 15:23:11 +0000167using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
168
169template <typename QueueDescriptor>
kevmay012b4d88e2019-01-24 14:05:09 +0000170using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
171
172template <typename QueueDescriptor>
173using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
174 armnn::DataType::Float32,
175 armnn::DataType::Boolean>;
176
177template <typename QueueDescriptor>
178using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000179 armnn::DataType::QAsymmU8,
kevmay012b4d88e2019-01-24 14:05:09 +0000180 armnn::DataType::Boolean>;
181
182template <typename QueueDescriptor>
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000183using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
184 armnn::DataType::BFloat16,
185 armnn::DataType::Float32>;
186
187template <typename QueueDescriptor>
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000188using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
189 armnn::DataType::Float32,
190 armnn::DataType::BFloat16>;
191
192template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100193using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
194 armnn::DataType::Float16,
195 armnn::DataType::Float32>;
196
197template <typename QueueDescriptor>
198using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
199 armnn::DataType::Float32,
200 armnn::DataType::Float16>;
201
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000202template <typename QueueDescriptor>
203using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000204 armnn::DataType::QAsymmU8,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000205 armnn::DataType::Float32>;
206
telsoa014fcda012018-03-09 14:13:49 +0000207} //namespace armnn