blob: 42657341fde73f688a9e18f4940033ea3d71efc9 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly4cc341c2023-07-07 15:43:06 +01002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Mike Kelly4cc341c2023-07-07 15:43:06 +01005
telsoa014fcda012018-03-09 14:13:49 +00006#pragma once
7
Derek Lambertic81855f2019-06-13 17:34:19 +01008#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
David Beck09e2f272018-10-30 11:38:41 +000011#include <Half.hpp>
12
Jan Eilers3c9e0452020-04-10 13:00:44 +010013#include <armnn/utility/PolymorphicDowncast.hpp>
14
telsoa014fcda012018-03-09 14:13:49 +000015#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <arm_compute/runtime/IMemoryGroup.h>
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010018#include <arm_compute/runtime/MemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000019#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
Cathal Corbettd9e55f02023-01-11 13:03:21 +000022#include <aclCommon/IClTensorHandle.hpp>
Narumol Prangnawarat9ef36142022-01-25 15:15:34 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
Mike Kelly4cc341c2023-07-07 15:43:06 +010026class ClTensorHandleDecorator;
telsoa014fcda012018-03-09 14:13:49 +000027
telsoa014fcda012018-03-09 14:13:49 +000028class ClTensorHandle : public IClTensorHandle
29{
30public:
31 ClTensorHandle(const TensorInfo& tensorInfo)
David Monahan66dbf5b2021-03-11 11:34:54 +000032 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
33 m_Imported(false),
34 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000035 {
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37 }
38
David Monahan66dbf5b2021-03-11 11:34:54 +000039 ClTensorHandle(const TensorInfo& tensorInfo,
40 DataLayout dataLayout,
41 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
42 : m_ImportFlags(importFlags),
43 m_Imported(false),
44 m_IsImportEnabled(false)
Francis Murtagh351d13d2018-09-24 15:01:18 +010045 {
46 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
47 }
48
telsoa014fcda012018-03-09 14:13:49 +000049 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
50 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
David Monahan66dbf5b2021-03-11 11:34:54 +000051 virtual void Allocate() override
52 {
53 // If we have enabled Importing, don't allocate the tensor
David Monahane4a41dc2021-04-14 16:55:36 +010054 if (m_IsImportEnabled)
55 {
56 throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
57 }
58 else
David Monahan66dbf5b2021-03-11 11:34:54 +000059 {
60 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
61 }
62
63 }
telsoa014fcda012018-03-09 14:13:49 +000064
telsoa01c577f2c2018-08-31 09:22:23 +010065 virtual void Manage() override
66 {
David Monahan66dbf5b2021-03-11 11:34:54 +000067 // If we have enabled Importing, don't manage the tensor
David Monahane4a41dc2021-04-14 16:55:36 +010068 if (m_IsImportEnabled)
69 {
70 throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
71 }
72 else
David Monahan66dbf5b2021-03-11 11:34:54 +000073 {
74 assert(m_MemoryGroup != nullptr);
75 m_MemoryGroup->manage(&m_Tensor);
76 }
telsoa01c577f2c2018-08-31 09:22:23 +010077 }
telsoa014fcda012018-03-09 14:13:49 +000078
telsoa01c577f2c2018-08-31 09:22:23 +010079 virtual const void* Map(bool blocking = true) const override
80 {
81 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
82 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
83 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010084
telsoa01c577f2c2018-08-31 09:22:23 +010085 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
86
telsoa01c577f2c2018-08-31 09:22:23 +010087 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000088
89 virtual arm_compute::DataType GetDataType() const override
90 {
91 return m_Tensor.info()->data_type();
92 }
93
telsoa01c577f2c2018-08-31 09:22:23 +010094 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
95 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010096 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010097 }
98
99 TensorShape GetStrides() const override
100 {
101 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
102 }
103
104 TensorShape GetShape() const override
105 {
106 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
107 }
David Beck09e2f272018-10-30 11:38:41 +0000108
David Monahan66dbf5b2021-03-11 11:34:54 +0000109 void SetImportFlags(MemorySourceFlags importFlags)
110 {
111 m_ImportFlags = importFlags;
112 }
113
114 MemorySourceFlags GetImportFlags() const override
115 {
116 return m_ImportFlags;
117 }
118
119 void SetImportEnabledFlag(bool importEnabledFlag)
120 {
121 m_IsImportEnabled = importEnabledFlag;
122 }
123
124 virtual bool Import(void* memory, MemorySource source) override
125 {
126 armnn::IgnoreUnused(memory);
Mike Kelly4cc341c2023-07-07 15:43:06 +0100127 if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
David Monahan66dbf5b2021-03-11 11:34:54 +0000128 {
129 throw MemoryImportException("ClTensorHandle::Incorrect import flag");
130 }
131 m_Imported = false;
132 return false;
133 }
134
Nikhil Raj60ab9762022-01-13 09:34:44 +0000135 virtual bool CanBeImported(void* memory, MemorySource source) override
136 {
137 // This TensorHandle can never import.
138 armnn::IgnoreUnused(memory, source);
139 return false;
140 }
141
Mike Kelly4cc341c2023-07-07 15:43:06 +0100142 virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
143
telsoa014fcda012018-03-09 14:13:49 +0000144private:
David Beck09e2f272018-10-30 11:38:41 +0000145 // Only used for testing
146 void CopyOutTo(void* memory) const override
147 {
148 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
149 switch(this->GetDataType())
150 {
151 case arm_compute::DataType::F32:
152 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
153 static_cast<float*>(memory));
154 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000155 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000156 case arm_compute::DataType::QASYMM8:
157 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
158 static_cast<uint8_t*>(memory));
159 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100160 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000161 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
162 case arm_compute::DataType::QASYMM8_SIGNED:
163 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164 static_cast<int8_t*>(memory));
165 break;
David Beck09e2f272018-10-30 11:38:41 +0000166 case arm_compute::DataType::F16:
167 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
168 static_cast<armnn::Half*>(memory));
169 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100170 case arm_compute::DataType::S16:
171 case arm_compute::DataType::QSYMM16:
172 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173 static_cast<int16_t*>(memory));
174 break;
James Conroy2dc05722019-09-19 17:00:31 +0100175 case arm_compute::DataType::S32:
176 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
177 static_cast<int32_t*>(memory));
178 break;
David Beck09e2f272018-10-30 11:38:41 +0000179 default:
180 {
181 throw armnn::UnimplementedException();
182 }
183 }
184 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
185 }
186
187 // Only used for testing
188 void CopyInFrom(const void* memory) override
189 {
190 this->Map(true);
191 switch(this->GetDataType())
192 {
193 case arm_compute::DataType::F32:
194 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
195 this->GetTensor());
196 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000197 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000198 case arm_compute::DataType::QASYMM8:
199 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
200 this->GetTensor());
201 break;
202 case arm_compute::DataType::F16:
203 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
204 this->GetTensor());
205 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100206 case arm_compute::DataType::S16:
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100207 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000208 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
209 case arm_compute::DataType::QASYMM8_SIGNED:
210 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
211 this->GetTensor());
212 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100213 case arm_compute::DataType::QSYMM16:
214 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
215 this->GetTensor());
216 break;
James Conroy2dc05722019-09-19 17:00:31 +0100217 case arm_compute::DataType::S32:
218 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
219 this->GetTensor());
220 break;
David Beck09e2f272018-10-30 11:38:41 +0000221 default:
222 {
223 throw armnn::UnimplementedException();
224 }
225 }
226 this->Unmap();
227 }
228
telsoa014fcda012018-03-09 14:13:49 +0000229 arm_compute::CLTensor m_Tensor;
Narumol Prangnawarat680f9912019-10-01 11:32:10 +0100230 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan66dbf5b2021-03-11 11:34:54 +0000231 MemorySourceFlags m_ImportFlags;
232 bool m_Imported;
233 bool m_IsImportEnabled;
Mike Kelly4cc341c2023-07-07 15:43:06 +0100234 std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated;
telsoa014fcda012018-03-09 14:13:49 +0000235};
236
237class ClSubTensorHandle : public IClTensorHandle
238{
239public:
telsoa01c577f2c2018-08-31 09:22:23 +0100240 ClSubTensorHandle(IClTensorHandle* parent,
241 const arm_compute::TensorShape& shape,
242 const arm_compute::Coordinates& coords)
243 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000244 {
telsoa01c577f2c2018-08-31 09:22:23 +0100245 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000246 }
247
248 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
249 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000250
telsoa01c577f2c2018-08-31 09:22:23 +0100251 virtual void Allocate() override {}
252 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000253
telsoa01c577f2c2018-08-31 09:22:23 +0100254 virtual const void* Map(bool blocking = true) const override
255 {
256 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
257 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
258 }
259 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
260
telsoa01c577f2c2018-08-31 09:22:23 +0100261 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000262
263 virtual arm_compute::DataType GetDataType() const override
264 {
265 return m_Tensor.info()->data_type();
266 }
267
telsoa01c577f2c2018-08-31 09:22:23 +0100268 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
269
270 TensorShape GetStrides() const override
271 {
272 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
273 }
274
275 TensorShape GetShape() const override
276 {
277 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
278 }
279
telsoa014fcda012018-03-09 14:13:49 +0000280private:
David Beck09e2f272018-10-30 11:38:41 +0000281 // Only used for testing
282 void CopyOutTo(void* memory) const override
283 {
284 const_cast<ClSubTensorHandle*>(this)->Map(true);
285 switch(this->GetDataType())
286 {
287 case arm_compute::DataType::F32:
288 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
289 static_cast<float*>(memory));
290 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000291 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000292 case arm_compute::DataType::QASYMM8:
293 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
294 static_cast<uint8_t*>(memory));
295 break;
296 case arm_compute::DataType::F16:
297 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
298 static_cast<armnn::Half*>(memory));
299 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100300 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000301 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
302 case arm_compute::DataType::QASYMM8_SIGNED:
303 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304 static_cast<int8_t*>(memory));
305 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100306 case arm_compute::DataType::S16:
307 case arm_compute::DataType::QSYMM16:
308 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
309 static_cast<int16_t*>(memory));
310 break;
James Conroy2dc05722019-09-19 17:00:31 +0100311 case arm_compute::DataType::S32:
312 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
313 static_cast<int32_t*>(memory));
314 break;
David Beck09e2f272018-10-30 11:38:41 +0000315 default:
316 {
317 throw armnn::UnimplementedException();
318 }
319 }
320 const_cast<ClSubTensorHandle*>(this)->Unmap();
321 }
322
323 // Only used for testing
324 void CopyInFrom(const void* memory) override
325 {
326 this->Map(true);
327 switch(this->GetDataType())
328 {
329 case arm_compute::DataType::F32:
330 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
331 this->GetTensor());
332 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000333 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000334 case arm_compute::DataType::QASYMM8:
335 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
336 this->GetTensor());
337 break;
338 case arm_compute::DataType::F16:
339 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
340 this->GetTensor());
341 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100342 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000343 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
344 case arm_compute::DataType::QASYMM8_SIGNED:
345 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
346 this->GetTensor());
347 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100348 case arm_compute::DataType::S16:
349 case arm_compute::DataType::QSYMM16:
350 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
351 this->GetTensor());
352 break;
James Conroy2dc05722019-09-19 17:00:31 +0100353 case arm_compute::DataType::S32:
354 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
355 this->GetTensor());
356 break;
David Beck09e2f272018-10-30 11:38:41 +0000357 default:
358 {
359 throw armnn::UnimplementedException();
360 }
361 }
362 this->Unmap();
363 }
364
telsoa01c577f2c2018-08-31 09:22:23 +0100365 mutable arm_compute::CLSubTensor m_Tensor;
366 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000367};
368
Mike Kelly4cc341c2023-07-07 15:43:06 +0100369/** ClTensorDecorator wraps an existing CL tensor allowing us to override the TensorInfo for it */
370class ClTensorDecorator : public arm_compute::ICLTensor
371{
372public:
373 ClTensorDecorator();
374
375 ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& info);
376
377 ~ClTensorDecorator() = default;
378
379 ClTensorDecorator(const ClTensorDecorator&) = delete;
380
381 ClTensorDecorator& operator=(const ClTensorDecorator&) = delete;
382
383 ClTensorDecorator(ClTensorDecorator&&) = default;
384
385 ClTensorDecorator& operator=(ClTensorDecorator&&) = default;
386
387 arm_compute::ICLTensor* parent();
388
389 void map(bool blocking = true);
390 using arm_compute::ICLTensor::map;
391
392 void unmap();
393 using arm_compute::ICLTensor::unmap;
394
395 virtual arm_compute::ITensorInfo* info() const override;
396 virtual arm_compute::ITensorInfo* info() override;
397 const cl::Buffer& cl_buffer() const override;
398 arm_compute::CLQuantization quantization() const override;
399
400protected:
401 // Inherited methods overridden:
402 uint8_t* do_map(cl::CommandQueue& q, bool blocking) override;
403 void do_unmap(cl::CommandQueue& q) override;
404
405private:
406 arm_compute::ICLTensor* m_Original;
407 mutable arm_compute::TensorInfo m_TensorInfo;
408};
409
410class ClTensorHandleDecorator : public IClTensorHandle
411{
412public:
413 ClTensorHandleDecorator(IClTensorHandle* parent, const TensorInfo& info)
414 : m_Tensor(&parent->GetTensor(), info)
415 {
416 m_OriginalHandle = parent;
417 }
418
419 arm_compute::ICLTensor& GetTensor() override { return m_Tensor; }
420 arm_compute::ICLTensor const& GetTensor() const override { return m_Tensor; }
421
422 virtual void Allocate() override {}
423 virtual void Manage() override {}
424
425 virtual const void* Map(bool blocking = true) const override
426 {
427 m_Tensor.map(blocking);
428 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
429 }
430
431 virtual void Unmap() const override
432 {
433 m_Tensor.unmap();
434 }
435
436 virtual ITensorHandle* GetParent() const override { return nullptr; }
437
438 virtual arm_compute::DataType GetDataType() const override
439 {
440 return m_Tensor.info()->data_type();
441 }
442
443 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
444
445 TensorShape GetStrides() const override
446 {
447 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
448 }
449
450 TensorShape GetShape() const override
451 {
452 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
453 }
454
455private:
456 // Only used for testing
457 void CopyOutTo(void* memory) const override
458 {
459 const_cast<ClTensorHandleDecorator*>(this)->Map(true);
460 switch(this->GetDataType())
461 {
462 case arm_compute::DataType::F32:
463 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
464 static_cast<float*>(memory));
465 break;
466 case arm_compute::DataType::U8:
467 case arm_compute::DataType::QASYMM8:
468 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
469 static_cast<uint8_t*>(memory));
470 break;
471 case arm_compute::DataType::F16:
472 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
473 static_cast<armnn::Half*>(memory));
474 break;
475 case arm_compute::DataType::QSYMM8:
476 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
477 case arm_compute::DataType::QASYMM8_SIGNED:
478 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
479 static_cast<int8_t*>(memory));
480 break;
481 case arm_compute::DataType::S16:
482 case arm_compute::DataType::QSYMM16:
483 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
484 static_cast<int16_t*>(memory));
485 break;
486 case arm_compute::DataType::S32:
487 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
488 static_cast<int32_t*>(memory));
489 break;
490 default:
491 {
492 throw armnn::UnimplementedException();
493 }
494 }
495 const_cast<ClTensorHandleDecorator*>(this)->Unmap();
496 }
497
498 // Only used for testing
499 void CopyInFrom(const void* memory) override
500 {
501 this->Map(true);
502 switch(this->GetDataType())
503 {
504 case arm_compute::DataType::F32:
505 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
506 this->GetTensor());
507 break;
508 case arm_compute::DataType::U8:
509 case arm_compute::DataType::QASYMM8:
510 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
511 this->GetTensor());
512 break;
513 case arm_compute::DataType::F16:
514 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
515 this->GetTensor());
516 break;
517 case arm_compute::DataType::QSYMM8:
518 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
519 case arm_compute::DataType::QASYMM8_SIGNED:
520 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
521 this->GetTensor());
522 break;
523 case arm_compute::DataType::S16:
524 case arm_compute::DataType::QSYMM16:
525 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
526 this->GetTensor());
527 break;
528 case arm_compute::DataType::S32:
529 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
530 this->GetTensor());
531 break;
532 default:
533 {
534 throw armnn::UnimplementedException();
535 }
536 }
537 this->Unmap();
538 }
539
540 mutable ClTensorDecorator m_Tensor;
541 IClTensorHandle* m_OriginalHandle = nullptr;
542};
543
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000544} // namespace armnn