blob: acd0d1bce48589e37b7d8b7ce2ed3444e415fd65 [file] [log] [blame]
Sadik Armagana097d2a2021-11-24 15:47:28 +00001//
Mike Kelly4cc341c2023-07-07 15:43:06 +01002// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagana097d2a2021-11-24 15:47:28 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/INetwork.hpp>
9#include <Graph.hpp>
10#include <Runtime.hpp>
11
12void Connect(armnn::IConnectableLayer* from, armnn::IConnectableLayer* to, const armnn::TensorInfo& tensorInfo,
13 unsigned int fromIndex = 0, unsigned int toIndex = 0);
14
Mike Kelly4cc341c2023-07-07 15:43:06 +010015class LayerNameAndTypeCheck
16{
17public:
18 LayerNameAndTypeCheck(armnn::LayerType layerType, const char* name)
19 : m_layerType(layerType)
20 , m_name(name)
21 {}
22
23 bool operator()(const armnn::Layer* const layer)
24 {
25 return (layer->GetNameStr() == m_name &&
26 layer->GetType() == m_layerType);
27 }
28private:
29 armnn::LayerType m_layerType;
30 const char* m_name;
31};
32
Sadik Armagana097d2a2021-11-24 15:47:28 +000033template <typename LayerT>
34bool IsLayerOfType(const armnn::Layer* const layer)
35{
36 return (layer->GetType() == armnn::LayerEnumOf<LayerT>());
37}
38
39inline bool CheckSequence(const armnn::Graph::ConstIterator first, const armnn::Graph::ConstIterator last)
40{
41 return (first == last);
42}
43
44/// Checks each unary function in Us evaluates true for each correspondent layer in the sequence [first, last).
45template <typename U, typename... Us>
46bool CheckSequence(const armnn::Graph::ConstIterator first, const armnn::Graph::ConstIterator last, U&& u, Us&&... us)
47{
48 return u(*first) && CheckSequence(std::next(first), last, us...);
49}
50
51template <typename LayerT>
52bool CheckRelatedLayers(armnn::Graph& graph, const std::list<std::string>& testRelatedLayers)
53{
54 for (auto& layer : graph)
55 {
56 if (layer->GetType() == armnn::LayerEnumOf<LayerT>())
57 {
58 auto& relatedLayers = layer->GetRelatedLayerNames();
59 if (!std::equal(relatedLayers.begin(), relatedLayers.end(), testRelatedLayers.begin(),
60 testRelatedLayers.end()))
61 {
62 return false;
63 }
64 }
65 }
66
67 return true;
68}
69
70namespace armnn
71{
72Graph& GetGraphForTesting(IOptimizedNetwork* optNetPtr);
73ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNetPtr);
Jim Flynnaf947722022-03-02 11:04:47 +000074arm::pipe::IProfilingService& GetProfilingService(RuntimeImpl* runtime);
Sadik Armagana097d2a2021-11-24 15:47:28 +000075
Jim Flynnaf947722022-03-02 11:04:47 +000076} // namespace armnn