blob: 2a46045f596f283f5fe043a43c71a25197e33e85 [file] [log] [blame]
Mike Kellyaf484012019-02-20 16:53:11 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnnDeserializer/IDeserializer.hpp>
7#include <armnn/ArmNN.hpp>
8#include <armnn/INetwork.hpp>
9#include "../Serializer.hpp"
10#include <sstream>
11#include <boost/test/unit_test.hpp>
12
13BOOST_AUTO_TEST_SUITE(SerializerTests)
14
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000015class VerifyActivationName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
16{
17public:
18 void VisitActivationLayer(const armnn::IConnectableLayer* layer,
19 const armnn::ActivationDescriptor& activationDescriptor,
20 const char* name) override
21 {
22 BOOST_TEST(name == "activation");
23 }
24};
25
Mike Kellyaf484012019-02-20 16:53:11 +000026BOOST_AUTO_TEST_CASE(ActivationSerialization)
27{
28 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
29
30 armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
31 armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
32
33 // Construct network
34 armnn::INetworkPtr network = armnn::INetwork::Create();
35
36 armnn::ActivationDescriptor descriptor;
37 descriptor.m_Function = armnn::ActivationFunction::ReLu;
38 descriptor.m_A = 0;
39 descriptor.m_B = 0;
40
41 armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input");
42 armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
43 armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output");
44
45 inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
46 inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
47
48 activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
49 activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
50
51 armnnSerializer::Serializer serializer;
52 serializer.Serialize(*network);
53
54 std::stringstream stream;
55 serializer.SaveSerializedToStream(stream);
56
57 std::string const serializerString{stream.str()};
58 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
59
60 armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
61
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000062 VerifyActivationName visitor;
63 deserializedNetwork->Accept(visitor);
64
Mike Kellyaf484012019-02-20 16:53:11 +000065 armnn::IRuntime::CreationOptions options; // default options
66 armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
67 auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
68
69 armnn::NetworkId networkIdentifier;
70
71 // Load graph into runtime
72 run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
73
74 std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
75 armnn::InputTensors inputTensors
76 {
77 {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), inputData.data())}
78 };
79
80 std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
81
82 std::vector<float> outputData(4);
83 armnn::OutputTensors outputTensors
84 {
85 {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
86 };
87 run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
88 BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(), outputData.end(),
89 expectedOutputData.begin(), expectedOutputData.end());
90}
91
92BOOST_AUTO_TEST_SUITE_END()