blob: 5433d933322fbffb3e4e2d48be56b78fb34f67cd [file] [log] [blame]
Sadik Armagana097d2a2021-11-24 15:47:28 +00001//
2// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/INetwork.hpp>
9#include <Graph.hpp>
10#include <Runtime.hpp>
11
12void Connect(armnn::IConnectableLayer* from, armnn::IConnectableLayer* to, const armnn::TensorInfo& tensorInfo,
13 unsigned int fromIndex = 0, unsigned int toIndex = 0);
14
15template <typename LayerT>
16bool IsLayerOfType(const armnn::Layer* const layer)
17{
18 return (layer->GetType() == armnn::LayerEnumOf<LayerT>());
19}
20
21inline bool CheckSequence(const armnn::Graph::ConstIterator first, const armnn::Graph::ConstIterator last)
22{
23 return (first == last);
24}
25
26/// Checks each unary function in Us evaluates true for each correspondent layer in the sequence [first, last).
27template <typename U, typename... Us>
28bool CheckSequence(const armnn::Graph::ConstIterator first, const armnn::Graph::ConstIterator last, U&& u, Us&&... us)
29{
30 return u(*first) && CheckSequence(std::next(first), last, us...);
31}
32
33template <typename LayerT>
34bool CheckRelatedLayers(armnn::Graph& graph, const std::list<std::string>& testRelatedLayers)
35{
36 for (auto& layer : graph)
37 {
38 if (layer->GetType() == armnn::LayerEnumOf<LayerT>())
39 {
40 auto& relatedLayers = layer->GetRelatedLayerNames();
41 if (!std::equal(relatedLayers.begin(), relatedLayers.end(), testRelatedLayers.begin(),
42 testRelatedLayers.end()))
43 {
44 return false;
45 }
46 }
47 }
48
49 return true;
50}
51
52namespace armnn
53{
54Graph& GetGraphForTesting(IOptimizedNetwork* optNetPtr);
55ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNetPtr);
Cathal Corbett5aa9fd72022-02-25 15:33:28 +000056arm::pipe::ProfilingService& GetProfilingService(RuntimeImpl* runtime);
Sadik Armagana097d2a2021-11-24 15:47:28 +000057
58} // namespace armnn