blob: 66b3d2685a6a350ae696b7bab2376257ce51dbc4 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
8#include "Optimization.hpp"
arovir01616e7752018-10-01 17:08:59 +01009
Matteo Martincighe011d202019-11-28 11:35:47 +000010#include <armnnUtils/FloatingPointConverter.hpp>
11
James Conroy1f58f032021-04-27 17:13:27 +010012#include <backendsCommon/TensorHandle.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013
Jan Eilers8eb25602020-03-09 12:13:48 +000014#include <armnn/utility/IgnoreUnused.hpp>
Derek Lamberti9be61282019-12-10 21:42:57 +000015
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000016#include <BFloat16.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <Half.hpp>
arovir01616e7752018-10-01 17:08:59 +010018
telsoa01c577f2c2018-08-31 09:22:23 +010019namespace armnn
20{
21namespace optimizations
22{
23
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000024struct BFloat16ToFloat32
25{
James Conroy1f58f032021-04-27 17:13:27 +010026 static void Func(std::shared_ptr<ConstTensorHandle>& handle)
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000027 {
28 const TensorInfo& info = handle->GetTensorInfo();
29
30 if (info.GetDataType() == DataType::BFloat16)
31 {
32 std::vector<float> newValues(info.GetNumElements());
33
Finn Williams4422cec2021-03-22 17:51:06 +000034 armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(handle->GetConstTensor<BFloat16>(),
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000035 info.GetNumElements(),
36 newValues.data());
37
38 TensorInfo newInfo(info.GetShape(), DataType::Float32);
39 ConstTensor newInput(newInfo, newValues);
James Conroy1f58f032021-04-27 17:13:27 +010040 handle.reset(new ScopedTensorHandle(newInput));
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000041 }
42 }
43};
44
telsoa01c577f2c2018-08-31 09:22:23 +010045struct Float16ToFloat32
46{
James Conroy1f58f032021-04-27 17:13:27 +010047 static void Func(std::shared_ptr<ConstTensorHandle>& handle)
telsoa01c577f2c2018-08-31 09:22:23 +010048 {
49 const TensorInfo& info = handle->GetTensorInfo();
50
51 if (info.GetDataType() == DataType::Float16)
52 {
53 std::vector<float> newValues(info.GetNumElements());
54
Finn Williams4422cec2021-03-22 17:51:06 +000055 armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetConstTensor<Half>(),
telsoa01c577f2c2018-08-31 09:22:23 +010056 info.GetNumElements(),
57 newValues.data());
58
59 TensorInfo newInfo(info.GetShape(), DataType::Float32);
60 ConstTensor newInput(newInfo, newValues);
James Conroy1f58f032021-04-27 17:13:27 +010061 handle.reset(new ScopedTensorHandle(newInput));
telsoa01c577f2c2018-08-31 09:22:23 +010062 }
63 }
64};
65
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000066struct Float32ToBFloat16
67{
James Conroy1f58f032021-04-27 17:13:27 +010068 static void Func(std::shared_ptr<ConstTensorHandle>& handle)
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000069 {
70 const TensorInfo& info = handle->GetTensorInfo();
71
72 if (info.GetDataType() == DataType::Float32)
73 {
74 std::vector<BFloat16> newValues(info.GetNumElements());
75
Finn Williams4422cec2021-03-22 17:51:06 +000076 armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(handle->GetConstTensor<float>(),
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000077 info.GetNumElements(),
78 newValues.data());
79
80 TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
81 ConstTensor newInput(newInfo, newValues);
James Conroy1f58f032021-04-27 17:13:27 +010082 handle.reset(new ScopedTensorHandle(newInput));
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +000083 }
84 }
85};
86
telsoa01c577f2c2018-08-31 09:22:23 +010087struct Float32ToFloat16
88{
James Conroy1f58f032021-04-27 17:13:27 +010089 static void Func(std::shared_ptr<ConstTensorHandle>& handle)
telsoa01c577f2c2018-08-31 09:22:23 +010090 {
91 const TensorInfo& info = handle->GetTensorInfo();
92
93 if (info.GetDataType() == DataType::Float32)
94 {
95 std::vector<Half> newValues(info.GetNumElements());
96
Finn Williams4422cec2021-03-22 17:51:06 +000097 armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(),
telsoa01c577f2c2018-08-31 09:22:23 +010098 info.GetNumElements(),
99 newValues.data());
100
101 TensorInfo newInfo(info.GetShape(), DataType::Float16);
102 ConstTensor newInput(newInfo, newValues);
James Conroy1f58f032021-04-27 17:13:27 +0100103 handle.reset(new ScopedTensorHandle(newInput));
telsoa01c577f2c2018-08-31 09:22:23 +0100104 }
105 }
106};
107
108template<typename Converter, typename Predicate>
109class ConvertConstants : public Optimization
110{
111public:
112 ConvertConstants() = default;
113 ConvertConstants(const ConvertConstants&) = default;
114 virtual ~ConvertConstants() = default;
115
116 void Run(Graph& graph, Layer& layer) const override
117 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000118 IgnoreUnused(graph);
telsoa01c577f2c2018-08-31 09:22:23 +0100119 if (Predicate::Test(layer))
120 {
121 layer.OperateOnConstantTensors(Converter::Func);
122 }
123 }
124protected:
125};
126
127struct IsFloat32Layer
128{
129 static bool Test(const Layer& layer)
130 {
131 return layer.GetDataType() == DataType::Float32;
132 }
133};
134
135struct IsFloat16Layer
136{
137 static bool Test(const Layer& layer)
138 {
139 return layer.GetDataType() == DataType::Float16;
140 }
141};
142
Narumol Prangnawaratbc7ffb52020-03-20 15:01:01 +0000143struct IsBFloat16Layer
144{
145 static bool Test(const Layer& layer)
146 {
147 return layer.GetDataType() == DataType::BFloat16;
148 }
149};
150
151using ConvertConstantsBFloatToFloat = ConvertConstants<BFloat16ToFloat32, IsFloat32Layer>;
152using ConvertConstantsFloatToBFloat = ConvertConstants<Float32ToBFloat16, IsBFloat16Layer>;
153
telsoa01c577f2c2018-08-31 09:22:23 +0100154using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>;
155using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>;
156
157} //namespace optimizations
158} //namespace armnn