blob: 81a66145d92b23a3deb061789fd119024227fa6d [file] [log] [blame]
Matthew Sloyan80fbcd52021-01-07 13:28:47 +00001//
John Mcloughlinc5ee0d72023-03-24 12:07:25 +00002// Copyright © 2020, 2023 Arm Ltd. All rights reserved.
Matthew Sloyan80fbcd52021-01-07 13:28:47 +00003// SPDX-License-Identifier: MIT
4//
5
Rob Hughes9542f902021-07-14 09:48:54 +01006#include <armnnUtils/Filesystem.hpp>
Matthew Sloyan80fbcd52021-01-07 13:28:47 +00007
8#include <cl/test/ClContextControlFixture.hpp>
9
Sadik Armagan1625efc2021-06-10 18:24:34 +010010#include <doctest/doctest.h>
Matthew Sloyan80fbcd52021-01-07 13:28:47 +000011
12#include <fstream>
13
14namespace
15{
16
17armnn::INetworkPtr CreateNetwork()
18{
19 // Builds up the structure of the network.
20 armnn::INetworkPtr net(armnn::INetwork::Create());
21
22 armnn::IConnectableLayer* input = net->AddInputLayer(0, "input");
23 armnn::IConnectableLayer* softmax = net->AddSoftmaxLayer(armnn::SoftmaxDescriptor(), "softmax");
24 armnn::IConnectableLayer* output = net->AddOutputLayer(0, "output");
25
26 input->GetOutputSlot(0).Connect(softmax->GetInputSlot(0));
27 softmax->GetOutputSlot(0).Connect(output->GetInputSlot(0));
28
29 // Sets the input and output tensors
30 armnn::TensorInfo inputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 10000.0f, 1);
31 input->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
32
33 armnn::TensorInfo outputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 1.0f/255.0f, 0);
34 softmax->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
35
36 return net;
37}
38
39void RunInference(armnn::NetworkId& netId, armnn::IRuntimePtr& runtime, std::vector<uint8_t>& outputData)
40{
41 // Creates structures for input & output.
42 std::vector<uint8_t> inputData
43 {
44 1, 10, 3, 200, 5 // Some inputs - one of which is sufficiently larger than the others to saturate softmax.
45 };
46
Cathal Corbett5b8093c2021-10-22 11:12:07 +010047 armnn::TensorInfo inputTensorInfo = runtime->GetInputTensorInfo(netId, 0);
48 inputTensorInfo.SetConstant(true);
Matthew Sloyan80fbcd52021-01-07 13:28:47 +000049 armnn::InputTensors inputTensors
50 {
Cathal Corbett5b8093c2021-10-22 11:12:07 +010051 {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
Matthew Sloyan80fbcd52021-01-07 13:28:47 +000052 };
53
54 armnn::OutputTensors outputTensors
55 {
56 {0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
57 };
58
59 // Run inference.
60 runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
61}
62
63std::vector<char> ReadBinaryFile(const std::string& binaryFileName)
64{
65 std::ifstream input(binaryFileName, std::ios::binary);
66 return std::vector<char>(std::istreambuf_iterator<char>(input), {});
67}
68
69} // anonymous namespace
70
Sadik Armagan1625efc2021-06-10 18:24:34 +010071TEST_CASE_FIXTURE(ClContextControlFixture, "ClContextSerializerTest")
Matthew Sloyan80fbcd52021-01-07 13:28:47 +000072{
73 // Get tmp directory and create blank file.
74 fs::path filePath = armnnUtils::Filesystem::NamedTempFile("Armnn-CachedNetworkFileTest-TempFile.bin");
75 std::string const filePathString{filePath.string()};
76 std::ofstream file { filePathString };
77
78 // Create runtime in which test will run
79 armnn::IRuntime::CreationOptions options;
80 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
81
82 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
83
84 // Create two networks.
85 // net1 will serialize and save context to file.
86 // net2 will deserialize context saved from net1 and load.
87 armnn::INetworkPtr net1 = CreateNetwork();
88 armnn::INetworkPtr net2 = CreateNetwork();
89
90 // Add specific optimizerOptions to each network.
John Mcloughlinc5ee0d72023-03-24 12:07:25 +000091 armnn::OptimizerOptionsOpaque optimizerOptions1;
92 armnn::OptimizerOptionsOpaque optimizerOptions2;
Matthew Sloyan80fbcd52021-01-07 13:28:47 +000093 armnn::BackendOptions modelOptions1("GpuAcc",
94 {{"SaveCachedNetwork", true}, {"CachedNetworkFilePath", filePathString}});
95 armnn::BackendOptions modelOptions2("GpuAcc",
96 {{"SaveCachedNetwork", false}, {"CachedNetworkFilePath", filePathString}});
John Mcloughlinc5ee0d72023-03-24 12:07:25 +000097 optimizerOptions1.AddModelOption(modelOptions1);
98 optimizerOptions2.AddModelOption(modelOptions2);
Matthew Sloyan80fbcd52021-01-07 13:28:47 +000099
100 armnn::IOptimizedNetworkPtr optNet1 = armnn::Optimize(
101 *net1, backends, runtime->GetDeviceSpec(), optimizerOptions1);
102 armnn::IOptimizedNetworkPtr optNet2 = armnn::Optimize(
103 *net2, backends, runtime->GetDeviceSpec(), optimizerOptions2);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100104 CHECK(optNet1);
105 CHECK(optNet2);
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000106
107 // Cached file should be empty until net1 is loaded into runtime.
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108 CHECK(fs::is_empty(filePathString));
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000109
110 // Load net1 into the runtime.
111 armnn::NetworkId netId1;
Sadik Armagan1625efc2021-06-10 18:24:34 +0100112 CHECK(runtime->LoadNetwork(netId1, std::move(optNet1)) == armnn::Status::Success);
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000113
114 // File should now exist and not be empty. It has been serialized.
Sadik Armagan1625efc2021-06-10 18:24:34 +0100115 CHECK(fs::exists(filePathString));
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000116 std::vector<char> dataSerialized = ReadBinaryFile(filePathString);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100117 CHECK(dataSerialized.size() != 0);
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000118
119 // Load net2 into the runtime using file and deserialize.
120 armnn::NetworkId netId2;
Sadik Armagan1625efc2021-06-10 18:24:34 +0100121 CHECK(runtime->LoadNetwork(netId2, std::move(optNet2)) == armnn::Status::Success);
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000122
123 // Run inference and get output data.
124 std::vector<uint8_t> outputData1(5);
125 RunInference(netId1, runtime, outputData1);
126
127 std::vector<uint8_t> outputData2(5);
128 RunInference(netId2, runtime, outputData2);
129
130 // Compare outputs from both networks.
Sadik Armagan1625efc2021-06-10 18:24:34 +0100131 CHECK(std::equal(outputData1.begin(), outputData1.end(), outputData2.begin(), outputData2.end()));
Matthew Sloyan80fbcd52021-01-07 13:28:47 +0000132
133 // Remove temp file created.
134 fs::remove(filePath);
135}