blob: dbc7574d0e0f914272af41bc380453ed69fad963 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "WorkloadData.hpp"
8#include "WorkloadInfo.hpp"
9#include <algorithm>
10#include "Profiling.hpp"
11
12namespace armnn
13{
14
15// Workload interface to enqueue a layer computation
16class IWorkload
17{
18public:
19 virtual ~IWorkload(){};
20
21 virtual void Execute() const = 0;
22};
23
24// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
25// in the various workload factories.
26// There should never be an instantiation of a NullWorkload.
27class NullWorkload : public IWorkload
28{
29 NullWorkload()=delete;
30};
31
32template <typename QueueDescriptor>
33class BaseWorkload : public IWorkload
34{
35public:
36
37 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
38 : m_Data(descriptor)
39 {
40 m_Data.Validate(info);
41 }
42
43 const QueueDescriptor& GetData() const { return m_Data; }
44
45protected:
46 const QueueDescriptor m_Data;
47};
48
49template <typename QueueDescriptor, armnn::DataType DataType>
50class TypedWorkload : public BaseWorkload<QueueDescriptor>
51{
52public:
53
54 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
55 : BaseWorkload<QueueDescriptor>(descriptor, info)
56 {
57 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
58 info.m_InputTensorInfos.end(),
59 [&](auto it){
60 return it.GetDataType() == DataType;
61 }),
62 "Trying to create workload with incorrect type");
63 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
64 info.m_OutputTensorInfos.end(),
65 [&](auto it){
66 return it.GetDataType() == DataType;
67 }),
68 "Trying to create workload with incorrect type");
69 }
70
71 static constexpr armnn::DataType ms_DataType = DataType;
72};
73
74template <typename QueueDescriptor>
75using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
76
77template <typename QueueDescriptor>
78using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
79
80} //namespace armnn