blob: 6312f727075497245fbd079505b583ade40b9702 [file] [log] [blame]
Matthew Sloyan164bf4f2022-10-28 18:02:17 +01001//
Mike Kelly1ec5f852023-04-05 12:51:10 +01002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +01003// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/INetwork.hpp>
7
8#include <GraphUtils.hpp>
9#include <TestUtils.hpp>
10
11#include <doctest/doctest.h>
12
13TEST_SUITE("TosaReferenceOptimizedNetwork")
14{
15
16TEST_CASE("SimpleSupportedOptimizedNetwork")
17{
18 armnn::IRuntime::CreationOptions options;
19 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
20 armnn::INetworkPtr network(armnn::INetwork::Create());
21
22 auto inputLayer1 = network->AddInputLayer(0, "input_1");
23 auto inputLayer2 = network->AddInputLayer(1, "input_2");
Mike Kelly1ec5f852023-04-05 12:51:10 +010024 ARMNN_NO_DEPRECATE_WARN_BEGIN
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010025 auto addLayer = network->AddAdditionLayer("add");
Mike Kelly1ec5f852023-04-05 12:51:10 +010026 ARMNN_NO_DEPRECATE_WARN_END
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010027 auto outputLayer = network->AddOutputLayer(2, "output");
28
29 armnn::TensorInfo tensorInfo{{4}, armnn::DataType::Float32};
30
31 inputLayer1->GetOutputSlot(0).Connect(addLayer->GetInputSlot(0));
32 inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo);
33
34 inputLayer2->GetOutputSlot(0).Connect(addLayer->GetInputSlot(1));
35 inputLayer2->GetOutputSlot(0).SetTensorInfo(tensorInfo);
36
37 addLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
38 addLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
39
40 std::vector<armnn::BackendId> backends = { "TosaRef" };
41
John Mcloughlin42969272023-04-14 14:43:47 +010042 armnn::OptimizerOptionsOpaque optimizedOptions;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010043 armnn::IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec(), optimizedOptions);
44 CHECK(optNet);
45
46 armnn::Graph& graph = GetGraphForTesting(optNet.get());
47
48 // Check graph layer sequence to ensure that the network has been replaced with a PreCompiledLayer
49 CHECK(CheckSequence(graph.cbegin(), graph.cend(),
50 &IsLayerOfType<armnn::InputLayer>,
51 &IsLayerOfType<armnn::InputLayer>,
52 &IsLayerOfType<armnn::PreCompiledLayer>,
53 &IsLayerOfType<armnn::OutputLayer>));
54}
55
56}