blob: 582c691a18a0e629f2e37fada6d5910e7441f583 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
David Beckb4540be2018-09-24 13:18:27 +01005#include <backends/CpuTensorHandle.hpp>
6#include <backends/MemCopyWorkload.hpp>
7#include <backends/MakeWorkloadHelper.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include "RefWorkloadFactory.hpp"
David Beckb4540be2018-09-24 13:18:27 +01009#include "workloads/RefWorkloads.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include "Layer.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011
12#include <boost/log/trivial.hpp>
13
14namespace armnn
15{
16
17template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
18std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
19 const WorkloadInfo& info) const
20{
telsoa01c577f2c2018-08-31 09:22:23 +010021 return armnn::MakeWorkload<NullWorkload, F32Workload, U8Workload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000022}
23
telsoa01c577f2c2018-08-31 09:22:23 +010024RefWorkloadFactory::RefWorkloadFactory()
telsoa014fcda012018-03-09 14:13:49 +000025{
26}
27
telsoa01c577f2c2018-08-31 09:22:23 +010028bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
29 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000030{
31 return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported);
32}
33
34std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
35{
36 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
37}
38
Francis Murtagh351d13d2018-09-24 15:01:18 +010039std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
40 DataLayout dataLayout) const
41{
42 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
43}
44
telsoa014fcda012018-03-09 14:13:49 +000045std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
46 const WorkloadInfo& info) const
47{
48 if (info.m_InputTensorInfos.empty() )
49 {
50 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
51 }
52 if (info.m_OutputTensorInfos.empty())
53 {
54 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
55 }
56
57 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
58 {
59 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
60 }
61
telsoa01c577f2c2018-08-31 09:22:23 +010062 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000063}
64
65std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
66 const WorkloadInfo& info) const
67{
68 if (info.m_InputTensorInfos.empty() )
69 {
70 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
71 }
72 if (info.m_OutputTensorInfos.empty())
73 {
74 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
75 }
76 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
77 {
78 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
79 }
80
telsoa01c577f2c2018-08-31 09:22:23 +010081 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000082}
83
84std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
85 const WorkloadInfo& info) const
86{
87 return MakeWorkload<RefActivationFloat32Workload, RefActivationUint8Workload>(descriptor, info);
88}
89
90std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
91 const WorkloadInfo& info) const
92{
93 return MakeWorkload<RefSoftmaxFloat32Workload, RefSoftmaxUint8Workload>(descriptor, info);
94}
95
96std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
97 const WorkloadInfo& info) const
98{
99 return MakeWorkload<RefSplitterFloat32Workload, RefSplitterUint8Workload>(descriptor, info);
100}
101
102std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
103 const WorkloadInfo& info) const
104{
105 return MakeWorkload<RefMergerFloat32Workload, RefMergerUint8Workload>(descriptor, info);
106}
107
108std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
109 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
110{
111 return MakeWorkload<RefFullyConnectedFloat32Workload, RefFullyConnectedUint8Workload>(descriptor, info);
112}
113
114std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
115 const WorkloadInfo& info) const
116{
117 return MakeWorkload<RefPermuteFloat32Workload, RefPermuteUint8Workload>(descriptor, info);
118}
119
120std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
121 const WorkloadInfo& info) const
122{
123 return MakeWorkload<RefPooling2dFloat32Workload, RefPooling2dUint8Workload>(descriptor, info);
124}
125
126std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConvolution2d(
127 const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
128{
129 return MakeWorkload<RefConvolution2dFloat32Workload, RefConvolution2dUint8Workload>(descriptor, info);
130}
131
132std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthwiseConvolution2d(
133 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
134{
135 return MakeWorkload<RefDepthwiseConvolution2dFloat32Workload,
136 RefDepthwiseConvolution2dUint8Workload>(descriptor, info);
137}
138
139std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization(
140 const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
141{
142 return MakeWorkload<RefNormalizationFloat32Workload, NullWorkload>(descriptor, info);
143}
144
145std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
146 const WorkloadInfo& info) const
147{
148 return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info);
149}
150
151std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication(
152 const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
153{
154 return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info);
155}
156
157std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization(
158 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
159{
160 return MakeWorkload<RefBatchNormalizationFloat32Workload, RefBatchNormalizationUint8Workload>(descriptor, info);
161}
162
163std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
164 const WorkloadInfo& info) const
165{
166 if (descriptor.m_Inputs.empty())
167 {
168 throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
169 }
telsoa01c577f2c2018-08-31 09:22:23 +0100170 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000171}
172
173std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
174 const WorkloadInfo& info) const
175{
176 return MakeWorkload<RefResizeBilinearFloat32Workload, RefResizeBilinearUint8Workload>(descriptor, info);
177}
178
179std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
180 const FakeQuantizationQueueDescriptor& descriptor,
181 const WorkloadInfo& info) const
182{
183 return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
184}
185
186std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
187 const WorkloadInfo& info) const
188{
189 return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
190}
191
192std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
193 const WorkloadInfo& info) const
194{
195 return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info);
196}
197
198std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
199 const WorkloadInfo& info) const
200{
201 return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
202}
203
204std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100205 const WorkloadInfo& info) const
telsoa014fcda012018-03-09 14:13:49 +0000206{
207 return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info);
208}
209
telsoa01c577f2c2018-08-31 09:22:23 +0100210std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
211 const WorkloadInfo& info) const
212{
213 return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
214}
215
216std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
217 const ConvertFp16ToFp32QueueDescriptor& descriptor,
218 const WorkloadInfo& info) const
219{
220 return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
221}
222
223std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
224 const ConvertFp32ToFp16QueueDescriptor& descriptor,
225 const WorkloadInfo& info) const
226{
227 return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
228}
229
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100230std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
231 const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
232{
233 return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info);
234}
235
David Beckc2044fe2018-09-05 15:00:38 +0100236std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
237 const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
238{
David Beckf195f032018-09-06 16:46:34 +0100239 return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
David Beckc2044fe2018-09-05 15:00:38 +0100240}
241
narpra01a6bf9122018-09-10 09:50:09 +0100242std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
243 const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const
244{
narpra011e4c31d2018-09-28 11:07:51 +0100245 return MakeWorkload<RefMeanFloat32Workload, RefMeanUint8Workload>(descriptor, info);
narpra01a6bf9122018-09-10 09:50:09 +0100246}
247
jimfly012c9322a2018-09-19 10:59:49 +0100248std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
249 const WorkloadInfo& info) const
250{
251 return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
252}
253
254
telsoa014fcda012018-03-09 14:13:49 +0000255} // namespace armnn