blob: 984443b79b19d17f814f6e7f8421c291104abbb2 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
Narumol Prangnawarat867eba52020-02-03 12:29:56 +000010#include <armnn/backends/IWorkload.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <Profiling.hpp>
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000012#include <ProfilingService.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013
telsoa014fcda012018-03-09 14:13:49 +000014#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000015
16namespace armnn
17{
18
telsoa014fcda012018-03-09 14:13:49 +000019// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
20// in the various workload factories.
21// There should never be an instantiation of a NullWorkload.
22class NullWorkload : public IWorkload
23{
24 NullWorkload()=delete;
25};
26
27template <typename QueueDescriptor>
28class BaseWorkload : public IWorkload
29{
30public:
31
32 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000033 : m_Data(descriptor),
Sadik Armagan3184c902020-03-18 10:57:30 +000034 m_Guid(profiling::ProfilingService::GetNextGuid())
telsoa014fcda012018-03-09 14:13:49 +000035 {
36 m_Data.Validate(info);
37 }
38
Derek Lambertif30f7d32019-04-09 10:25:02 +010039 void PostAllocationConfigure() override {}
40
telsoa014fcda012018-03-09 14:13:49 +000041 const QueueDescriptor& GetData() const { return m_Data; }
42
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000043 profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
44
telsoa014fcda012018-03-09 14:13:49 +000045protected:
46 const QueueDescriptor m_Data;
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000047 const profiling::ProfilingGuid m_Guid;
telsoa014fcda012018-03-09 14:13:49 +000048};
49
telsoa01c577f2c2018-08-31 09:22:23 +010050// TypedWorkload used
51template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000052class TypedWorkload : public BaseWorkload<QueueDescriptor>
53{
54public:
55
56 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
57 : BaseWorkload<QueueDescriptor>(descriptor, info)
58 {
telsoa01c577f2c2018-08-31 09:22:23 +010059 std::vector<armnn::DataType> dataTypes = {DataTypes...};
60 armnn::DataType expectedInputType;
61
62 if (!info.m_InputTensorInfos.empty())
63 {
64 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
65
66 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
67 {
68 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
69 }
70 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
71 info.m_InputTensorInfos.end(),
72 [&](auto it){
73 return it.GetDataType() == expectedInputType;
74 }),
75 "Trying to create workload with incorrect type");
76 }
77 armnn::DataType expectedOutputType;
78
79 if (!info.m_OutputTensorInfos.empty())
80 {
81 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
82
83 if (!info.m_InputTensorInfos.empty())
84 {
85 if (expectedOutputType != expectedInputType)
86 {
87 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
88 }
89 }
90 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
91 {
92 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
93 }
94 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
95 info.m_OutputTensorInfos.end(),
96 [&](auto it){
97 return it.GetDataType() == expectedOutputType;
98 }),
99 "Trying to create workload with incorrect type");
100 }
101 }
102};
103
104template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
105class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
106{
107public:
108
109 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
110 : BaseWorkload<QueueDescriptor>(descriptor, info)
111 {
telsoa014fcda012018-03-09 14:13:49 +0000112 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
113 info.m_InputTensorInfos.end(),
114 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100115 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000116 }),
117 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000118
telsoa014fcda012018-03-09 14:13:49 +0000119 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
120 info.m_OutputTensorInfos.end(),
121 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100122 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000123 }),
124 "Trying to create workload with incorrect type");
125 }
telsoa014fcda012018-03-09 14:13:49 +0000126};
127
narpra014951d842019-01-18 16:53:53 +0000128// FirstInputTypedWorkload used to check type of the first input
129template <typename QueueDescriptor, armnn::DataType DataType>
130class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
131{
132public:
133
134 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
135 : BaseWorkload<QueueDescriptor>(descriptor, info)
136 {
137 if (!info.m_InputTensorInfos.empty())
138 {
139 BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
140 "Trying to create workload with incorrect type");
141 }
142
143 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
144 info.m_OutputTensorInfos.end(),
145 [&](auto it){
146 return it.GetDataType() == DataType;
147 }),
148 "Trying to create workload with incorrect type");
149 }
150};
151
telsoa014fcda012018-03-09 14:13:49 +0000152template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100153using FloatWorkload = TypedWorkload<QueueDescriptor,
154 armnn::DataType::Float16,
155 armnn::DataType::Float32>;
156
157template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000158using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
159
160template <typename QueueDescriptor>
Derek Lambertif90c56d2020-01-10 17:14:08 +0000161using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
telsoa014fcda012018-03-09 14:13:49 +0000162
telsoa01c577f2c2018-08-31 09:22:23 +0100163template <typename QueueDescriptor>
narpra01db2b1602019-01-23 15:23:11 +0000164using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
165
166template <typename QueueDescriptor>
kevmay012b4d88e2019-01-24 14:05:09 +0000167using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
168
169template <typename QueueDescriptor>
170using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
171 armnn::DataType::Float32,
172 armnn::DataType::Boolean>;
173
174template <typename QueueDescriptor>
175using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000176 armnn::DataType::QAsymmU8,
kevmay012b4d88e2019-01-24 14:05:09 +0000177 armnn::DataType::Boolean>;
178
179template <typename QueueDescriptor>
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000180using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
181 armnn::DataType::BFloat16,
182 armnn::DataType::Float32>;
183
184template <typename QueueDescriptor>
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000185using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
186 armnn::DataType::Float32,
187 armnn::DataType::BFloat16>;
188
189template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100190using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
191 armnn::DataType::Float16,
192 armnn::DataType::Float32>;
193
194template <typename QueueDescriptor>
195using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
196 armnn::DataType::Float32,
197 armnn::DataType::Float16>;
198
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000199template <typename QueueDescriptor>
200using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000201 armnn::DataType::QAsymmU8,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000202 armnn::DataType::Float32>;
203
telsoa014fcda012018-03-09 14:13:49 +0000204} //namespace armnn