blob: b1f9d6c70a3a1b7921903c717ffe06062e997471 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
David Beckb4540be2018-09-24 13:18:27 +01005#include <backends/CpuTensorHandle.hpp>
6#include <backends/MemCopyWorkload.hpp>
7#include <backends/MakeWorkloadHelper.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include "RefWorkloadFactory.hpp"
David Beckb4540be2018-09-24 13:18:27 +01009#include "workloads/RefWorkloads.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include "Layer.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011
12#include <boost/log/trivial.hpp>
13
14namespace armnn
15{
16
17template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
18std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
19 const WorkloadInfo& info) const
20{
telsoa01c577f2c2018-08-31 09:22:23 +010021 return armnn::MakeWorkload<NullWorkload, F32Workload, U8Workload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000022}
23
telsoa01c577f2c2018-08-31 09:22:23 +010024RefWorkloadFactory::RefWorkloadFactory()
telsoa014fcda012018-03-09 14:13:49 +000025{
26}
27
telsoa01c577f2c2018-08-31 09:22:23 +010028bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
29 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000030{
31 return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported);
32}
33
34std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
35{
36 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
37}
38
Francis Murtagh351d13d2018-09-24 15:01:18 +010039std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
40 DataLayout dataLayout) const
41{
42 return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
43}
44
telsoa014fcda012018-03-09 14:13:49 +000045std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
46 const WorkloadInfo& info) const
47{
48 if (info.m_InputTensorInfos.empty() )
49 {
50 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
51 }
52 if (info.m_OutputTensorInfos.empty())
53 {
54 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
55 }
56
57 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
58 {
59 throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
60 }
61
telsoa01c577f2c2018-08-31 09:22:23 +010062 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000063}
64
65std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
66 const WorkloadInfo& info) const
67{
68 if (info.m_InputTensorInfos.empty() )
69 {
70 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
71 }
72 if (info.m_OutputTensorInfos.empty())
73 {
74 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
75 }
76 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
77 {
78 throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
79 }
80
telsoa01c577f2c2018-08-31 09:22:23 +010081 return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +000082}
83
84std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
85 const WorkloadInfo& info) const
86{
87 return MakeWorkload<RefActivationFloat32Workload, RefActivationUint8Workload>(descriptor, info);
88}
89
90std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
91 const WorkloadInfo& info) const
92{
93 return MakeWorkload<RefSoftmaxFloat32Workload, RefSoftmaxUint8Workload>(descriptor, info);
94}
95
96std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
97 const WorkloadInfo& info) const
98{
99 return MakeWorkload<RefSplitterFloat32Workload, RefSplitterUint8Workload>(descriptor, info);
100}
101
102std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
103 const WorkloadInfo& info) const
104{
105 return MakeWorkload<RefMergerFloat32Workload, RefMergerUint8Workload>(descriptor, info);
106}
107
108std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
109 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
110{
111 return MakeWorkload<RefFullyConnectedFloat32Workload, RefFullyConnectedUint8Workload>(descriptor, info);
112}
113
114std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
115 const WorkloadInfo& info) const
116{
arovir01616e7752018-10-01 17:08:59 +0100117 return armnn::MakeWorkload<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload>
118 (descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000119}
120
121std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
122 const WorkloadInfo& info) const
123{
124 return MakeWorkload<RefPooling2dFloat32Workload, RefPooling2dUint8Workload>(descriptor, info);
125}
126
127std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConvolution2d(
128 const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
129{
130 return MakeWorkload<RefConvolution2dFloat32Workload, RefConvolution2dUint8Workload>(descriptor, info);
131}
132
133std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthwiseConvolution2d(
134 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
135{
136 return MakeWorkload<RefDepthwiseConvolution2dFloat32Workload,
137 RefDepthwiseConvolution2dUint8Workload>(descriptor, info);
138}
139
140std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization(
141 const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
142{
143 return MakeWorkload<RefNormalizationFloat32Workload, NullWorkload>(descriptor, info);
144}
145
146std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
147 const WorkloadInfo& info) const
148{
149 return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info);
150}
151
152std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication(
153 const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
154{
155 return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info);
156}
157
158std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization(
159 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
160{
161 return MakeWorkload<RefBatchNormalizationFloat32Workload, RefBatchNormalizationUint8Workload>(descriptor, info);
162}
163
164std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
165 const WorkloadInfo& info) const
166{
167 if (descriptor.m_Inputs.empty())
168 {
169 throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
170 }
telsoa01c577f2c2018-08-31 09:22:23 +0100171 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
telsoa014fcda012018-03-09 14:13:49 +0000172}
173
174std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
175 const WorkloadInfo& info) const
176{
177 return MakeWorkload<RefResizeBilinearFloat32Workload, RefResizeBilinearUint8Workload>(descriptor, info);
178}
179
180std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
181 const FakeQuantizationQueueDescriptor& descriptor,
182 const WorkloadInfo& info) const
183{
184 return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
185}
186
187std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
188 const WorkloadInfo& info) const
189{
190 return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
191}
192
193std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
194 const WorkloadInfo& info) const
195{
196 return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info);
197}
198
199std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
200 const WorkloadInfo& info) const
201{
202 return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
203}
204
205std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100206 const WorkloadInfo& info) const
telsoa014fcda012018-03-09 14:13:49 +0000207{
208 return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info);
209}
210
telsoa01c577f2c2018-08-31 09:22:23 +0100211std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
212 const WorkloadInfo& info) const
213{
214 return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
215}
216
217std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
218 const ConvertFp16ToFp32QueueDescriptor& descriptor,
219 const WorkloadInfo& info) const
220{
221 return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
222}
223
224std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
225 const ConvertFp32ToFp16QueueDescriptor& descriptor,
226 const WorkloadInfo& info) const
227{
228 return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
229}
230
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100231std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
232 const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
233{
234 return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info);
235}
236
David Beckc2044fe2018-09-05 15:00:38 +0100237std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
238 const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
239{
David Beckf195f032018-09-06 16:46:34 +0100240 return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
David Beckc2044fe2018-09-05 15:00:38 +0100241}
242
narpra01a6bf9122018-09-10 09:50:09 +0100243std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
244 const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const
245{
narpra011e4c31d2018-09-28 11:07:51 +0100246 return MakeWorkload<RefMeanFloat32Workload, RefMeanUint8Workload>(descriptor, info);
narpra01a6bf9122018-09-10 09:50:09 +0100247}
248
jimfly012c9322a2018-09-19 10:59:49 +0100249std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
250 const WorkloadInfo& info) const
251{
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100252 return MakeWorkload<RefPadFloat32Workload, RefPadUint8Workload>(descriptor, info);
jimfly012c9322a2018-09-19 10:59:49 +0100253}
254
255
telsoa014fcda012018-03-09 14:13:49 +0000256} // namespace armnn