blob: eef5b24df73fc27153fc83e78010040bc7d426e8 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005#include <backendsCommon/CpuTensorHandle.hpp>
6#include <backendsCommon/MemCopyWorkload.hpp>
7#include <backendsCommon/MakeWorkloadHelper.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include "RefWorkloadFactory.hpp"
David Beck79141b92018-10-23 16:09:36 +01009#include "RefBackendId.hpp"
David Beckb4540be2018-09-24 13:18:27 +010010#include "workloads/RefWorkloads.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011#include "Layer.hpp"
telsoa014fcda012018-03-09 14:13:49 +000012
13#include <boost/log/trivial.hpp>
14
15namespace armnn
16{
17
David Beck79141b92018-10-23 16:09:36 +010018namespace
19{
20static const BackendId s_Id{RefBackendId()};
21}
22
telsoa014fcda012018-03-09 14:13:49 +000023template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
24std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
25 const WorkloadInfo& info) const
26{
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +010027 return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000028}
29
telsoa01c577f2c2018-08-31 09:22:23 +010030RefWorkloadFactory::RefWorkloadFactory()
telsoa014fcda012018-03-09 14:13:49 +000031{
32}
33
David Beck79141b92018-10-23 16:09:36 +010034const BackendId& RefWorkloadFactory::GetBackendId() const
35{
36 return s_Id;
37}
38
David Beck29c75de2018-10-23 13:35:58 +010039bool RefWorkloadFactory::IsLayerSupported(const Layer& layer,
40 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +010041 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000042{
David Beck79141b92018-10-23 16:09:36 +010043 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +000044}
45
46std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
47{
48 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
49}
50
Francis Murtagh351d13d2018-09-24 15:01:18 +010051std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
52 DataLayout dataLayout) const
53{
54 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
55}
56
telsoa014fcda012018-03-09 14:13:49 +000057std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
58 const WorkloadInfo& info) const
59{
60 if (info.m_InputTensorInfos.empty() )
61 {
62 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
63 }
64 if (info.m_OutputTensorInfos.empty())
65 {
66 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
67 }
68
69 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
70 {
71 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
72 }
73
telsoa01c577f2c2018-08-31 09:22:23 +010074 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000075}
76
77std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
78 const WorkloadInfo& info) const
79{
80 if (info.m_InputTensorInfos.empty() )
81 {
82 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
83 }
84 if (info.m_OutputTensorInfos.empty())
85 {
86 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
87 }
88 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
89 {
90 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
91 }
92
telsoa01c577f2c2018-08-31 09:22:23 +010093 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000094}
95
96std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
97 const WorkloadInfo& info) const
98{
99 return MakeWorkload<RefActivationFloat32Workload, RefActivationUint8Workload>(descriptor, info);
100}
101
102std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
103 const WorkloadInfo& info) const
104{
105 return MakeWorkload<RefSoftmaxFloat32Workload, RefSoftmaxUint8Workload>(descriptor, info);
106}
107
108std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
109 const WorkloadInfo& info) const
110{
111 return MakeWorkload<RefSplitterFloat32Workload, RefSplitterUint8Workload>(descriptor, info);
112}
113
114std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
115 const WorkloadInfo& info) const
116{
117 return MakeWorkload<RefMergerFloat32Workload, RefMergerUint8Workload>(descriptor, info);
118}
119
120std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
121 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
122{
123 return MakeWorkload<RefFullyConnectedFloat32Workload, RefFullyConnectedUint8Workload>(descriptor, info);
124}
125
126std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
127 const WorkloadInfo& info) const
128{
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +0100129 return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload>
arovir01616e7752018-10-01 17:08:59 +0100130 (descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000131}
132
133std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
134 const WorkloadInfo& info) const
135{
136 return MakeWorkload<RefPooling2dFloat32Workload, RefPooling2dUint8Workload>(descriptor, info);
137}
138
139std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConvolution2d(
140 const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
141{
142 return MakeWorkload<RefConvolution2dFloat32Workload, RefConvolution2dUint8Workload>(descriptor, info);
143}
144
145std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthwiseConvolution2d(
146 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
147{
148 return MakeWorkload<RefDepthwiseConvolution2dFloat32Workload,
149 RefDepthwiseConvolution2dUint8Workload>(descriptor, info);
150}
151
152std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization(
153 const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
154{
155 return MakeWorkload<RefNormalizationFloat32Workload, NullWorkload>(descriptor, info);
156}
157
158std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
159 const WorkloadInfo& info) const
160{
161 return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info);
162}
163
164std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication(
165 const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
166{
167 return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info);
168}
169
170std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization(
171 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
172{
173 return MakeWorkload<RefBatchNormalizationFloat32Workload, RefBatchNormalizationUint8Workload>(descriptor, info);
174}
175
176std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
177 const WorkloadInfo& info) const
178{
179 if (descriptor.m_Inputs.empty())
180 {
181 throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
182 }
telsoa01c577f2c2018-08-31 09:22:23 +0100183 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000184}
185
186std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
187 const WorkloadInfo& info) const
188{
189 return MakeWorkload<RefResizeBilinearFloat32Workload, RefResizeBilinearUint8Workload>(descriptor, info);
190}
191
192std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
193 const FakeQuantizationQueueDescriptor& descriptor,
194 const WorkloadInfo& info) const
195{
196 return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
197}
198
199std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
200 const WorkloadInfo& info) const
201{
202 return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
203}
204
205std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
206 const WorkloadInfo& info) const
207{
208 return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info);
209}
210
211std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
212 const WorkloadInfo& info) const
213{
214 return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
215}
216
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000217std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
218 const WorkloadInfo& info) const
219{
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000220 return MakeWorkload<RefSpaceToBatchNdFloat32Workload, RefSpaceToBatchNdUint8Workload>(descriptor, info);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000221}
222
telsoa014fcda012018-03-09 14:13:49 +0000223std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100224 const WorkloadInfo& info) const
telsoa014fcda012018-03-09 14:13:49 +0000225{
226 return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info);
227}
228
telsoa01c577f2c2018-08-31 09:22:23 +0100229std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
230 const WorkloadInfo& info) const
231{
232 return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
233}
234
235std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
236 const ConvertFp16ToFp32QueueDescriptor& descriptor,
237 const WorkloadInfo& info) const
238{
239 return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
240}
241
242std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
243 const ConvertFp32ToFp16QueueDescriptor& descriptor,
244 const WorkloadInfo& info) const
245{
246 return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
247}
248
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100249std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
250 const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
251{
252 return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info);
253}
254
David Beckc2044fe2018-09-05 15:00:38 +0100255std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
256 const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
257{
David Beckf195f032018-09-06 16:46:34 +0100258 return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
David Beckc2044fe2018-09-05 15:00:38 +0100259}
260
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000261std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMaximum(
262 const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const
263{
saoste012df12b32018-11-28 16:57:20 +0000264 return MakeWorkload<RefMaximumFloat32Workload, RefMaximumUint8Workload>(descriptor, info);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000265}
266
narpra01a6bf9122018-09-10 09:50:09 +0100267std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
268 const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const
269{
narpra011e4c31d2018-09-28 11:07:51 +0100270 return MakeWorkload<RefMeanFloat32Workload, RefMeanUint8Workload>(descriptor, info);
narpra01a6bf9122018-09-10 09:50:09 +0100271}
272
jimfly012c9322a2018-09-19 10:59:49 +0100273std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
274 const WorkloadInfo& info) const
275{
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100276 return MakeWorkload<RefPadFloat32Workload, RefPadUint8Workload>(descriptor, info);
jimfly012c9322a2018-09-19 10:59:49 +0100277}
278
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000279std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
280 const WorkloadInfo& info) const
281{
282 return MakeWorkload<RefBatchToSpaceNdFloat32Workload, RefBatchToSpaceNdUint8Workload>(descriptor, info);
283}
jimfly012c9322a2018-09-19 10:59:49 +0100284
Conor Kennedy430b5d82018-11-14 15:28:28 +0000285std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
286 const WorkloadInfo& info) const
287{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000288 return MakeWorkload<RefStridedSliceFloat32Workload, RefStridedSliceUint8Workload>(descriptor, info);
Conor Kennedy430b5d82018-11-14 15:28:28 +0000289}
290
kevmay0190539692018-11-29 08:40:19 +0000291std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
292 const WorkloadInfo& info) const
293{
294 return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
295}
296
telsoa014fcda012018-03-09 14:13:49 +0000297} // namespace armnn