blob: 8a1b585d47225f4b480d4d5608a50b5c0e918712 [file] [log] [blame]
Matthew Sloyan80fbcd52021-01-07 13:28:47 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClContextDeserializer.hpp"
7#include "ClContextSchema_generated.h"
8
9#include <armnn/Exceptions.hpp>
10#include <armnn/utility/NumericCast.hpp>
11
12#include <flatbuffers/flexbuffers.h>
13
14#include <fmt/format.h>
15
16#include <cstdlib>
17#include <fstream>
18#include <iostream>
19#include <vector>
20
21namespace armnn
22{
23
24void ClContextDeserializer::Deserialize(arm_compute::CLCompileContext& clCompileContext,
25 cl::Context& context,
26 cl::Device& device,
27 const std::string& filePath)
28{
29 std::ifstream inputFileStream(filePath, std::ios::binary);
30 std::vector<std::uint8_t> binaryContent;
31 while (inputFileStream)
32 {
33 char input;
34 inputFileStream.get(input);
35 if (inputFileStream)
36 {
37 binaryContent.push_back(static_cast<std::uint8_t>(input));
38 }
39 }
40 inputFileStream.close();
41 DeserializeFromBinary(clCompileContext, context, device, binaryContent);
42}
43
44void ClContextDeserializer::DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext,
45 cl::Context& context,
46 cl::Device& device,
47 const std::vector<uint8_t>& binaryContent)
48{
49 if (binaryContent.data() == nullptr)
50 {
51 throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
52 CHECK_LOCATION().AsString()));
53 }
54
55 size_t binaryContentSize = binaryContent.size();
56 flatbuffers::Verifier verifier(binaryContent.data(), binaryContentSize);
57 if (verifier.VerifyBuffer<ClContext>() == false)
58 {
59 throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
60 "flatbuffers format. size:{0} {1}",
61 binaryContentSize,
62 CHECK_LOCATION().AsString()));
63 }
64 auto clContext = GetClContext(binaryContent.data());
65
66 for (Program const* program : *clContext->programs())
67 {
68 auto programName = program->name()->c_str();
69 auto programBinary = program->binary();
70 std::vector<uint8_t> binary(programBinary->begin(), programBinary->begin() + programBinary->size());
71
72 cl::Program::Binaries binaries{ binary };
73 std::vector<cl::Device> devices {device};
74 cl::Program theProgram(context, devices, binaries);
75 theProgram.build();
76 clCompileContext.add_built_program(programName, theProgram);
77 }
78}
79
80} // namespace armnn