blob: feb43290b497b777581c88714c494e48d76d4ae8 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <backendsCommon/CpuTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
10#include <armnn/Tensor.hpp>
11#include <armnn/Types.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <Half.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
14#include <boost/polymorphic_cast.hpp>
15
16namespace armnn
17{
18
19////////////////////////////////////////////
20/// float32 helpers
21////////////////////////////////////////////
22
23inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle)
24{
25 // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate.
26 const ConstCpuTensorHandle* cpuTensorHandle =
27 boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle);
28 return cpuTensorHandle->GetTensorInfo();
29}
30
31template <typename DataType>
32inline const DataType* GetConstCpuData(const ITensorHandle* tensorHandle)
33{
34 // We know that reference workloads use (Const)CpuTensorHandles only, so this cast is legitimate.
35 const ConstCpuTensorHandle* cpuTensorHandle =
36 boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle);
37 return cpuTensorHandle->GetConstTensor<DataType>();
38}
39
40template <typename DataType>
41inline DataType* GetCpuData(const ITensorHandle* tensorHandle)
42{
43 // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate.
44 const CpuTensorHandle* cpuTensorHandle = boost::polymorphic_downcast<const CpuTensorHandle*>(tensorHandle);
45 return cpuTensorHandle->GetTensor<DataType>();
46};
47
48template <typename DataType, typename PayloadType>
49const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data)
50{
51 const ITensorHandle* tensorHandle = data.m_Inputs[idx];
52 return GetConstCpuData<DataType>(tensorHandle);
53}
54
55template <typename DataType, typename PayloadType>
56DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data)
57{
58 const ITensorHandle* tensorHandle = data.m_Outputs[idx];
59 return GetCpuData<DataType>(tensorHandle);
60}
61
62template <typename PayloadType>
63const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data)
64{
65 return GetInputTensorData<float>(idx, data);
66}
67
68template <typename PayloadType>
69float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data)
70{
71 return GetOutputTensorData<float>(idx, data);
72}
73
telsoa01c577f2c2018-08-31 09:22:23 +010074template <typename PayloadType>
75const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data)
76{
77 return GetInputTensorData<Half>(idx, data);
78}
79
80template <typename PayloadType>
81Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data)
82{
83 return GetOutputTensorData<Half>(idx, data);
84}
85
telsoa014fcda012018-03-09 14:13:49 +000086////////////////////////////////////////////
87/// u8 helpers
88////////////////////////////////////////////
89
90inline const uint8_t* GetConstCpuU8Data(const ITensorHandle* tensorHandle)
91{
92 // We know that reference workloads use (Const)CpuTensorHandles only, so this cast is legitimate.
93 const ConstCpuTensorHandle* cpuTensorHandle =
94 boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle);
95 return cpuTensorHandle->GetConstTensor<uint8_t>();
96};
97
98inline uint8_t* GetCpuU8Data(const ITensorHandle* tensorHandle)
99{
100 // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate.
101 const CpuTensorHandle* cpuTensorHandle = boost::polymorphic_downcast<const CpuTensorHandle*>(tensorHandle);
102 return cpuTensorHandle->GetTensor<uint8_t>();
103};
104
105template <typename PayloadType>
106const uint8_t* GetInputTensorDataU8(unsigned int idx, const PayloadType& data)
107{
108 const ITensorHandle* tensorHandle = data.m_Inputs[idx];
109 return GetConstCpuU8Data(tensorHandle);
110}
111
112template <typename PayloadType>
113uint8_t* GetOutputTensorDataU8(unsigned int idx, const PayloadType& data)
114{
115 const ITensorHandle* tensorHandle = data.m_Outputs[idx];
116 return GetCpuU8Data(tensorHandle);
117}
118
119template<typename T>
120std::vector<float> Dequantize(const T* quant, const TensorInfo& info)
121{
122 std::vector<float> ret(info.GetNumElements());
123 for (size_t i = 0; i < info.GetNumElements(); i++)
124 {
125 ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
126 }
127 return ret;
128}
129
130inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info)
131{
132 for (size_t i = 0; i < info.GetNumElements(); i++)
133 {
134 quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
135 }
136}
137
138} //namespace armnn