Laurent Carlier | 749294b | 2020-06-01 09:03:17 +0100 | [diff] [blame] | 1 | // |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
Colm Donelan | 0c47974 | 2021-12-10 12:43:54 +0000 | [diff] [blame] | 8 | #include <armnn/backends/TensorHandle.hpp> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 9 | |
| 10 | #include <armnn/Tensor.hpp> |
| 11 | #include <armnn/Types.hpp> |
Jan Eilers | bb446e5 | 2020-04-02 13:56:54 +0100 | [diff] [blame] | 12 | #include <armnn/utility/PolymorphicDowncast.hpp> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 13 | |
Matthew Bentham | 4cefc41 | 2019-06-18 16:14:34 +0100 | [diff] [blame] | 14 | #include <reference/RefTensorHandle.hpp> |
| 15 | |
Narumol Prangnawarat | 7ddbbae | 2020-03-13 10:26:05 +0000 | [diff] [blame] | 16 | #include <BFloat16.hpp> |
Matthew Bentham | 4cefc41 | 2019-06-18 16:14:34 +0100 | [diff] [blame] | 17 | #include <Half.hpp> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 18 | |
| 19 | namespace armnn |
| 20 | { |
| 21 | |
| 22 | //////////////////////////////////////////// |
| 23 | /// float32 helpers |
| 24 | //////////////////////////////////////////// |
| 25 | |
| 26 | inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle) |
| 27 | { |
Matthew Bentham | 4cefc41 | 2019-06-18 16:14:34 +0100 | [diff] [blame] | 28 | // We know that reference workloads use RefTensorHandles for inputs and outputs |
| 29 | const RefTensorHandle* refTensorHandle = |
Jan Eilers | bb446e5 | 2020-04-02 13:56:54 +0100 | [diff] [blame] | 30 | PolymorphicDowncast<const RefTensorHandle*>(tensorHandle); |
Matthew Bentham | 4cefc41 | 2019-06-18 16:14:34 +0100 | [diff] [blame] | 31 | return refTensorHandle->GetTensorInfo(); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 32 | } |
| 33 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 34 | template <typename DataType, typename PayloadType> |
| 35 | const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data) |
| 36 | { |
| 37 | const ITensorHandle* tensorHandle = data.m_Inputs[idx]; |
Matthew Bentham | 4cefc41 | 2019-06-18 16:14:34 +0100 | [diff] [blame] | 38 | return reinterpret_cast<const DataType*>(tensorHandle->Map()); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 39 | } |
| 40 | |
| 41 | template <typename DataType, typename PayloadType> |
| 42 | DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data) |
| 43 | { |
Matthew Bentham | 4cefc41 | 2019-06-18 16:14:34 +0100 | [diff] [blame] | 44 | ITensorHandle* tensorHandle = data.m_Outputs[idx]; |
| 45 | return reinterpret_cast<DataType*>(tensorHandle->Map()); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 46 | } |
| 47 | |
Finn Williams | 0109794 | 2021-04-26 12:06:34 +0100 | [diff] [blame] | 48 | template <typename DataType> |
| 49 | DataType* GetOutputTensorData(ITensorHandle* tensorHandle) |
| 50 | { |
| 51 | return reinterpret_cast<DataType*>(tensorHandle->Map()); |
| 52 | } |
| 53 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 54 | template <typename PayloadType> |
| 55 | const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data) |
| 56 | { |
| 57 | return GetInputTensorData<float>(idx, data); |
| 58 | } |
| 59 | |
| 60 | template <typename PayloadType> |
| 61 | float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data) |
| 62 | { |
| 63 | return GetOutputTensorData<float>(idx, data); |
| 64 | } |
| 65 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 66 | template <typename PayloadType> |
| 67 | const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data) |
| 68 | { |
| 69 | return GetInputTensorData<Half>(idx, data); |
| 70 | } |
| 71 | |
| 72 | template <typename PayloadType> |
| 73 | Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data) |
| 74 | { |
| 75 | return GetOutputTensorData<Half>(idx, data); |
| 76 | } |
| 77 | |
Narumol Prangnawarat | 7ddbbae | 2020-03-13 10:26:05 +0000 | [diff] [blame] | 78 | template <typename PayloadType> |
| 79 | const BFloat16* GetInputTensorDataBFloat16(unsigned int idx, const PayloadType& data) |
| 80 | { |
| 81 | return GetInputTensorData<BFloat16>(idx, data); |
| 82 | } |
| 83 | |
Narumol Prangnawarat | ea54a01 | 2020-03-16 16:36:10 +0000 | [diff] [blame] | 84 | template <typename PayloadType> |
| 85 | BFloat16* GetOutputTensorDataBFloat16(unsigned int idx, const PayloadType& data) |
| 86 | { |
| 87 | return GetOutputTensorData<BFloat16>(idx, data); |
| 88 | } |
| 89 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 90 | //////////////////////////////////////////// |
| 91 | /// u8 helpers |
| 92 | //////////////////////////////////////////// |
| 93 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 94 | template<typename T> |
| 95 | std::vector<float> Dequantize(const T* quant, const TensorInfo& info) |
| 96 | { |
| 97 | std::vector<float> ret(info.GetNumElements()); |
| 98 | for (size_t i = 0; i < info.GetNumElements(); i++) |
| 99 | { |
| 100 | ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset()); |
| 101 | } |
| 102 | return ret; |
| 103 | } |
| 104 | |
Nattapat Chaimanowong | 8a54ac0 | 2019-03-29 15:25:04 +0000 | [diff] [blame] | 105 | template<typename T> |
| 106 | inline void Dequantize(const T* inputData, float* outputData, const TensorInfo& info) |
| 107 | { |
| 108 | for (unsigned int i = 0; i < info.GetNumElements(); i++) |
| 109 | { |
| 110 | outputData[i] = Dequantize<T>(inputData[i], info.GetQuantizationScale(), info.GetQuantizationOffset()); |
| 111 | } |
| 112 | } |
| 113 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 114 | inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info) |
| 115 | { |
| 116 | for (size_t i = 0; i < info.GetNumElements(); i++) |
| 117 | { |
| 118 | quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset()); |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | } //namespace armnn |