blob: 341752dd670ddebd664b01ad68abaa12c9a86326 [file] [log] [blame]
Mike Kellyaf484012019-02-20 16:53:11 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Jan Eilers8eb25602020-03-09 12:13:48 +00006#include "../Serializer.hpp"
7
Matthew Benthamff130e22020-01-17 11:47:42 +00008#include <armnn/Descriptors.hpp>
Mike Kellyaf484012019-02-20 16:53:11 +00009#include <armnn/INetwork.hpp>
Matthew Benthamff130e22020-01-17 11:47:42 +000010#include <armnn/IRuntime.hpp>
11#include <armnnDeserializer/IDeserializer.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000012#include <armnn/utility/IgnoreUnused.hpp>
Matthew Benthamff130e22020-01-17 11:47:42 +000013
Sadik Armagan1625efc2021-06-10 18:24:34 +010014#include <doctest/doctest.h>
Matthew Benthamff130e22020-01-17 11:47:42 +000015
Mike Kellyaf484012019-02-20 16:53:11 +000016#include <sstream>
Derek Lamberti859f9ce2019-12-10 22:05:21 +000017
Sadik Armagan1625efc2021-06-10 18:24:34 +010018TEST_SUITE("SerializerTests")
19{
Finn Williamsb454c5c2021-02-09 15:56:23 +000020class VerifyActivationName : public armnn::IStrategy
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000021{
22public:
Finn Williamsb454c5c2021-02-09 15:56:23 +000023 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
24 const armnn::BaseDescriptor& descriptor,
25 const std::vector<armnn::ConstTensor>& constants,
26 const char* name,
27 const armnn::LayerBindingId id = 0) override
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000028 {
Finn Williamsb454c5c2021-02-09 15:56:23 +000029 IgnoreUnused(layer, descriptor, constants, id);
30 if (layer->GetType() == armnn::LayerType::Activation)
31 {
Sadik Armagan1625efc2021-06-10 18:24:34 +010032 CHECK(std::string(name) == "activation");
Finn Williamsb454c5c2021-02-09 15:56:23 +000033 }
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000034 }
35};
36
Sadik Armagan1625efc2021-06-10 18:24:34 +010037TEST_CASE("ActivationSerialization")
Mike Kellyaf484012019-02-20 16:53:11 +000038{
39 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
40
41 armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
42 armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
43
44 // Construct network
45 armnn::INetworkPtr network = armnn::INetwork::Create();
46
47 armnn::ActivationDescriptor descriptor;
48 descriptor.m_Function = armnn::ActivationFunction::ReLu;
49 descriptor.m_A = 0;
50 descriptor.m_B = 0;
51
52 armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input");
53 armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
54 armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output");
55
56 inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
57 inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
58
59 activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
60 activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
61
Finn Williams85d36712021-01-26 22:30:06 +000062 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
63
64 serializer->Serialize(*network);
Mike Kellyaf484012019-02-20 16:53:11 +000065
66 std::stringstream stream;
Finn Williams85d36712021-01-26 22:30:06 +000067 serializer->SaveSerializedToStream(stream);
Mike Kellyaf484012019-02-20 16:53:11 +000068
69 std::string const serializerString{stream.str()};
70 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
71
72 armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
73
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000074 VerifyActivationName visitor;
Finn Williamsb454c5c2021-02-09 15:56:23 +000075 deserializedNetwork->ExecuteStrategy(visitor);
Éanna Ó Catháin633f8592019-02-25 16:26:29 +000076
Mike Kellyaf484012019-02-20 16:53:11 +000077 armnn::IRuntime::CreationOptions options; // default options
78 armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
79 auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
80
81 armnn::NetworkId networkIdentifier;
82
83 // Load graph into runtime
84 run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
85
86 std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
Cathal Corbett5b8093c2021-10-22 11:12:07 +010087 armnn::TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
88 inputTensorInfo.SetConstant(true);
Mike Kellyaf484012019-02-20 16:53:11 +000089 armnn::InputTensors inputTensors
90 {
Cathal Corbett5b8093c2021-10-22 11:12:07 +010091 {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
Mike Kellyaf484012019-02-20 16:53:11 +000092 };
93
94 std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
95
96 std::vector<float> outputData(4);
97 armnn::OutputTensors outputTensors
98 {
99 {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
100 };
101 run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100102 CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutputData.begin(), expectedOutputData.end()));
Mike Kellyaf484012019-02-20 16:53:11 +0000103}
104
Sadik Armagan1625efc2021-06-10 18:24:34 +0100105}