blob: f513d28f9d751223f1f64e90a673c6741de9dc4b [file] [log] [blame]
arovir01b0717b52018-09-05 17:03:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ConversionUtils.hpp"
7
8///
9/// Helper classes
10///
11
12namespace armnn_driver
13{
14
15LayerInputHandle::LayerInputHandle()
16 : m_OutputSlot(nullptr)
17 , m_Valid(false)
18{}
19
20LayerInputHandle::LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo)
21 : m_OutputSlot(outputSlot)
22 , m_Valid(valid)
23 , m_TensorInfo(tensorInfo)
24{}
25
26bool LayerInputHandle::IsValid() const
27{
28 return m_Valid;
29}
30
31void LayerInputHandle::Connect(armnn::IInputSlot& inputSlot)
32{
33 BOOST_ASSERT(IsValid());
34 if (m_OutputSlot)
35 {
36 m_OutputSlot->Connect(inputSlot);
37 }
38}
39
40const armnn::TensorInfo& LayerInputHandle::GetTensorInfo() const
41{
42 return m_TensorInfo;
43}
44
45ConstTensorPin::ConstTensorPin(bool optional)
46 : m_Optional(optional)
47{}
48
49ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo,
50 const void* valueStart,
51 uint32_t numBytes,
52 const armnn::PermutationVector& mappings)
53{
54 boost::ignore_unused(numBytes);
55 assert(tensorInfo.GetNumBytes() == numBytes);
56
57 const bool needsSwizzling = (mappings.GetSize() > 0);
58 if (needsSwizzling)
59 {
60 m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
61 SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
62
63 m_ConstTensor = armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, mappings), m_SwizzledTensorData.data());
64 }
65 else
66 {
67 m_ConstTensor = armnn::ConstTensor(tensorInfo, valueStart);
68 }
69}
70
71bool ConstTensorPin::IsValid() const
72{
73 return m_ConstTensor.GetMemoryArea() != nullptr;
74}
75
76bool ConstTensorPin::IsOptional() const
77{
78 return m_Optional;
79}
80
81const armnn::ConstTensor& ConstTensorPin::GetConstTensor() const
82{
83 return m_ConstTensor;
84}
85
86const armnn::ConstTensor* ConstTensorPin::GetConstTensorPtr() const
87{
88 if (IsValid() && m_ConstTensor.GetNumElements() > 0)
89 {
90 return &m_ConstTensor;
91 }
92 // tensor is either invalid, or has no elements (indicating an optional tensor that was not provided)
93 return nullptr;
94}
95
96///
97/// Utility functions
98///
99
100armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
101 ActivationFn activation,
102 armnn::IConnectableLayer* prevLayer,
103 ConversionData& data)
104{
105 BOOST_ASSERT(prevLayer->GetNumOutputSlots() == 1);
106
107 prevLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
108
109 armnn::IConnectableLayer* activationLayer = prevLayer;
110
111 if (activation != ActivationFn::kActivationNone)
112 {
113 armnn::ActivationDescriptor activationDesc;
114 switch (activation)
115 {
116 case ActivationFn::kActivationRelu:
117 {
118 activationDesc.m_Function = armnn::ActivationFunction::ReLu;
119 break;
120 }
121 case ActivationFn::kActivationRelu1:
122 {
123 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
124 activationDesc.m_A = 1.0f;
125 activationDesc.m_B = -1.0f;
126 break;
127 }
128 case ActivationFn::kActivationRelu6:
129 {
130 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
131 activationDesc.m_A = 6.0f;
132 break;
133 }
134 case ActivationFn::kActivationSigmoid:
135 {
136 activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
137 break;
138 }
139 case ActivationFn::kActivationTanh:
140 {
141 activationDesc.m_Function = armnn::ActivationFunction::TanH;
142 activationDesc.m_A = 1.0f;
143 activationDesc.m_B = 1.0f;
144 break;
145 }
146 default:
147 {
148 Fail("%s: Invalid activation enum value %i", __func__, activation);
149 return nullptr;
150 }
151 }
152
Ferran Balaguerd30093c2019-07-09 17:04:47 +0100153 bool isSupported = false;
154 FORWARD_LAYER_SUPPORT_FUNC(__func__,
155 IsActivationSupported,
156 data.m_Backends,
157 isSupported,
158 prevLayer->GetOutputSlot(0).GetTensorInfo(),
159 tensorInfo,
160 activationDesc);
161 if (!isSupported)
arovir01b0717b52018-09-05 17:03:25 +0100162 {
163 return nullptr;
164 }
165
166 activationLayer = data.m_Network->AddActivationLayer(activationDesc);
167
168 prevLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
169 activationLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
170 }
171
172 return activationLayer;
173}
174
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +0100175} // namespace armnn_driver