blob: be89a3222f30226a9b9e91473ab7eb5058481d32 [file] [log] [blame]
David Beck279f8722018-09-12 13:50:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Types.hpp>
David Beckac42efd2018-09-26 17:41:13 +01009#include <backends/StringMapping.hpp>
10#include <backends/Workload.hpp>
11#include <backends/WorkloadData.hpp>
David Beck279f8722018-09-12 13:50:03 +010012
13namespace armnn
14{
15
16template <typename Functor,
17 typename armnn::DataType DataType,
18 typename ParentDescriptor,
19 typename armnn::StringMapping::Id DebugString>
20class RefArithmeticWorkload
21{
22 // Needs specialization. The default is empty on purpose.
23};
24
25template <typename ParentDescriptor, typename Functor>
26class BaseFloat32ArithmeticWorkload : public Float32Workload<ParentDescriptor>
27{
28public:
29 using Float32Workload<ParentDescriptor>::Float32Workload;
30 void ExecuteImpl(const char * debugString) const;
31};
32
33template <typename Functor,
34 typename ParentDescriptor,
35 typename armnn::StringMapping::Id DebugString>
36class RefArithmeticWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString>
37 : public BaseFloat32ArithmeticWorkload<ParentDescriptor, Functor>
38{
39public:
40 using BaseFloat32ArithmeticWorkload<ParentDescriptor, Functor>::BaseFloat32ArithmeticWorkload;
41
42 virtual void Execute() const override
43 {
44 using Parent = BaseFloat32ArithmeticWorkload<ParentDescriptor, Functor>;
45 Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
46 }
47};
48
49template <typename ParentDescriptor, typename Functor>
50class BaseUint8ArithmeticWorkload : public Uint8Workload<ParentDescriptor>
51{
52public:
53 using Uint8Workload<ParentDescriptor>::Uint8Workload;
54 void ExecuteImpl(const char * debugString) const;
55};
56
57template <typename Functor,
58 typename ParentDescriptor,
59 typename armnn::StringMapping::Id DebugString>
60class RefArithmeticWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString>
61 : public BaseUint8ArithmeticWorkload<ParentDescriptor, Functor>
62{
63public:
64 using BaseUint8ArithmeticWorkload<ParentDescriptor, Functor>::BaseUint8ArithmeticWorkload;
65
66 virtual void Execute() const override
67 {
68 using Parent = BaseUint8ArithmeticWorkload<ParentDescriptor, Functor>;
69 Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
70 }
71};
72
73using RefAdditionFloat32Workload =
74 RefArithmeticWorkload<std::plus<float>,
75 DataType::Float32,
76 AdditionQueueDescriptor,
77 StringMapping::RefAdditionWorkload_Execute>;
78
79using RefAdditionUint8Workload =
80 RefArithmeticWorkload<std::plus<float>,
81 DataType::QuantisedAsymm8,
82 AdditionQueueDescriptor,
83 StringMapping::RefAdditionWorkload_Execute>;
84
85
86using RefSubtractionFloat32Workload =
87 RefArithmeticWorkload<std::minus<float>,
88 DataType::Float32,
89 SubtractionQueueDescriptor,
90 StringMapping::RefSubtractionWorkload_Execute>;
91
92using RefSubtractionUint8Workload =
93 RefArithmeticWorkload<std::minus<float>,
94 DataType::QuantisedAsymm8,
95 SubtractionQueueDescriptor,
96 StringMapping::RefSubtractionWorkload_Execute>;
97
98using RefMultiplicationFloat32Workload =
99 RefArithmeticWorkload<std::multiplies<float>,
100 DataType::Float32,
101 MultiplicationQueueDescriptor,
102 StringMapping::RefMultiplicationWorkload_Execute>;
103
104using RefMultiplicationUint8Workload =
105 RefArithmeticWorkload<std::multiplies<float>,
106 DataType::QuantisedAsymm8,
107 MultiplicationQueueDescriptor,
108 StringMapping::RefMultiplicationWorkload_Execute>;
109
110using RefDivisionFloat32Workload =
111 RefArithmeticWorkload<std::divides<float>,
112 DataType::Float32,
113 DivisionQueueDescriptor,
114 StringMapping::RefDivisionWorkload_Execute>;
115
116using RefDivisionUint8Workload =
117 RefArithmeticWorkload<std::divides<float>,
118 DataType::QuantisedAsymm8,
119 DivisionQueueDescriptor,
120 StringMapping::RefDivisionWorkload_Execute>;
121
122} // armnn