blob: 32dacdfc618fa6c867f621267aecad39f59cc1a6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <Half.hpp>
Matthew Bentham14e46692018-09-20 15:35:30 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <cl/OpenClTimer.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010012
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +010013#include <arm_compute/runtime/CL/CLFunctions.h>
14
15#include <sstream>
16
telsoa01c577f2c2018-08-31 09:22:23 +010017#define ARMNN_SCOPED_PROFILING_EVENT_CL(name) \
18 ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::GpuAcc, \
19 name, \
20 armnn::OpenClTimer(), \
21 armnn::WallClockTimer())
telsoa014fcda012018-03-09 14:13:49 +000022
23namespace armnn
24{
25
26template <typename T>
Matthew Benthamca6616c2018-09-21 15:16:53 +010027void CopyArmComputeClTensorData(arm_compute::CLTensor& dstTensor, const T* srcData)
telsoa014fcda012018-03-09 14:13:49 +000028{
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 ARMNN_SCOPED_PROFILING_EVENT_CL("MapClTensorForWriting");
telsoa014fcda012018-03-09 14:13:49 +000031 dstTensor.map(true);
32 }
33
34 {
telsoa01c577f2c2018-08-31 09:22:23 +010035 ARMNN_SCOPED_PROFILING_EVENT_CL("CopyToClTensor");
telsoa014fcda012018-03-09 14:13:49 +000036 armcomputetensorutils::CopyArmComputeITensorData<T>(srcData, dstTensor);
37 }
38
39 dstTensor.unmap();
40}
41
keidav01d74dc912018-12-10 18:16:07 +000042inline auto SetClStridedSliceData(const std::vector<int>& m_begin,
43 const std::vector<int>& m_end,
44 const std::vector<int>& m_stride)
45{
46 arm_compute::Coordinates starts;
47 arm_compute::Coordinates ends;
48 arm_compute::Coordinates strides;
49
50 unsigned int num_dims = static_cast<unsigned int>(m_begin.size());
51
52 for (unsigned int i = 0; i < num_dims; i++) {
53 unsigned int revertedIndex = num_dims - i - 1;
54
55 starts.set(i, static_cast<int>(m_begin[revertedIndex]));
56 ends.set(i, static_cast<int>(m_end[revertedIndex]));
57 strides.set(i, static_cast<int>(m_stride[revertedIndex]));
58 }
59
60 return std::make_tuple(starts, ends, strides);
61}
62
Matthew Bentham785df502018-09-21 10:29:58 +010063inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
64 const ConstCpuTensorHandle* handle)
telsoa01c577f2c2018-08-31 09:22:23 +010065{
66 BOOST_ASSERT(handle);
Matthew Benthamca6616c2018-09-21 15:16:53 +010067
68 armcomputetensorutils::InitialiseArmComputeTensorEmpty(clTensor);
telsoa01c577f2c2018-08-31 09:22:23 +010069 switch(handle->GetTensorInfo().GetDataType())
70 {
71 case DataType::Float16:
Matthew Benthamca6616c2018-09-21 15:16:53 +010072 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<armnn::Half>());
telsoa01c577f2c2018-08-31 09:22:23 +010073 break;
74 case DataType::Float32:
Matthew Benthamca6616c2018-09-21 15:16:53 +010075 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
telsoa01c577f2c2018-08-31 09:22:23 +010076 break;
Matthew Bentham785df502018-09-21 10:29:58 +010077 case DataType::QuantisedAsymm8:
Matthew Benthamca6616c2018-09-21 15:16:53 +010078 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
Matthew Bentham785df502018-09-21 10:29:58 +010079 break;
80 case DataType::Signed32:
Matthew Benthamca6616c2018-09-21 15:16:53 +010081 CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
Matthew Bentham785df502018-09-21 10:29:58 +010082 break;
telsoa01c577f2c2018-08-31 09:22:23 +010083 default:
Matthew Bentham785df502018-09-21 10:29:58 +010084 BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
telsoa01c577f2c2018-08-31 09:22:23 +010085 }
86};
87
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +010088inline RuntimeException WrapClError(const cl::Error& clError, const CheckLocation& location)
89{
90 std::stringstream message;
91 message << "CL error: " << clError.what() << ". Error code: " << clError.err();
92
93 return RuntimeException(message.str(), location);
94}
95
96inline void RunClFunction(arm_compute::IFunction& function, const CheckLocation& location)
97{
98 try
99 {
100 function.run();
101 }
102 catch (cl::Error& error)
103 {
104 throw WrapClError(error, location);
105 }
106}
107
telsoa014fcda012018-03-09 14:13:49 +0000108} //namespace armnn