blob: 3bf83bd9be7cfc6bf85050b9deaf6b02a1a74b78 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005#include <backendsCommon/CpuTensorHandle.hpp>
6#include <backendsCommon/MemCopyWorkload.hpp>
7#include <backendsCommon/MakeWorkloadHelper.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include "RefWorkloadFactory.hpp"
David Beck79141b92018-10-23 16:09:36 +01009#include "RefBackendId.hpp"
David Beckb4540be2018-09-24 13:18:27 +010010#include "workloads/RefWorkloads.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011#include "Layer.hpp"
telsoa014fcda012018-03-09 14:13:49 +000012
13#include <boost/log/trivial.hpp>
14
15namespace armnn
16{
17
David Beck79141b92018-10-23 16:09:36 +010018namespace
19{
20static const BackendId s_Id{RefBackendId()};
21}
22
telsoa014fcda012018-03-09 14:13:49 +000023template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
24std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
25 const WorkloadInfo& info) const
26{
kevmay012b4d88e2019-01-24 14:05:09 +000027 return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload>(descriptor,
28 info);
telsoa014fcda012018-03-09 14:13:49 +000029}
30
telsoa01c577f2c2018-08-31 09:22:23 +010031RefWorkloadFactory::RefWorkloadFactory()
telsoa014fcda012018-03-09 14:13:49 +000032{
33}
34
David Beck79141b92018-10-23 16:09:36 +010035const BackendId& RefWorkloadFactory::GetBackendId() const
36{
37 return s_Id;
38}
39
David Beck29c75de2018-10-23 13:35:58 +010040bool RefWorkloadFactory::IsLayerSupported(const Layer& layer,
41 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +010042 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000043{
David Beck79141b92018-10-23 16:09:36 +010044 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +000045}
46
47std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
48{
49 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
50}
51
Francis Murtagh351d13d2018-09-24 15:01:18 +010052std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
53 DataLayout dataLayout) const
54{
55 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
56}
57
telsoa014fcda012018-03-09 14:13:49 +000058std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
59 const WorkloadInfo& info) const
60{
61 if (info.m_InputTensorInfos.empty() )
62 {
63 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
64 }
65 if (info.m_OutputTensorInfos.empty())
66 {
67 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
68 }
69
70 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
71 {
72 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
73 }
74
telsoa01c577f2c2018-08-31 09:22:23 +010075 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000076}
77
78std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
79 const WorkloadInfo& info) const
80{
81 if (info.m_InputTensorInfos.empty() )
82 {
83 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
84 }
85 if (info.m_OutputTensorInfos.empty())
86 {
87 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
88 }
89 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
90 {
91 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
92 }
93
kevmay012b4d88e2019-01-24 14:05:09 +000094 return MakeWorkloadHelper<CopyMemGenericWorkload, CopyMemGenericWorkload,
95 CopyMemGenericWorkload, NullWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000096}
97
98std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
99 const WorkloadInfo& info) const
100{
101 return MakeWorkload<RefActivationFloat32Workload, RefActivationUint8Workload>(descriptor, info);
102}
103
104std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
105 const WorkloadInfo& info) const
106{
107 return MakeWorkload<RefSoftmaxFloat32Workload, RefSoftmaxUint8Workload>(descriptor, info);
108}
109
110std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
111 const WorkloadInfo& info) const
112{
113 return MakeWorkload<RefSplitterFloat32Workload, RefSplitterUint8Workload>(descriptor, info);
114}
115
116std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
117 const WorkloadInfo& info) const
118{
119 return MakeWorkload<RefMergerFloat32Workload, RefMergerUint8Workload>(descriptor, info);
120}
121
122std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
123 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
124{
125 return MakeWorkload<RefFullyConnectedFloat32Workload, RefFullyConnectedUint8Workload>(descriptor, info);
126}
127
128std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
129 const WorkloadInfo& info) const
130{
narpra01db2b1602019-01-23 15:23:11 +0000131 return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload,
kevmay012b4d88e2019-01-24 14:05:09 +0000132 NullWorkload, NullWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000133}
134
135std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
136 const WorkloadInfo& info) const
137{
138 return MakeWorkload<RefPooling2dFloat32Workload, RefPooling2dUint8Workload>(descriptor, info);
139}
140
141std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConvolution2d(
142 const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
143{
144 return MakeWorkload<RefConvolution2dFloat32Workload, RefConvolution2dUint8Workload>(descriptor, info);
145}
146
147std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthwiseConvolution2d(
148 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
149{
150 return MakeWorkload<RefDepthwiseConvolution2dFloat32Workload,
151 RefDepthwiseConvolution2dUint8Workload>(descriptor, info);
152}
153
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000154std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDetectionPostProcess(
155 const armnn::DetectionPostProcessQueueDescriptor& descriptor, const armnn::WorkloadInfo& info) const
156{
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000157 const DataType dataType = info.m_InputTensorInfos[0].GetDataType();
158 switch (dataType)
159 {
160 case DataType::Float32:
161 return std::make_unique<RefDetectionPostProcessFloat32Workload>(descriptor, info);
162 case DataType::QuantisedAsymm8:
163 return std::make_unique<RefDetectionPostProcessUint8Workload>(descriptor, info);
164 default:
165 return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
166 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000167}
168
telsoa014fcda012018-03-09 14:13:49 +0000169std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization(
170 const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
171{
172 return MakeWorkload<RefNormalizationFloat32Workload, NullWorkload>(descriptor, info);
173}
174
175std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
176 const WorkloadInfo& info) const
177{
178 return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info);
179}
180
181std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication(
182 const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
183{
184 return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info);
185}
186
187std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization(
188 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
189{
190 return MakeWorkload<RefBatchNormalizationFloat32Workload, RefBatchNormalizationUint8Workload>(descriptor, info);
191}
192
193std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
194 const WorkloadInfo& info) const
195{
196 if (descriptor.m_Inputs.empty())
197 {
198 throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
199 }
telsoa01c577f2c2018-08-31 09:22:23 +0100200 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000201}
202
203std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
204 const WorkloadInfo& info) const
205{
206 return MakeWorkload<RefResizeBilinearFloat32Workload, RefResizeBilinearUint8Workload>(descriptor, info);
207}
208
209std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
210 const FakeQuantizationQueueDescriptor& descriptor,
211 const WorkloadInfo& info) const
212{
213 return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
214}
215
216std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
217 const WorkloadInfo& info) const
218{
219 return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
220}
221
222std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
223 const WorkloadInfo& info) const
224{
narpra01db2b1602019-01-23 15:23:11 +0000225 return MakeWorkloadHelper<NullWorkload, RefConstantFloat32Workload, RefConstantUint8Workload,
kevmay012b4d88e2019-01-24 14:05:09 +0000226 RefConstantInt32Workload, NullWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000227}
228
229std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
230 const WorkloadInfo& info) const
231{
232 return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
233}
234
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000235std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
236 const WorkloadInfo& info) const
237{
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000238 return MakeWorkload<RefSpaceToBatchNdFloat32Workload, RefSpaceToBatchNdUint8Workload>(descriptor, info);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000239}
240
telsoa014fcda012018-03-09 14:13:49 +0000241std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100242 const WorkloadInfo& info) const
telsoa014fcda012018-03-09 14:13:49 +0000243{
244 return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info);
245}
246
telsoa01c577f2c2018-08-31 09:22:23 +0100247std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
248 const WorkloadInfo& info) const
249{
250 return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
251}
252
253std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
254 const ConvertFp16ToFp32QueueDescriptor& descriptor,
255 const WorkloadInfo& info) const
256{
257 return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
258}
259
260std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
261 const ConvertFp32ToFp16QueueDescriptor& descriptor,
262 const WorkloadInfo& info) const
263{
264 return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
265}
266
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100267std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
268 const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
269{
270 return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info);
271}
272
David Beckc2044fe2018-09-05 15:00:38 +0100273std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
274 const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
275{
David Beckf195f032018-09-06 16:46:34 +0100276 return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
David Beckc2044fe2018-09-05 15:00:38 +0100277}
278
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000279std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMaximum(
280 const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const
281{
saoste012df12b32018-11-28 16:57:20 +0000282 return MakeWorkload<RefMaximumFloat32Workload, RefMaximumUint8Workload>(descriptor, info);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000283}
284
narpra01a6bf9122018-09-10 09:50:09 +0100285std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
286 const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const
287{
narpra011e4c31d2018-09-28 11:07:51 +0100288 return MakeWorkload<RefMeanFloat32Workload, RefMeanUint8Workload>(descriptor, info);
narpra01a6bf9122018-09-10 09:50:09 +0100289}
290
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000291std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMinimum(
292 const MinimumQueueDescriptor& descriptor, const WorkloadInfo& info) const
293{
294 return MakeWorkload<RefMinimumFloat32Workload, RefMinimumUint8Workload>(descriptor, info);
295}
296
jimfly012c9322a2018-09-19 10:59:49 +0100297std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
298 const WorkloadInfo& info) const
299{
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100300 return MakeWorkload<RefPadFloat32Workload, RefPadUint8Workload>(descriptor, info);
jimfly012c9322a2018-09-19 10:59:49 +0100301}
302
FrancisMurtagh20995952018-12-17 12:11:36 +0000303std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
304 const WorkloadInfo& info) const
305{
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000306 return MakeWorkload<RefEqualFloat32Workload, RefEqualUint8Workload>(descriptor, info);
FrancisMurtagh20995952018-12-17 12:11:36 +0000307}
308
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000309std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
310 const WorkloadInfo& info) const
311{
312 return MakeWorkload<RefBatchToSpaceNdFloat32Workload, RefBatchToSpaceNdUint8Workload>(descriptor, info);
313}
jimfly012c9322a2018-09-19 10:59:49 +0100314
Conor Kennedy430b5d82018-11-14 15:28:28 +0000315std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
316 const WorkloadInfo& info) const
317{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000318 return MakeWorkload<RefStridedSliceFloat32Workload, RefStridedSliceUint8Workload>(descriptor, info);
Conor Kennedy430b5d82018-11-14 15:28:28 +0000319}
320
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000321std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
322 const WorkloadInfo& info) const
323{
FrancisMurtagh878f0232018-12-19 10:56:15 +0000324 return MakeWorkload<RefGreaterFloat32Workload, RefGreaterUint8Workload>(descriptor, info);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000325}
326
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000327std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
328 const WorkloadInfo& info) const
329{
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000330 return MakeWorkload<RefDebugFloat32Workload, RefDebugUint8Workload>(descriptor, info);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000331}
332
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000333std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
334 const WorkloadInfo& info) const
335{
336 return MakeWorkload<RefRsqrtFloat32Workload, NullWorkload>(descriptor, info);
337}
338
narpra014951d842019-01-18 16:53:53 +0000339std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
340 const armnn::WorkloadInfo& info) const
341{
342 return MakeWorkload<RefGatherFloat32Workload, RefGatherUint8Workload>(descriptor, info);
343}
344
Matteo Martincigh49124022019-01-11 13:25:59 +0000345std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
346 const WorkloadInfo& info) const
347{
348 return nullptr;
349}
350
351} // namespace armnn