blob: 65392194a2ea0625ae4f6f810ff4738013e04268 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Profiling.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace armnn
15{
16
David Beckdcb751f2018-10-03 11:42:42 +010017/// Workload interface to enqueue a layer computation.
telsoa014fcda012018-03-09 14:13:49 +000018class IWorkload
19{
20public:
telsoa01c577f2c2018-08-31 09:22:23 +010021 virtual ~IWorkload() {}
telsoa014fcda012018-03-09 14:13:49 +000022
23 virtual void Execute() const = 0;
24};
25
26// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
27// in the various workload factories.
28// There should never be an instantiation of a NullWorkload.
29class NullWorkload : public IWorkload
30{
31 NullWorkload()=delete;
32};
33
34template <typename QueueDescriptor>
35class BaseWorkload : public IWorkload
36{
37public:
38
39 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
40 : m_Data(descriptor)
41 {
42 m_Data.Validate(info);
43 }
44
45 const QueueDescriptor& GetData() const { return m_Data; }
46
47protected:
48 const QueueDescriptor m_Data;
49};
50
telsoa01c577f2c2018-08-31 09:22:23 +010051// TypedWorkload used
52template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000053class TypedWorkload : public BaseWorkload<QueueDescriptor>
54{
55public:
56
57 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
58 : BaseWorkload<QueueDescriptor>(descriptor, info)
59 {
telsoa01c577f2c2018-08-31 09:22:23 +010060 std::vector<armnn::DataType> dataTypes = {DataTypes...};
61 armnn::DataType expectedInputType;
62
63 if (!info.m_InputTensorInfos.empty())
64 {
65 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
66
67 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
68 {
69 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
70 }
71 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
72 info.m_InputTensorInfos.end(),
73 [&](auto it){
74 return it.GetDataType() == expectedInputType;
75 }),
76 "Trying to create workload with incorrect type");
77 }
78 armnn::DataType expectedOutputType;
79
80 if (!info.m_OutputTensorInfos.empty())
81 {
82 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
83
84 if (!info.m_InputTensorInfos.empty())
85 {
86 if (expectedOutputType != expectedInputType)
87 {
88 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
89 }
90 }
91 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
92 {
93 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
94 }
95 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
96 info.m_OutputTensorInfos.end(),
97 [&](auto it){
98 return it.GetDataType() == expectedOutputType;
99 }),
100 "Trying to create workload with incorrect type");
101 }
102 }
103};
104
105template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
106class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
107{
108public:
109
110 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
111 : BaseWorkload<QueueDescriptor>(descriptor, info)
112 {
telsoa014fcda012018-03-09 14:13:49 +0000113 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
114 info.m_InputTensorInfos.end(),
115 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100116 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000117 }),
118 "Trying to create workload with incorrect type");
narpra014951d842019-01-18 16:53:53 +0000119
telsoa014fcda012018-03-09 14:13:49 +0000120 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
121 info.m_OutputTensorInfos.end(),
122 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100123 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000124 }),
125 "Trying to create workload with incorrect type");
126 }
telsoa014fcda012018-03-09 14:13:49 +0000127};
128
narpra014951d842019-01-18 16:53:53 +0000129// FirstInputTypedWorkload used to check type of the first input
130template <typename QueueDescriptor, armnn::DataType DataType>
131class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
132{
133public:
134
135 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
136 : BaseWorkload<QueueDescriptor>(descriptor, info)
137 {
138 if (!info.m_InputTensorInfos.empty())
139 {
140 BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
141 "Trying to create workload with incorrect type");
142 }
143
144 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
145 info.m_OutputTensorInfos.end(),
146 [&](auto it){
147 return it.GetDataType() == DataType;
148 }),
149 "Trying to create workload with incorrect type");
150 }
151};
152
telsoa014fcda012018-03-09 14:13:49 +0000153template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100154using FloatWorkload = TypedWorkload<QueueDescriptor,
155 armnn::DataType::Float16,
156 armnn::DataType::Float32>;
157
158template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000159using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
160
161template <typename QueueDescriptor>
162using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
163
telsoa01c577f2c2018-08-31 09:22:23 +0100164template <typename QueueDescriptor>
165using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
166 armnn::DataType::Float16,
167 armnn::DataType::Float32>;
168
169template <typename QueueDescriptor>
170using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
171 armnn::DataType::Float32,
172 armnn::DataType::Float16>;
173
telsoa014fcda012018-03-09 14:13:49 +0000174} //namespace armnn