blob: 878280161a7967d7435bc1acc066998ff2d44aca [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "NeonWorkloadUtils.hpp"
6#include "backends/ArmComputeTensorUtils.hpp"
7#include "backends/ArmComputeUtils.hpp"
8#include "backends/CpuTensorHandle.hpp"
9#include "backends/NeonTensorHandle.hpp"
10
11#include "armnn/Utils.hpp"
12#include "armnn/Exceptions.hpp"
13
telsoa014fcda012018-03-09 14:13:49 +000014#include <cstring>
15#include <boost/assert.hpp>
16#include <boost/cast.hpp>
17#include <boost/format.hpp>
18
19#include "Profiling.hpp"
20
21#include "NeonLayerSupport.hpp"
David Beck10b4dfd2018-09-19 12:03:20 +010022#include "armnn/Types.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010023#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +000024
25using namespace armnn::armcomputetensorutils;
26
27namespace armnn
28{
29
telsoa01c577f2c2018-08-31 09:22:23 +010030// Allocates a tensor and copy the contents in data to the tensor contents.
telsoa014fcda012018-03-09 14:13:49 +000031template<typename T>
32void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const T* data)
33{
34 InitialiseArmComputeTensorEmpty(tensor);
35 CopyArmComputeITensorData(data, tensor);
36}
37
telsoa01c577f2c2018-08-31 09:22:23 +010038template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const Half* data);
telsoa014fcda012018-03-09 14:13:49 +000039template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const float* data);
40template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const uint8_t* data);
41template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const int32_t* data);
42
telsoa01c577f2c2018-08-31 09:22:23 +010043void InitializeArmComputeTensorDataForFloatTypes(arm_compute::Tensor& tensor,
44 const ConstCpuTensorHandle* handle)
45{
46 BOOST_ASSERT(handle);
47 switch(handle->GetTensorInfo().GetDataType())
48 {
49 case DataType::Float16:
50 InitialiseArmComputeTensorData(tensor, handle->GetConstTensor<Half>());
51 break;
52 case DataType::Float32:
53 InitialiseArmComputeTensorData(tensor, handle->GetConstTensor<float>());
54 break;
55 default:
56 BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
57 }
58};
59
telsoa014fcda012018-03-09 14:13:49 +000060} //namespace armnn