blob: abc63ae64d6dca0d07017e6f43bb60465cc6fac8 [file] [log] [blame]
Mike Kellyaf484012019-02-20 16:53:11 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Jan Eilers8eb25602020-03-09 12:13:48 +00006#include "../Serializer.hpp"
7
Matthew Benthamff130e22020-01-17 11:47:42 +00008#include <armnn/Descriptors.hpp>
Mike Kellyaf484012019-02-20 16:53:11 +00009#include <armnn/INetwork.hpp>
Matthew Benthamff130e22020-01-17 11:47:42 +000010#include <armnn/IRuntime.hpp>
11#include <armnnDeserializer/IDeserializer.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000012#include <armnn/utility/IgnoreUnused.hpp>
Matthew Benthamff130e22020-01-17 11:47:42 +000013
Jan Eilers8eb25602020-03-09 12:13:48 +000014#include <boost/test/unit_test.hpp>
Matthew Benthamff130e22020-01-17 11:47:42 +000015
Mike Kellyaf484012019-02-20 16:53:11 +000016#include <sstream>
Derek Lamberti859f9ce2019-12-10 22:05:21 +000017
Mike Kellyaf484012019-02-20 16:53:11 +000018BOOST_AUTO_TEST_SUITE(SerializerTests)
19
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000020class VerifyActivationName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
21{
22public:
23 void VisitActivationLayer(const armnn::IConnectableLayer* layer,
24 const armnn::ActivationDescriptor& activationDescriptor,
25 const char* name) override
26 {
Jan Eilers8eb25602020-03-09 12:13:48 +000027 IgnoreUnused(layer, activationDescriptor);
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000028 BOOST_TEST(name == "activation");
29 }
30};
31
Mike Kellyaf484012019-02-20 16:53:11 +000032BOOST_AUTO_TEST_CASE(ActivationSerialization)
33{
34 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
35
36 armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
37 armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
38
39 // Construct network
40 armnn::INetworkPtr network = armnn::INetwork::Create();
41
42 armnn::ActivationDescriptor descriptor;
43 descriptor.m_Function = armnn::ActivationFunction::ReLu;
44 descriptor.m_A = 0;
45 descriptor.m_B = 0;
46
47 armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input");
48 armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
49 armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output");
50
51 inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
52 inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
53
54 activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
55 activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
56
57 armnnSerializer::Serializer serializer;
58 serializer.Serialize(*network);
59
60 std::stringstream stream;
61 serializer.SaveSerializedToStream(stream);
62
63 std::string const serializerString{stream.str()};
64 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
65
66 armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
67
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000068 VerifyActivationName visitor;
69 deserializedNetwork->Accept(visitor);
70
Mike Kellyaf484012019-02-20 16:53:11 +000071 armnn::IRuntime::CreationOptions options; // default options
72 armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
73 auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
74
75 armnn::NetworkId networkIdentifier;
76
77 // Load graph into runtime
78 run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
79
80 std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
81 armnn::InputTensors inputTensors
82 {
83 {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), inputData.data())}
84 };
85
86 std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
87
88 std::vector<float> outputData(4);
89 armnn::OutputTensors outputTensors
90 {
91 {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
92 };
93 run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
94 BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(), outputData.end(),
95 expectedOutputData.begin(), expectedOutputData.end());
96}
97
98BOOST_AUTO_TEST_SUITE_END()