blob: e5f210773d933dec005bc84d6821ff587bed0c04 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly4cc341c2023-07-07 15:43:06 +01002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Mike Kelly4cc341c2023-07-07 15:43:06 +01005
telsoa014fcda012018-03-09 14:13:49 +00006#pragma once
7
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01008#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01009#include <Half.hpp>
10
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010011#include <armnn/utility/Assert.hpp>
12
Derek Lambertic81855f2019-06-13 17:34:19 +010013#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010015#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <arm_compute/runtime/MemoryGroup.h>
18#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000019#include <arm_compute/runtime/Tensor.h>
20#include <arm_compute/runtime/SubTensor.h>
21#include <arm_compute/core/TensorShape.h>
22#include <arm_compute/core/Coordinates.h>
Mike Kelly4cc341c2023-07-07 15:43:06 +010023#include "armnn/TypesUtils.hpp"
telsoa014fcda012018-03-09 14:13:49 +000024
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
Mike Kelly4cc341c2023-07-07 15:43:06 +010027class NeonTensorHandleDecorator;
telsoa014fcda012018-03-09 14:13:49 +000028
Derek Lambertic81855f2019-06-13 17:34:19 +010029class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000030{
31public:
32 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010033 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
34 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010035 m_IsImportEnabled(false),
36 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
telsoa014fcda012018-03-09 14:13:49 +000037 {
38 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
39 }
40
David Monahan3fb7e102019-08-20 11:25:29 +010041 NeonTensorHandle(const TensorInfo& tensorInfo,
42 DataLayout dataLayout,
43 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
44 : m_ImportFlags(importFlags),
45 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010046 m_IsImportEnabled(false),
47 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
48
David Monahan3fb7e102019-08-20 11:25:29 +010049
Francis Murtagh351d13d2018-09-24 15:01:18 +010050 {
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
52 }
53
telsoa014fcda012018-03-09 14:13:49 +000054 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
55 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010056
telsoa014fcda012018-03-09 14:13:49 +000057 virtual void Allocate() override
58 {
David Monahan3fb7e102019-08-20 11:25:29 +010059 // If we have enabled Importing, don't Allocate the tensor
60 if (!m_IsImportEnabled)
61 {
62 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
63 }
telsoa014fcda012018-03-09 14:13:49 +000064 };
65
telsoa01c577f2c2018-08-31 09:22:23 +010066 virtual void Manage() override
67 {
David Monahan3fb7e102019-08-20 11:25:29 +010068 // If we have enabled Importing, don't manage the tensor
69 if (!m_IsImportEnabled)
70 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010071 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010072 m_MemoryGroup->manage(&m_Tensor);
73 }
telsoa01c577f2c2018-08-31 09:22:23 +010074 }
75
telsoa01c577f2c2018-08-31 09:22:23 +010076 virtual ITensorHandle* GetParent() const override { return nullptr; }
77
telsoa014fcda012018-03-09 14:13:49 +000078 virtual arm_compute::DataType GetDataType() const override
79 {
80 return m_Tensor.info()->data_type();
81 }
82
telsoa01c577f2c2018-08-31 09:22:23 +010083 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
84 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010085 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010086 }
87
88 virtual const void* Map(bool /* blocking = true */) const override
89 {
90 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
91 }
telsoa01c577f2c2018-08-31 09:22:23 +010092
David Monahan3fb7e102019-08-20 11:25:29 +010093 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010094
95 TensorShape GetStrides() const override
96 {
97 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
98 }
99
100 TensorShape GetShape() const override
101 {
102 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
103 }
104
David Monahan3fb7e102019-08-20 11:25:29 +0100105 void SetImportFlags(MemorySourceFlags importFlags)
106 {
107 m_ImportFlags = importFlags;
108 }
109
110 MemorySourceFlags GetImportFlags() const override
111 {
112 return m_ImportFlags;
113 }
114
115 void SetImportEnabledFlag(bool importEnabledFlag)
116 {
117 m_IsImportEnabled = importEnabledFlag;
118 }
119
David Monahan0fa10502022-01-13 10:48:33 +0000120 bool CanBeImported(void* memory, MemorySource source) override
121 {
David Monahan3826ab62022-02-21 12:26:16 +0000122 if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
David Monahan0fa10502022-01-13 10:48:33 +0000123 {
124 return false;
125 }
126 return true;
127 }
128
David Monahan3fb7e102019-08-20 11:25:29 +0100129 virtual bool Import(void* memory, MemorySource source) override
130 {
Mike Kelly4cc341c2023-07-07 15:43:06 +0100131 if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
David Monahan3fb7e102019-08-20 11:25:29 +0100132 {
133 if (source == MemorySource::Malloc && m_IsImportEnabled)
134 {
David Monahan0fa10502022-01-13 10:48:33 +0000135 if (!CanBeImported(memory, source))
David Monahan3fb7e102019-08-20 11:25:29 +0100136 {
137 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
138 }
139
140 // m_Tensor not yet Allocated
141 if (!m_Imported && !m_Tensor.buffer())
142 {
143 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
144 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
145 // with the Status error message
146 m_Imported = bool(status);
147 if (!m_Imported)
148 {
149 throw MemoryImportException(status.error_description());
150 }
151 return m_Imported;
152 }
153
154 // m_Tensor.buffer() initially allocated with Allocate().
155 if (!m_Imported && m_Tensor.buffer())
156 {
157 throw MemoryImportException(
158 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
159 }
160
161 // m_Tensor.buffer() previously imported.
162 if (m_Imported)
163 {
164 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
165 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
166 // with the Status error message
167 m_Imported = bool(status);
168 if (!m_Imported)
169 {
170 throw MemoryImportException(status.error_description());
171 }
172 return m_Imported;
173 }
174 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100175 else
176 {
177 throw MemoryImportException("NeonTensorHandle::Import is disabled");
178 }
179 }
180 else
181 {
182 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100183 }
184 return false;
185 }
186
Mike Kelly4cc341c2023-07-07 15:43:06 +0100187 virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
188
telsoa014fcda012018-03-09 14:13:49 +0000189private:
David Beck09e2f272018-10-30 11:38:41 +0000190 // Only used for testing
191 void CopyOutTo(void* memory) const override
192 {
193 switch (this->GetDataType())
194 {
195 case arm_compute::DataType::F32:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<float*>(memory));
198 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000199 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000200 case arm_compute::DataType::QASYMM8:
201 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
202 static_cast<uint8_t*>(memory));
203 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100204 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100205 case arm_compute::DataType::QASYMM8_SIGNED:
206 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
207 static_cast<int8_t*>(memory));
208 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100209 case arm_compute::DataType::BFLOAT16:
210 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
211 static_cast<armnn::BFloat16*>(memory));
212 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100213 case arm_compute::DataType::F16:
214 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
215 static_cast<armnn::Half*>(memory));
216 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100217 case arm_compute::DataType::S16:
218 case arm_compute::DataType::QSYMM16:
219 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
220 static_cast<int16_t*>(memory));
221 break;
James Conroyd47a0642019-09-17 14:22:06 +0100222 case arm_compute::DataType::S32:
223 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
224 static_cast<int32_t*>(memory));
225 break;
David Beck09e2f272018-10-30 11:38:41 +0000226 default:
227 {
228 throw armnn::UnimplementedException();
229 }
230 }
231 }
232
233 // Only used for testing
234 void CopyInFrom(const void* memory) override
235 {
236 switch (this->GetDataType())
237 {
238 case arm_compute::DataType::F32:
239 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
240 this->GetTensor());
241 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000242 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000243 case arm_compute::DataType::QASYMM8:
244 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
245 this->GetTensor());
246 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100247 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100248 case arm_compute::DataType::QASYMM8_SIGNED:
Cathal Corbett06902652022-04-14 17:55:11 +0100249 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100250 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
251 this->GetTensor());
252 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100253 case arm_compute::DataType::BFLOAT16:
254 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
255 this->GetTensor());
256 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100257 case arm_compute::DataType::F16:
258 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
259 this->GetTensor());
260 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100261 case arm_compute::DataType::S16:
262 case arm_compute::DataType::QSYMM16:
263 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
264 this->GetTensor());
265 break;
James Conroyd47a0642019-09-17 14:22:06 +0100266 case arm_compute::DataType::S32:
267 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
268 this->GetTensor());
269 break;
David Beck09e2f272018-10-30 11:38:41 +0000270 default:
271 {
272 throw armnn::UnimplementedException();
273 }
274 }
275 }
276
telsoa014fcda012018-03-09 14:13:49 +0000277 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100278 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100279 MemorySourceFlags m_ImportFlags;
280 bool m_Imported;
281 bool m_IsImportEnabled;
Finn Williamsb1aad422021-10-28 19:07:32 +0100282 const uintptr_t m_TypeAlignment;
Mike Kelly4cc341c2023-07-07 15:43:06 +0100283 std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
telsoa014fcda012018-03-09 14:13:49 +0000284};
285
Derek Lambertic81855f2019-06-13 17:34:19 +0100286class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000287{
288public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100289 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100290 const arm_compute::TensorShape& shape,
291 const arm_compute::Coordinates& coords)
Mike Kelly4cc341c2023-07-07 15:43:06 +0100292 : m_Tensor(&parent->GetTensor(), shape, coords, true)
telsoa014fcda012018-03-09 14:13:49 +0000293 {
telsoa01c577f2c2018-08-31 09:22:23 +0100294 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000295 }
296
297 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
298 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100299
300 virtual void Allocate() override {}
301 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000302
telsoa01c577f2c2018-08-31 09:22:23 +0100303 virtual ITensorHandle* GetParent() const override { return parentHandle; }
304
telsoa014fcda012018-03-09 14:13:49 +0000305 virtual arm_compute::DataType GetDataType() const override
306 {
307 return m_Tensor.info()->data_type();
308 }
309
telsoa01c577f2c2018-08-31 09:22:23 +0100310 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
311
312 virtual const void* Map(bool /* blocking = true */) const override
313 {
314 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
315 }
316 virtual void Unmap() const override {}
317
318 TensorShape GetStrides() const override
319 {
320 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
321 }
322
323 TensorShape GetShape() const override
324 {
325 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
326 }
David Beck09e2f272018-10-30 11:38:41 +0000327
Mike Kelly4cc341c2023-07-07 15:43:06 +0100328 virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
329 {
330 return nullptr;
331 };
332
telsoa014fcda012018-03-09 14:13:49 +0000333private:
David Beck09e2f272018-10-30 11:38:41 +0000334 // Only used for testing
335 void CopyOutTo(void* memory) const override
336 {
337 switch (this->GetDataType())
338 {
339 case arm_compute::DataType::F32:
340 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
341 static_cast<float*>(memory));
342 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000343 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000344 case arm_compute::DataType::QASYMM8:
345 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
346 static_cast<uint8_t*>(memory));
347 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100348 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100349 case arm_compute::DataType::QASYMM8_SIGNED:
350 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
351 static_cast<int8_t*>(memory));
352 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100353 case arm_compute::DataType::S16:
354 case arm_compute::DataType::QSYMM16:
355 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
356 static_cast<int16_t*>(memory));
357 break;
James Conroyd47a0642019-09-17 14:22:06 +0100358 case arm_compute::DataType::S32:
359 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
360 static_cast<int32_t*>(memory));
361 break;
David Beck09e2f272018-10-30 11:38:41 +0000362 default:
363 {
364 throw armnn::UnimplementedException();
365 }
366 }
367 }
368
369 // Only used for testing
370 void CopyInFrom(const void* memory) override
371 {
372 switch (this->GetDataType())
373 {
374 case arm_compute::DataType::F32:
375 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
376 this->GetTensor());
377 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000378 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000379 case arm_compute::DataType::QASYMM8:
380 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
381 this->GetTensor());
382 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100383 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100384 case arm_compute::DataType::QASYMM8_SIGNED:
385 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
386 this->GetTensor());
387 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100388 case arm_compute::DataType::S16:
389 case arm_compute::DataType::QSYMM16:
390 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
391 this->GetTensor());
392 break;
James Conroyd47a0642019-09-17 14:22:06 +0100393 case arm_compute::DataType::S32:
394 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
395 this->GetTensor());
396 break;
David Beck09e2f272018-10-30 11:38:41 +0000397 default:
398 {
399 throw armnn::UnimplementedException();
400 }
401 }
402 }
403
telsoa01c577f2c2018-08-31 09:22:23 +0100404 arm_compute::SubTensor m_Tensor;
405 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000406};
407
Mike Kelly4cc341c2023-07-07 15:43:06 +0100408/// NeonTensorDecorator wraps an existing Neon tensor allowing us to override the TensorInfo for it
409class NeonTensorDecorator : public arm_compute::ITensor
410{
411public:
412 NeonTensorDecorator();
413
414 NeonTensorDecorator(arm_compute::ITensor* original, const TensorInfo& info);
415
416 ~NeonTensorDecorator() = default;
417
418 NeonTensorDecorator(const NeonTensorDecorator&) = delete;
419
420 NeonTensorDecorator& operator=(const NeonTensorDecorator&) = delete;
421
422 NeonTensorDecorator(NeonTensorDecorator&&) = default;
423
424 NeonTensorDecorator& operator=(NeonTensorDecorator&&) = default;
425
426 // Inherited methods overridden:
427 arm_compute::ITensorInfo* info() const override;
428
429 arm_compute::ITensorInfo* info() override;
430
431 uint8_t* buffer() const override;
432
433private:
434 arm_compute::ITensor* m_Original;
435 mutable arm_compute::TensorInfo m_TensorInfo;
436};
437
438class NeonTensorHandleDecorator : public IAclTensorHandle
439{
440public:
441 NeonTensorHandleDecorator(IAclTensorHandle* parent, const TensorInfo& info)
442 : m_Tensor(&parent->GetTensor(), info)
443 {
444 parentHandle = parent;
445 }
446
447 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
448 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
449
450 virtual void Allocate() override {}
451 virtual void Manage() override {}
452
453 virtual ITensorHandle* GetParent() const override { return nullptr; }
454
455 virtual arm_compute::DataType GetDataType() const override
456 {
457 return m_Tensor.info()->data_type();
458 }
459
460 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
461
462 virtual const void* Map(bool /* blocking = true */) const override
463 {
464 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
465 }
466 virtual void Unmap() const override {}
467
468 TensorShape GetStrides() const override
469 {
470 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
471 }
472
473 TensorShape GetShape() const override
474 {
475 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
476 }
477
478 virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
479 {
480 return nullptr;
481 };
482
483private:
484 // Only used for testing
485 void CopyOutTo(void* memory) const override
486 {
487 switch (this->GetDataType())
488 {
489 case arm_compute::DataType::F32:
490 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
491 static_cast<float*>(memory));
492 break;
493 case arm_compute::DataType::U8:
494 case arm_compute::DataType::QASYMM8:
495 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
496 static_cast<uint8_t*>(memory));
497 break;
498 case arm_compute::DataType::QSYMM8:
499 case arm_compute::DataType::QASYMM8_SIGNED:
500 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
501 static_cast<int8_t*>(memory));
502 break;
503 case arm_compute::DataType::S16:
504 case arm_compute::DataType::QSYMM16:
505 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
506 static_cast<int16_t*>(memory));
507 break;
508 case arm_compute::DataType::S32:
509 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
510 static_cast<int32_t*>(memory));
511 break;
512 default:
513 {
514 throw armnn::UnimplementedException();
515 }
516 }
517 }
518
519 // Only used for testing
520 void CopyInFrom(const void* memory) override
521 {
522 switch (this->GetDataType())
523 {
524 case arm_compute::DataType::F32:
525 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
526 this->GetTensor());
527 break;
528 case arm_compute::DataType::U8:
529 case arm_compute::DataType::QASYMM8:
530 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
531 this->GetTensor());
532 break;
533 case arm_compute::DataType::QSYMM8:
534 case arm_compute::DataType::QASYMM8_SIGNED:
535 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
536 this->GetTensor());
537 break;
538 case arm_compute::DataType::S16:
539 case arm_compute::DataType::QSYMM16:
540 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
541 this->GetTensor());
542 break;
543 case arm_compute::DataType::S32:
544 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
545 this->GetTensor());
546 break;
547 default:
548 {
549 throw armnn::UnimplementedException();
550 }
551 }
552 }
553
554 NeonTensorDecorator m_Tensor;
555 ITensorHandle* parentHandle = nullptr;
556};
557
558
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100559} // namespace armnn