blob: 3efd7dbfd413cf370a571bb9a1649e5d15776575 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Profiling.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace armnn
15{
16
David Beckdcb751f2018-10-03 11:42:42 +010017/// Workload interface to enqueue a layer computation.
telsoa014fcda012018-03-09 14:13:49 +000018class IWorkload
19{
20public:
telsoa01c577f2c2018-08-31 09:22:23 +010021 virtual ~IWorkload() {}
telsoa014fcda012018-03-09 14:13:49 +000022
Derek Lambertif30f7d32019-04-09 10:25:02 +010023 virtual void PostAllocationConfigure() = 0;
telsoa014fcda012018-03-09 14:13:49 +000024 virtual void Execute() const = 0;
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +000025
26 virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
telsoa014fcda012018-03-09 14:13:49 +000027};
28
29// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
30// in the various workload factories.
31// There should never be an instantiation of a NullWorkload.
32class NullWorkload : public IWorkload
33{
34 NullWorkload()=delete;
35};
36
37template <typename QueueDescriptor>
38class BaseWorkload : public IWorkload
39{
40public:
41
42 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
43 : m_Data(descriptor)
44 {
45 m_Data.Validate(info);
46 }
47
Derek Lambertif30f7d32019-04-09 10:25:02 +010048 void PostAllocationConfigure() override {}
49
telsoa014fcda012018-03-09 14:13:49 +000050 const QueueDescriptor& GetData() const { return m_Data; }
51
52protected:
53 const QueueDescriptor m_Data;
54};
55
telsoa01c577f2c2018-08-31 09:22:23 +010056// TypedWorkload used
57template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000058class TypedWorkload : public BaseWorkload<QueueDescriptor>
59{
60public:
61
62 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
63 : BaseWorkload<QueueDescriptor>(descriptor, info)
64 {
telsoa01c577f2c2018-08-31 09:22:23 +010065 std::vector<armnn::DataType> dataTypes = {DataTypes...};
66 armnn::DataType expectedInputType;
67
68 if (!info.m_InputTensorInfos.empty())
69 {
70 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
71
72 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
73 {
74 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
75 }
76 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
77 info.m_InputTensorInfos.end(),
78 [&](auto it){
79 return it.GetDataType() == expectedInputType;
80 }),
81 "Trying to create workload with incorrect type");
82 }
83 armnn::DataType expectedOutputType;
84
85 if (!info.m_OutputTensorInfos.empty())
86 {
87 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
88
89 if (!info.m_InputTensorInfos.empty())
90 {
91 if (expectedOutputType != expectedInputType)
92 {
93 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
94 }
95 }
96 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
97 {
98 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
99 }
100 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
101 info.m_OutputTensorInfos.end(),
102 [&](auto it){
103 return it.GetDataType() == expectedOutputType;
104 }),
105 "Trying to create workload with incorrect type");
106 }
107 }
108};
109
110template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
111class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
112{
113public:
114
115 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
116 : BaseWorkload<QueueDescriptor>(descriptor, info)
117 {
telsoa014fcda012018-03-09 14:13:49 +0000118 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
119 info.m_InputTensorInfos.end(),
120 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100121 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000122 }),
123 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000124
telsoa014fcda012018-03-09 14:13:49 +0000125 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
126 info.m_OutputTensorInfos.end(),
127 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100128 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000129 }),
130 "Trying to create workload with incorrect type");
131 }
telsoa014fcda012018-03-09 14:13:49 +0000132};
133
narpra014951d842019-01-18 16:53:53 +0000134// FirstInputTypedWorkload used to check type of the first input
135template <typename QueueDescriptor, armnn::DataType DataType>
136class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
137{
138public:
139
140 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
141 : BaseWorkload<QueueDescriptor>(descriptor, info)
142 {
143 if (!info.m_InputTensorInfos.empty())
144 {
145 BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
146 "Trying to create workload with incorrect type");
147 }
148
149 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
150 info.m_OutputTensorInfos.end(),
151 [&](auto it){
152 return it.GetDataType() == DataType;
153 }),
154 "Trying to create workload with incorrect type");
155 }
156};
157
telsoa014fcda012018-03-09 14:13:49 +0000158template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100159using FloatWorkload = TypedWorkload<QueueDescriptor,
160 armnn::DataType::Float16,
161 armnn::DataType::Float32>;
162
163template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000164using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
165
166template <typename QueueDescriptor>
167using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
168
telsoa01c577f2c2018-08-31 09:22:23 +0100169template <typename QueueDescriptor>
narpra01db2b1602019-01-23 15:23:11 +0000170using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
171
172template <typename QueueDescriptor>
kevmay012b4d88e2019-01-24 14:05:09 +0000173using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
174
175template <typename QueueDescriptor>
176using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
177 armnn::DataType::Float32,
178 armnn::DataType::Boolean>;
179
180template <typename QueueDescriptor>
181using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
182 armnn::DataType::QuantisedAsymm8,
183 armnn::DataType::Boolean>;
184
185template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100186using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
187 armnn::DataType::Float16,
188 armnn::DataType::Float32>;
189
190template <typename QueueDescriptor>
191using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
192 armnn::DataType::Float32,
193 armnn::DataType::Float16>;
194
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000195template <typename QueueDescriptor>
196using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
197 armnn::DataType::QuantisedAsymm8,
198 armnn::DataType::Float32>;
199
telsoa014fcda012018-03-09 14:13:49 +0000200} //namespace armnn