blob: 447ec1b4d6371e8f650b20a099da8d291f47ce2c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Profiling.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace armnn
15{
16
David Beckdcb751f2018-10-03 11:42:42 +010017/// Workload interface to enqueue a layer computation.
telsoa014fcda012018-03-09 14:13:49 +000018class IWorkload
19{
20public:
telsoa01c577f2c2018-08-31 09:22:23 +010021 virtual ~IWorkload() {}
telsoa014fcda012018-03-09 14:13:49 +000022
23 virtual void Execute() const = 0;
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +000024
25 virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
telsoa014fcda012018-03-09 14:13:49 +000026};
27
28// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
29// in the various workload factories.
30// There should never be an instantiation of a NullWorkload.
31class NullWorkload : public IWorkload
32{
33 NullWorkload()=delete;
34};
35
36template <typename QueueDescriptor>
37class BaseWorkload : public IWorkload
38{
39public:
40
41 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
42 : m_Data(descriptor)
43 {
44 m_Data.Validate(info);
45 }
46
47 const QueueDescriptor& GetData() const { return m_Data; }
48
49protected:
50 const QueueDescriptor m_Data;
51};
52
telsoa01c577f2c2018-08-31 09:22:23 +010053// TypedWorkload used
54template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000055class TypedWorkload : public BaseWorkload<QueueDescriptor>
56{
57public:
58
59 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
60 : BaseWorkload<QueueDescriptor>(descriptor, info)
61 {
telsoa01c577f2c2018-08-31 09:22:23 +010062 std::vector<armnn::DataType> dataTypes = {DataTypes...};
63 armnn::DataType expectedInputType;
64
65 if (!info.m_InputTensorInfos.empty())
66 {
67 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
68
69 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
70 {
71 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
72 }
73 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
74 info.m_InputTensorInfos.end(),
75 [&](auto it){
76 return it.GetDataType() == expectedInputType;
77 }),
78 "Trying to create workload with incorrect type");
79 }
80 armnn::DataType expectedOutputType;
81
82 if (!info.m_OutputTensorInfos.empty())
83 {
84 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
85
86 if (!info.m_InputTensorInfos.empty())
87 {
88 if (expectedOutputType != expectedInputType)
89 {
90 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
91 }
92 }
93 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
94 {
95 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
96 }
97 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
98 info.m_OutputTensorInfos.end(),
99 [&](auto it){
100 return it.GetDataType() == expectedOutputType;
101 }),
102 "Trying to create workload with incorrect type");
103 }
104 }
105};
106
107template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
108class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
109{
110public:
111
112 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
113 : BaseWorkload<QueueDescriptor>(descriptor, info)
114 {
telsoa014fcda012018-03-09 14:13:49 +0000115 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
116 info.m_InputTensorInfos.end(),
117 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100118 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000119 }),
120 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000121
telsoa014fcda012018-03-09 14:13:49 +0000122 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
123 info.m_OutputTensorInfos.end(),
124 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100125 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000126 }),
127 "Trying to create workload with incorrect type");
128 }
telsoa014fcda012018-03-09 14:13:49 +0000129};
130
narpra014951d842019-01-18 16:53:53 +0000131// FirstInputTypedWorkload used to check type of the first input
132template <typename QueueDescriptor, armnn::DataType DataType>
133class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
134{
135public:
136
137 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
138 : BaseWorkload<QueueDescriptor>(descriptor, info)
139 {
140 if (!info.m_InputTensorInfos.empty())
141 {
142 BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
143 "Trying to create workload with incorrect type");
144 }
145
146 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
147 info.m_OutputTensorInfos.end(),
148 [&](auto it){
149 return it.GetDataType() == DataType;
150 }),
151 "Trying to create workload with incorrect type");
152 }
153};
154
telsoa014fcda012018-03-09 14:13:49 +0000155template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100156using FloatWorkload = TypedWorkload<QueueDescriptor,
157 armnn::DataType::Float16,
158 armnn::DataType::Float32>;
159
160template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000161using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
162
163template <typename QueueDescriptor>
164using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
165
telsoa01c577f2c2018-08-31 09:22:23 +0100166template <typename QueueDescriptor>
narpra01db2b1602019-01-23 15:23:11 +0000167using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
168
169template <typename QueueDescriptor>
kevmay012b4d88e2019-01-24 14:05:09 +0000170using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
171
172template <typename QueueDescriptor>
173using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
174 armnn::DataType::Float32,
175 armnn::DataType::Boolean>;
176
177template <typename QueueDescriptor>
178using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
179 armnn::DataType::QuantisedAsymm8,
180 armnn::DataType::Boolean>;
181
182template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100183using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
184 armnn::DataType::Float16,
185 armnn::DataType::Float32>;
186
187template <typename QueueDescriptor>
188using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
189 armnn::DataType::Float32,
190 armnn::DataType::Float16>;
191
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000192template <typename QueueDescriptor>
193using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
194 armnn::DataType::QuantisedAsymm8,
195 armnn::DataType::Float32>;
196
telsoa014fcda012018-03-09 14:13:49 +0000197} //namespace armnn