blob: 88778b306ab5d5171b1affdbf72375235b9af8c8 [file] [log] [blame]
Finn Williamsb454c5c2021-02-09 15:56:23 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "../Serializer.hpp"
7#include "SerializerTestUtils.hpp"
8
9#include <armnn/Descriptors.hpp>
10#include <armnn/INetwork.hpp>
11#include <armnn/IRuntime.hpp>
12#include <armnnDeserializer/IDeserializer.hpp>
13#include <armnn/utility/IgnoreUnused.hpp>
14
Sadik Armagan1625efc2021-06-10 18:24:34 +010015#include <doctest/doctest.h>
Finn Williamsb454c5c2021-02-09 15:56:23 +000016
Sadik Armagan1625efc2021-06-10 18:24:34 +010017TEST_SUITE("SerializerTests")
18{
Finn Williamsb454c5c2021-02-09 15:56:23 +000019struct ComparisonModel
20{
21 ComparisonModel(const std::string& layerName,
22 const armnn::TensorInfo& inputInfo,
23 const armnn::TensorInfo& outputInfo,
24 armnn::ComparisonDescriptor& descriptor)
25 : m_network(armnn::INetwork::Create())
26 {
27 armnn::IConnectableLayer* const inputLayer0 = m_network->AddInputLayer(0);
28 armnn::IConnectableLayer* const inputLayer1 = m_network->AddInputLayer(1);
29 armnn::IConnectableLayer* const equalLayer = m_network->AddComparisonLayer(descriptor, layerName.c_str());
30 armnn::IConnectableLayer* const outputLayer = m_network->AddOutputLayer(0);
31
32 inputLayer0->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0));
33 inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1));
34 equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
35
36 inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo);
37 inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo);
38 equalLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
39 }
40
41 armnn::INetworkPtr m_network;
42};
43
44class ComparisonLayerVerifier : public LayerVerifierBase
45{
46public:
47 ComparisonLayerVerifier(const std::string& layerName,
48 const std::vector<armnn::TensorInfo>& inputInfos,
49 const std::vector<armnn::TensorInfo>& outputInfos,
50 const armnn::ComparisonDescriptor& descriptor)
51 : LayerVerifierBase(layerName, inputInfos, outputInfos)
52 , m_Descriptor (descriptor) {}
53
54 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
55 const armnn::BaseDescriptor& descriptor,
56 const std::vector<armnn::ConstTensor>& constants,
57 const char* name,
58 const armnn::LayerBindingId id = 0) override
59 {
60 armnn::IgnoreUnused(descriptor, constants, id);
61 switch (layer->GetType())
62 {
63 case armnn::LayerType::Input: break;
64 case armnn::LayerType::Output: break;
65 case armnn::LayerType::Comparison:
66 {
67 VerifyNameAndConnections(layer, name);
68 const armnn::ComparisonDescriptor& layerDescriptor =
69 static_cast<const armnn::ComparisonDescriptor&>(descriptor);
Sadik Armagan1625efc2021-06-10 18:24:34 +010070 CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation);
Finn Williamsb454c5c2021-02-09 15:56:23 +000071 break;
72 }
73 default:
74 {
75 throw armnn::Exception("Unexpected layer type in Comparison test model");
76 }
77 }
78 }
79
80private:
81 armnn::ComparisonDescriptor m_Descriptor;
82};
83
Sadik Armagan1625efc2021-06-10 18:24:34 +010084TEST_CASE("SerializeEqual")
Finn Williamsb454c5c2021-02-09 15:56:23 +000085{
86 const std::string layerName("equal");
87
88 const armnn::TensorShape shape{2, 1, 2, 4};
89 const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32);
90 const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
91
92 armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Equal);
93
94 ComparisonModel model(layerName, inputInfo, outputInfo, descriptor);
95
96 armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network));
Sadik Armagan1625efc2021-06-10 18:24:34 +010097 CHECK(deserializedNetwork);
Finn Williamsb454c5c2021-02-09 15:56:23 +000098
99 ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
100 deserializedNetwork->ExecuteStrategy(verifier);
101}
102
Sadik Armagan1625efc2021-06-10 18:24:34 +0100103TEST_CASE("SerializeGreater")
Finn Williamsb454c5c2021-02-09 15:56:23 +0000104{
105 const std::string layerName("greater");
106
107 const armnn::TensorShape shape{2, 1, 2, 4};
108 const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32);
109 const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
110
111 armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Greater);
112
113 ComparisonModel model(layerName, inputInfo, outputInfo, descriptor);
114
115 armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network));
Sadik Armagan1625efc2021-06-10 18:24:34 +0100116 CHECK(deserializedNetwork);
Finn Williamsb454c5c2021-02-09 15:56:23 +0000117
118 ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
119 deserializedNetwork->ExecuteStrategy(verifier);
120}
121
Sadik Armagan1625efc2021-06-10 18:24:34 +0100122}