blob: 9ac0b3986e7bc0b89789138477c056c1262cdd41 [file] [log] [blame]
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001//
Sadik Armagana097d2a2021-11-24 15:47:28 +00002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Matteo Martincighbf0e7222019-06-20 17:17:45 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "TestUtils.hpp"
7
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Matteo Martincighbf0e7222019-06-20 17:17:45 +01009
10using namespace armnn;
11
12void Connect(armnn::IConnectableLayer* from, armnn::IConnectableLayer* to, const armnn::TensorInfo& tensorInfo,
13 unsigned int fromIndex, unsigned int toIndex)
14{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010015 ARMNN_ASSERT(from);
16 ARMNN_ASSERT(to);
Matteo Martincighbf0e7222019-06-20 17:17:45 +010017
Cathal Corbettb8cc2b92021-10-08 14:43:11 +010018 try
19 {
20 from->GetOutputSlot(fromIndex).Connect(to->GetInputSlot(toIndex));
21 }
22 catch (const std::out_of_range& exc)
23 {
24 std::ostringstream message;
25
26 if (to->GetType() == armnn::LayerType::FullyConnected && toIndex == 2)
27 {
28 message << "Tried to connect bias to FullyConnected layer when bias is not enabled: ";
29 }
30
31 message << "Failed to connect to input slot "
32 << toIndex
33 << " on "
34 << GetLayerTypeAsCString(to->GetType())
35 << " layer "
36 << std::quoted(to->GetName())
37 << " as the slot does not exist or is unavailable";
38 throw LayerValidationException(message.str());
39 }
40
Matteo Martincighbf0e7222019-06-20 17:17:45 +010041 from->GetOutputSlot(fromIndex).SetTensorInfo(tensorInfo);
42}
Sadik Armaganea41b572020-03-19 18:16:46 +000043
44namespace armnn
45{
46
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000047Graph& GetGraphForTesting(IOptimizedNetwork* optNet)
48{
49 return optNet->pOptimizedNetworkImpl->GetGraph();
50}
51
52ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNet)
53{
54 return optNet->pOptimizedNetworkImpl->GetModelOptions();
55}
56
Kevin Mayd92a6e42021-02-04 10:27:41 +000057profiling::ProfilingService& GetProfilingService(armnn::RuntimeImpl* runtime)
Sadik Armaganea41b572020-03-19 18:16:46 +000058{
59 return runtime->m_ProfilingService;
60}
61
62}