blob: 5014e7527b4b348ba412267a9c4b0a0e9e4eeb5c [file] [log] [blame]
Sadik Armagan1153d1e2020-04-01 15:09:39 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ConversionUtils_1_2.hpp"
9
10using Half = half_float::half;
11
12namespace armnn_driver
13{
14
15using namespace armnn;
16using namespace android::nn;
17
18template<typename HalPolicy,
19 typename HalOperation = typename HalPolicy::Operation,
20 typename HalModel = typename HalPolicy::Model>
21bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
22{
23 using HalOperandType = typename HalPolicy::OperandType;
24
25 LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
26 if (!input0.IsValid())
27 {
28 return Fail("%s: Operation has invalid inputs", __func__);
29 }
30
31 // Determine data type of input tensor
32 HalOperandType inputType;
33 if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
34 {
35 return Fail("%s: Operation has invalid inputs", __func__);
36 }
37
38 ActivationDescriptor desc;
39 desc.m_Function = ActivationFunction::Elu;
40
41 // Read alpha
42 if (inputType == HalOperandType::TENSOR_FLOAT16)
43 {
44 Half alpha;
45
46 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
47 {
48 return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
49 }
50
51 desc.m_A = static_cast<float>(alpha);
52 }
53 else if (inputType == HalOperandType::TENSOR_FLOAT32)
54 {
55 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
56 {
57 return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
58 }
59 }
60 else
61 {
62 return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
63 }
64
65 return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
66}
67
68} // armnn_driver namespace