blob: 56e72a3cb234cdb6684b5914e3f0320d826249ae [file] [log] [blame]
arovir01b0717b52018-09-05 17:03:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ConversionUtils.hpp"
Mike Kelly4a956582020-02-28 10:32:09 +00007#include <armnnUtils/Permute.hpp>
arovir01b0717b52018-09-05 17:03:25 +01008
9///
10/// Helper classes
11///
12
13namespace armnn_driver
14{
15
16LayerInputHandle::LayerInputHandle()
17 : m_OutputSlot(nullptr)
18 , m_Valid(false)
19{}
20
21LayerInputHandle::LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo)
22 : m_OutputSlot(outputSlot)
23 , m_Valid(valid)
24 , m_TensorInfo(tensorInfo)
25{}
26
27bool LayerInputHandle::IsValid() const
28{
29 return m_Valid;
30}
31
32void LayerInputHandle::Connect(armnn::IInputSlot& inputSlot)
33{
34 BOOST_ASSERT(IsValid());
35 if (m_OutputSlot)
36 {
37 m_OutputSlot->Connect(inputSlot);
38 }
39}
40
41const armnn::TensorInfo& LayerInputHandle::GetTensorInfo() const
42{
43 return m_TensorInfo;
44}
45
46ConstTensorPin::ConstTensorPin(bool optional)
47 : m_Optional(optional)
48{}
49
50ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo,
51 const void* valueStart,
52 uint32_t numBytes,
53 const armnn::PermutationVector& mappings)
54{
55 boost::ignore_unused(numBytes);
56 assert(tensorInfo.GetNumBytes() == numBytes);
57
58 const bool needsSwizzling = (mappings.GetSize() > 0);
59 if (needsSwizzling)
60 {
61 m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
62 SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
63
64 m_ConstTensor = armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, mappings), m_SwizzledTensorData.data());
65 }
66 else
67 {
68 m_ConstTensor = armnn::ConstTensor(tensorInfo, valueStart);
69 }
70}
71
72bool ConstTensorPin::IsValid() const
73{
74 return m_ConstTensor.GetMemoryArea() != nullptr;
75}
76
77bool ConstTensorPin::IsOptional() const
78{
79 return m_Optional;
80}
81
82const armnn::ConstTensor& ConstTensorPin::GetConstTensor() const
83{
84 return m_ConstTensor;
85}
86
87const armnn::ConstTensor* ConstTensorPin::GetConstTensorPtr() const
88{
89 if (IsValid() && m_ConstTensor.GetNumElements() > 0)
90 {
91 return &m_ConstTensor;
92 }
93 // tensor is either invalid, or has no elements (indicating an optional tensor that was not provided)
94 return nullptr;
95}
96
97///
98/// Utility functions
99///
100
101armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
102 ActivationFn activation,
103 armnn::IConnectableLayer* prevLayer,
104 ConversionData& data)
105{
106 BOOST_ASSERT(prevLayer->GetNumOutputSlots() == 1);
107
108 prevLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
109
110 armnn::IConnectableLayer* activationLayer = prevLayer;
111
112 if (activation != ActivationFn::kActivationNone)
113 {
114 armnn::ActivationDescriptor activationDesc;
115 switch (activation)
116 {
117 case ActivationFn::kActivationRelu:
118 {
119 activationDesc.m_Function = armnn::ActivationFunction::ReLu;
120 break;
121 }
122 case ActivationFn::kActivationRelu1:
123 {
124 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
125 activationDesc.m_A = 1.0f;
126 activationDesc.m_B = -1.0f;
127 break;
128 }
129 case ActivationFn::kActivationRelu6:
130 {
131 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
132 activationDesc.m_A = 6.0f;
133 break;
134 }
135 case ActivationFn::kActivationSigmoid:
136 {
137 activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
138 break;
139 }
140 case ActivationFn::kActivationTanh:
141 {
142 activationDesc.m_Function = armnn::ActivationFunction::TanH;
143 activationDesc.m_A = 1.0f;
144 activationDesc.m_B = 1.0f;
145 break;
146 }
147 default:
148 {
149 Fail("%s: Invalid activation enum value %i", __func__, activation);
150 return nullptr;
151 }
152 }
153
Ferran Balaguerd30093c2019-07-09 17:04:47 +0100154 bool isSupported = false;
155 FORWARD_LAYER_SUPPORT_FUNC(__func__,
156 IsActivationSupported,
157 data.m_Backends,
158 isSupported,
159 prevLayer->GetOutputSlot(0).GetTensorInfo(),
160 tensorInfo,
161 activationDesc);
162 if (!isSupported)
arovir01b0717b52018-09-05 17:03:25 +0100163 {
164 return nullptr;
165 }
166
167 activationLayer = data.m_Network->AddActivationLayer(activationDesc);
168
169 prevLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
170 activationLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
171 }
172
173 return activationLayer;
174}
175
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +0100176} // namespace armnn_driver