blob: e3905804a84a4a779c25c269cfaafbb9ebf43845 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000010#include <armnn/Types.hpp>
Jim Flynn6b1bf1a2020-01-22 15:18:49 +000011#include <armnn/IProfilingGuidGenerator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <Profiling.hpp>
13
telsoa014fcda012018-03-09 14:13:49 +000014#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000015
16namespace armnn
17{
18
David Beckdcb751f2018-10-03 11:42:42 +010019/// Workload interface to enqueue a layer computation.
telsoa014fcda012018-03-09 14:13:49 +000020class IWorkload
21{
22public:
telsoa01c577f2c2018-08-31 09:22:23 +010023 virtual ~IWorkload() {}
telsoa014fcda012018-03-09 14:13:49 +000024
Derek Lambertif30f7d32019-04-09 10:25:02 +010025 virtual void PostAllocationConfigure() = 0;
telsoa014fcda012018-03-09 14:13:49 +000026 virtual void Execute() const = 0;
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +000027
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000028 virtual profiling::ProfilingGuid GetGuid() const = 0;
29
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +000030 virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
telsoa014fcda012018-03-09 14:13:49 +000031};
32
33// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
34// in the various workload factories.
35// There should never be an instantiation of a NullWorkload.
36class NullWorkload : public IWorkload
37{
38 NullWorkload()=delete;
39};
40
41template <typename QueueDescriptor>
42class BaseWorkload : public IWorkload
43{
44public:
45
46 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000047 : m_Data(descriptor),
Jim Flynn6b1bf1a2020-01-22 15:18:49 +000048 m_Guid(profiling::IProfilingGuidGenerator::Instance().NextGuid())
telsoa014fcda012018-03-09 14:13:49 +000049 {
50 m_Data.Validate(info);
51 }
52
Derek Lambertif30f7d32019-04-09 10:25:02 +010053 void PostAllocationConfigure() override {}
54
telsoa014fcda012018-03-09 14:13:49 +000055 const QueueDescriptor& GetData() const { return m_Data; }
56
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000057 profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
58
telsoa014fcda012018-03-09 14:13:49 +000059protected:
60 const QueueDescriptor m_Data;
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000061 const profiling::ProfilingGuid m_Guid;
telsoa014fcda012018-03-09 14:13:49 +000062};
63
telsoa01c577f2c2018-08-31 09:22:23 +010064// TypedWorkload used
65template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000066class TypedWorkload : public BaseWorkload<QueueDescriptor>
67{
68public:
69
70 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
71 : BaseWorkload<QueueDescriptor>(descriptor, info)
72 {
telsoa01c577f2c2018-08-31 09:22:23 +010073 std::vector<armnn::DataType> dataTypes = {DataTypes...};
74 armnn::DataType expectedInputType;
75
76 if (!info.m_InputTensorInfos.empty())
77 {
78 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
79
80 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
81 {
82 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
83 }
84 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
85 info.m_InputTensorInfos.end(),
86 [&](auto it){
87 return it.GetDataType() == expectedInputType;
88 }),
89 "Trying to create workload with incorrect type");
90 }
91 armnn::DataType expectedOutputType;
92
93 if (!info.m_OutputTensorInfos.empty())
94 {
95 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
96
97 if (!info.m_InputTensorInfos.empty())
98 {
99 if (expectedOutputType != expectedInputType)
100 {
101 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
102 }
103 }
104 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
105 {
106 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
107 }
108 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
109 info.m_OutputTensorInfos.end(),
110 [&](auto it){
111 return it.GetDataType() == expectedOutputType;
112 }),
113 "Trying to create workload with incorrect type");
114 }
115 }
116};
117
118template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
119class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
120{
121public:
122
123 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
124 : BaseWorkload<QueueDescriptor>(descriptor, info)
125 {
telsoa014fcda012018-03-09 14:13:49 +0000126 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
127 info.m_InputTensorInfos.end(),
128 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100129 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000130 }),
131 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000132
telsoa014fcda012018-03-09 14:13:49 +0000133 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
134 info.m_OutputTensorInfos.end(),
135 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100136 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000137 }),
138 "Trying to create workload with incorrect type");
139 }
telsoa014fcda012018-03-09 14:13:49 +0000140};
141
narpra014951d842019-01-18 16:53:53 +0000142// FirstInputTypedWorkload used to check type of the first input
143template <typename QueueDescriptor, armnn::DataType DataType>
144class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
145{
146public:
147
148 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
149 : BaseWorkload<QueueDescriptor>(descriptor, info)
150 {
151 if (!info.m_InputTensorInfos.empty())
152 {
153 BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
154 "Trying to create workload with incorrect type");
155 }
156
157 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
158 info.m_OutputTensorInfos.end(),
159 [&](auto it){
160 return it.GetDataType() == DataType;
161 }),
162 "Trying to create workload with incorrect type");
163 }
164};
165
telsoa014fcda012018-03-09 14:13:49 +0000166template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100167using FloatWorkload = TypedWorkload<QueueDescriptor,
168 armnn::DataType::Float16,
169 armnn::DataType::Float32>;
170
171template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000172using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
173
174template <typename QueueDescriptor>
175using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
176
telsoa01c577f2c2018-08-31 09:22:23 +0100177template <typename QueueDescriptor>
narpra01db2b1602019-01-23 15:23:11 +0000178using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
179
180template <typename QueueDescriptor>
kevmay012b4d88e2019-01-24 14:05:09 +0000181using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
182
183template <typename QueueDescriptor>
184using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
185 armnn::DataType::Float32,
186 armnn::DataType::Boolean>;
187
188template <typename QueueDescriptor>
189using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
190 armnn::DataType::QuantisedAsymm8,
191 armnn::DataType::Boolean>;
192
193template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100194using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
195 armnn::DataType::Float16,
196 armnn::DataType::Float32>;
197
198template <typename QueueDescriptor>
199using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
200 armnn::DataType::Float32,
201 armnn::DataType::Float16>;
202
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000203template <typename QueueDescriptor>
204using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
205 armnn::DataType::QuantisedAsymm8,
206 armnn::DataType::Float32>;
207
telsoa014fcda012018-03-09 14:13:49 +0000208} //namespace armnn