blob: 7168effe0cfbbfc04688eea46d4955b120b2df5e [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
8#include "Optimization.hpp"
9#include "backends/CpuTensorHandle.hpp"
10#include "Half.hpp"
11#include "FloatingPointConverter.hpp"
12
13namespace armnn
14{
15namespace optimizations
16{
17
18struct Float16ToFloat32
19{
20 static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
21 {
22 const TensorInfo& info = handle->GetTensorInfo();
23
24 if (info.GetDataType() == DataType::Float16)
25 {
26 std::vector<float> newValues(info.GetNumElements());
27
28 armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetTensor<Half>(),
29 info.GetNumElements(),
30 newValues.data());
31
32 TensorInfo newInfo(info.GetShape(), DataType::Float32);
33 ConstTensor newInput(newInfo, newValues);
34 handle.reset(new ScopedCpuTensorHandle(newInput));
35 }
36 }
37};
38
39struct Float32ToFloat16
40{
41 static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
42 {
43 const TensorInfo& info = handle->GetTensorInfo();
44
45 if (info.GetDataType() == DataType::Float32)
46 {
47 std::vector<Half> newValues(info.GetNumElements());
48
49 armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetTensor<float>(),
50 info.GetNumElements(),
51 newValues.data());
52
53 TensorInfo newInfo(info.GetShape(), DataType::Float16);
54 ConstTensor newInput(newInfo, newValues);
55 handle.reset(new ScopedCpuTensorHandle(newInput));
56 }
57 }
58};
59
60template<typename Converter, typename Predicate>
61class ConvertConstants : public Optimization
62{
63public:
64 ConvertConstants() = default;
65 ConvertConstants(const ConvertConstants&) = default;
66 virtual ~ConvertConstants() = default;
67
68 void Run(Graph& graph, Layer& layer) const override
69 {
70 if (Predicate::Test(layer))
71 {
72 layer.OperateOnConstantTensors(Converter::Func);
73 }
74 }
75protected:
76};
77
78struct IsFloat32Layer
79{
80 static bool Test(const Layer& layer)
81 {
82 return layer.GetDataType() == DataType::Float32;
83 }
84};
85
86struct IsFloat16Layer
87{
88 static bool Test(const Layer& layer)
89 {
90 return layer.GetDataType() == DataType::Float16;
91 }
92};
93
94using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>;
95using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>;
96
97} //namespace optimizations
98} //namespace armnn