blob: d3c6df50edc2d04dd219825d670b3a17a3346509 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <Half.hpp>
Matthew Bentham14e46692018-09-20 15:35:30 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <cl/OpenClTimer.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010012
Derek Lambertid466a542020-01-22 15:37:29 +000013#include <armnn/Utils.hpp>
14
Matthew Bentham9b3e7382020-02-05 21:39:55 +000015#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/IFunction.h>
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +010017
18#include <sstream>
19
telsoa01c577f2c2018-08-31 09:22:23 +010020#define ARMNN_SCOPED_PROFILING_EVENT_CL(name) \
21 ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::GpuAcc, \
22 name, \
23 armnn::OpenClTimer(), \
24 armnn::WallClockTimer())
telsoa014fcda012018-03-09 14:13:49 +000025
26namespace armnn
27{
28
29template <typename T>
Matthew Benthamca6616c2018-09-21 15:16:53 +010030void CopyArmComputeClTensorData(arm_compute::CLTensor& dstTensor, const T* srcData)
telsoa014fcda012018-03-09 14:13:49 +000031{
32 {
telsoa01c577f2c2018-08-31 09:22:23 +010033 ARMNN_SCOPED_PROFILING_EVENT_CL("MapClTensorForWriting");
telsoa014fcda012018-03-09 14:13:49 +000034 dstTensor.map(true);
35 }
36
37 {
telsoa01c577f2c2018-08-31 09:22:23 +010038 ARMNN_SCOPED_PROFILING_EVENT_CL("CopyToClTensor");
telsoa014fcda012018-03-09 14:13:49 +000039 armcomputetensorutils::CopyArmComputeITensorData<T>(srcData, dstTensor);
40 }
41
42 dstTensor.unmap();
43}
44
keidav01d74dc912018-12-10 18:16:07 +000045inline auto SetClStridedSliceData(const std::vector<int>& m_begin,
46 const std::vector<int>& m_end,
47 const std::vector<int>& m_stride)
48{
49 arm_compute::Coordinates starts;
50 arm_compute::Coordinates ends;
51 arm_compute::Coordinates strides;
52
53 unsigned int num_dims = static_cast<unsigned int>(m_begin.size());
54
55 for (unsigned int i = 0; i < num_dims; i++) {
56 unsigned int revertedIndex = num_dims - i - 1;
57
58 starts.set(i, static_cast<int>(m_begin[revertedIndex]));
59 ends.set(i, static_cast<int>(m_end[revertedIndex]));
60 strides.set(i, static_cast<int>(m_stride[revertedIndex]));
61 }
62
63 return std::make_tuple(starts, ends, strides);
64}
65
Aron Virginas-Tar94c4fef2019-11-25 15:37:08 +000066inline auto SetClSliceData(const std::vector<unsigned int>& m_begin,
67 const std::vector<unsigned int>& m_size)
68{
69 // This function must translate the size vector given to an end vector
70 // expected by the ACL NESlice workload
71 arm_compute::Coordinates starts;
72 arm_compute::Coordinates ends;
73
74 unsigned int num_dims = static_cast<unsigned int>(m_begin.size());
75
76 // For strided slices, we have the relationship size = (end - begin) / stride
77 // For slice, we assume stride to be a vector of all ones, yielding the formula
78 // size = (end - begin) therefore we know end = size + begin
79 for (unsigned int i = 0; i < num_dims; i++)
80 {
81 unsigned int revertedIndex = num_dims - i - 1;
82
83 starts.set(i, static_cast<int>(m_begin[revertedIndex]));
84 ends.set(i, static_cast<int>(m_begin[revertedIndex] + m_size[revertedIndex]));
85 }
86
87 return std::make_tuple(starts, ends);
88}
89
Matthew Bentham785df502018-09-21 10:29:58 +010090inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
91 const ConstCpuTensorHandle* handle)
telsoa01c577f2c2018-08-31 09:22:23 +010092{
93 BOOST_ASSERT(handle);
Matthew Benthamca6616c2018-09-21 15:16:53 +010094
95 armcomputetensorutils::InitialiseArmComputeTensorEmpty(clTensor);
telsoa01c577f2c2018-08-31 09:22:23 +010096 switch(handle->GetTensorInfo().GetDataType())
97 {
98 case DataType::Float16:
Matthew Benthamca6616c2018-09-21 15:16:53 +010099 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<armnn::Half>());
telsoa01c577f2c2018-08-31 09:22:23 +0100100 break;
101 case DataType::Float32:
Matthew Benthamca6616c2018-09-21 15:16:53 +0100102 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
telsoa01c577f2c2018-08-31 09:22:23 +0100103 break;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000104 case DataType::QAsymmU8:
Matthew Benthamca6616c2018-09-21 15:16:53 +0100105 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
Matthew Bentham785df502018-09-21 10:29:58 +0100106 break;
Derek Lambertid466a542020-01-22 15:37:29 +0000107 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis899f64f2019-11-26 16:01:18 +0000108 case DataType::QuantizedSymm8PerAxis:
Derek Lambertid466a542020-01-22 15:37:29 +0000109 ARMNN_FALLTHROUGH;
110 case DataType::QSymmS8:
Keith Davis899f64f2019-11-26 16:01:18 +0000111 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int8_t>());
112 break;
Derek Lambertid466a542020-01-22 15:37:29 +0000113 ARMNN_NO_DEPRECATE_WARN_END
Matthew Bentham785df502018-09-21 10:29:58 +0100114 case DataType::Signed32:
Matthew Benthamca6616c2018-09-21 15:16:53 +0100115 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
Matthew Bentham785df502018-09-21 10:29:58 +0100116 break;
telsoa01c577f2c2018-08-31 09:22:23 +0100117 default:
Matthew Bentham785df502018-09-21 10:29:58 +0100118 BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
telsoa01c577f2c2018-08-31 09:22:23 +0100119 }
120};
121
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +0100122inline RuntimeException WrapClError(const cl::Error& clError, const CheckLocation& location)
123{
124 std::stringstream message;
125 message << "CL error: " << clError.what() << ". Error code: " << clError.err();
126
127 return RuntimeException(message.str(), location);
128}
129
130inline void RunClFunction(arm_compute::IFunction& function, const CheckLocation& location)
131{
132 try
133 {
134 function.run();
135 }
136 catch (cl::Error& error)
137 {
138 throw WrapClError(error, location);
139 }
140}
141
telsoa014fcda012018-03-09 14:13:49 +0000142} //namespace armnn