blob: 309d53f48e6a8a62d7480c51af18c8d35c824d53 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
10#include <Profiling.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace armnn
15{
16
David Beckdcb751f2018-10-03 11:42:42 +010017/// Workload interface to enqueue a layer computation.
telsoa014fcda012018-03-09 14:13:49 +000018class IWorkload
19{
20public:
telsoa01c577f2c2018-08-31 09:22:23 +010021 virtual ~IWorkload() {}
telsoa014fcda012018-03-09 14:13:49 +000022
23 virtual void Execute() const = 0;
24};
25
26// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
27// in the various workload factories.
28// There should never be an instantiation of a NullWorkload.
29class NullWorkload : public IWorkload
30{
31 NullWorkload()=delete;
32};
33
34template <typename QueueDescriptor>
35class BaseWorkload : public IWorkload
36{
37public:
38
39 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
40 : m_Data(descriptor)
41 {
42 m_Data.Validate(info);
43 }
44
45 const QueueDescriptor& GetData() const { return m_Data; }
46
47protected:
48 const QueueDescriptor m_Data;
49};
50
telsoa01c577f2c2018-08-31 09:22:23 +010051// TypedWorkload used
52template <typename QueueDescriptor, armnn::DataType... DataTypes>
telsoa014fcda012018-03-09 14:13:49 +000053class TypedWorkload : public BaseWorkload<QueueDescriptor>
54{
55public:
56
57 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
58 : BaseWorkload<QueueDescriptor>(descriptor, info)
59 {
telsoa01c577f2c2018-08-31 09:22:23 +010060 std::vector<armnn::DataType> dataTypes = {DataTypes...};
61 armnn::DataType expectedInputType;
62
63 if (!info.m_InputTensorInfos.empty())
64 {
65 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
66
67 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
68 {
69 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
70 }
71 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
72 info.m_InputTensorInfos.end(),
73 [&](auto it){
74 return it.GetDataType() == expectedInputType;
75 }),
76 "Trying to create workload with incorrect type");
77 }
78 armnn::DataType expectedOutputType;
79
80 if (!info.m_OutputTensorInfos.empty())
81 {
82 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
83
84 if (!info.m_InputTensorInfos.empty())
85 {
86 if (expectedOutputType != expectedInputType)
87 {
88 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
89 }
90 }
91 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
92 {
93 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
94 }
95 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
96 info.m_OutputTensorInfos.end(),
97 [&](auto it){
98 return it.GetDataType() == expectedOutputType;
99 }),
100 "Trying to create workload with incorrect type");
101 }
102 }
103};
104
105template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
106class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
107{
108public:
109
110 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
111 : BaseWorkload<QueueDescriptor>(descriptor, info)
112 {
telsoa014fcda012018-03-09 14:13:49 +0000113 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
114 info.m_InputTensorInfos.end(),
115 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100116 return it.GetDataType() == InputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000117 }),
118 "Trying to create workload with incorrect type");
119 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
120 info.m_OutputTensorInfos.end(),
121 [&](auto it){
telsoa01c577f2c2018-08-31 09:22:23 +0100122 return it.GetDataType() == OutputDataType;
telsoa014fcda012018-03-09 14:13:49 +0000123 }),
124 "Trying to create workload with incorrect type");
125 }
telsoa014fcda012018-03-09 14:13:49 +0000126};
127
128template <typename QueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +0100129using FloatWorkload = TypedWorkload<QueueDescriptor,
130 armnn::DataType::Float16,
131 armnn::DataType::Float32>;
132
133template <typename QueueDescriptor>
telsoa014fcda012018-03-09 14:13:49 +0000134using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
135
136template <typename QueueDescriptor>
137using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
138
telsoa01c577f2c2018-08-31 09:22:23 +0100139template <typename QueueDescriptor>
140using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
141 armnn::DataType::Float16,
142 armnn::DataType::Float32>;
143
144template <typename QueueDescriptor>
145using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
146 armnn::DataType::Float32,
147 armnn::DataType::Float16>;
148
telsoa014fcda012018-03-09 14:13:49 +0000149} //namespace armnn