blob: c3260c81422a80a05a683e91c4517d7c5a64fefa [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <backendsCommon/CpuTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
10#include <armnn/Tensor.hpp>
11#include <armnn/Types.hpp>
12
Matthew Bentham4cefc412019-06-18 16:14:34 +010013#include <reference/RefTensorHandle.hpp>
14
15#include <Half.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016#include <boost/polymorphic_cast.hpp>
17
18namespace armnn
19{
20
21////////////////////////////////////////////
22/// float32 helpers
23////////////////////////////////////////////
24
25inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle)
26{
Matthew Bentham4cefc412019-06-18 16:14:34 +010027 // We know that reference workloads use RefTensorHandles for inputs and outputs
28 const RefTensorHandle* refTensorHandle =
29 boost::polymorphic_downcast<const RefTensorHandle*>(tensorHandle);
30 return refTensorHandle->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +000031}
32
telsoa014fcda012018-03-09 14:13:49 +000033template <typename DataType, typename PayloadType>
34const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data)
35{
36 const ITensorHandle* tensorHandle = data.m_Inputs[idx];
Matthew Bentham4cefc412019-06-18 16:14:34 +010037 return reinterpret_cast<const DataType*>(tensorHandle->Map());
telsoa014fcda012018-03-09 14:13:49 +000038}
39
40template <typename DataType, typename PayloadType>
41DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data)
42{
Matthew Bentham4cefc412019-06-18 16:14:34 +010043 ITensorHandle* tensorHandle = data.m_Outputs[idx];
44 return reinterpret_cast<DataType*>(tensorHandle->Map());
telsoa014fcda012018-03-09 14:13:49 +000045}
46
47template <typename PayloadType>
48const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data)
49{
50 return GetInputTensorData<float>(idx, data);
51}
52
53template <typename PayloadType>
54float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data)
55{
56 return GetOutputTensorData<float>(idx, data);
57}
58
telsoa01c577f2c2018-08-31 09:22:23 +010059template <typename PayloadType>
60const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data)
61{
62 return GetInputTensorData<Half>(idx, data);
63}
64
65template <typename PayloadType>
66Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data)
67{
68 return GetOutputTensorData<Half>(idx, data);
69}
70
telsoa014fcda012018-03-09 14:13:49 +000071////////////////////////////////////////////
72/// u8 helpers
73////////////////////////////////////////////
74
telsoa014fcda012018-03-09 14:13:49 +000075template<typename T>
76std::vector<float> Dequantize(const T* quant, const TensorInfo& info)
77{
78 std::vector<float> ret(info.GetNumElements());
79 for (size_t i = 0; i < info.GetNumElements(); i++)
80 {
81 ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
82 }
83 return ret;
84}
85
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +000086template<typename T>
87inline void Dequantize(const T* inputData, float* outputData, const TensorInfo& info)
88{
89 for (unsigned int i = 0; i < info.GetNumElements(); i++)
90 {
91 outputData[i] = Dequantize<T>(inputData[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
92 }
93}
94
telsoa014fcda012018-03-09 14:13:49 +000095inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info)
96{
97 for (size_t i = 0; i < info.GetNumElements(); i++)
98 {
99 quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
100 }
101}
102
103} //namespace armnn