blob: 6ba5085b361871c37ec77240a3f11bc4f4aa4755 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Cifar10Database.hpp"
6
7#include <boost/numeric/conversion/cast.hpp>
8#include <boost/log/trivial.hpp>
9#include <fstream>
10#include <vector>
11
12constexpr unsigned int g_kCifar10ImageByteSize = 1 + 3 * 32 * 32;
13
14Cifar10Database::Cifar10Database(const std::string& binaryFileDirectory, bool rgbPack)
15 : m_BinaryDirectory(binaryFileDirectory), m_RgbPack(rgbPack)
16{
17}
18
19std::unique_ptr<Cifar10Database::TTestCaseData> Cifar10Database::GetTestCaseData(unsigned int testCaseId)
20{
21 std::vector<unsigned char> I(g_kCifar10ImageByteSize);
22
23 std::string fullpath = m_BinaryDirectory + std::string("test_batch.bin");
24
25 std::ifstream fileStream(fullpath, std::ios::binary);
26 if (!fileStream.is_open())
27 {
28 BOOST_LOG_TRIVIAL(fatal) << "Failed to load " << fullpath;
29 return nullptr;
30 }
31
32 fileStream.seekg(testCaseId * g_kCifar10ImageByteSize, std::ios_base::beg);
33 fileStream.read(reinterpret_cast<char*>(&I[0]), g_kCifar10ImageByteSize);
34
35 if (!fileStream.good())
36 {
37 BOOST_LOG_TRIVIAL(fatal) << "Failed to read " << fullpath;
38 return nullptr;
39 }
40
41
42 std::vector<float> inputImageData;
43 inputImageData.resize(g_kCifar10ImageByteSize - 1);
44
45 unsigned int step;
46 unsigned int countR_o;
47 unsigned int countG_o;
48 unsigned int countB_o;
49 unsigned int countR = 1;
50 unsigned int countG = 1 + 32 * 32;
51 unsigned int countB = 1 + 2 * 32 * 32;
52
53 if (m_RgbPack)
54 {
55 countR_o = 0;
56 countG_o = 1;
57 countB_o = 2;
58 step = 3;
59 }
60 else
61 {
62 countR_o = 0;
63 countG_o = 32 * 32;
64 countB_o = 2 * 32 * 32;
65 step = 1;
66 }
67
68 for (unsigned int h = 0; h < 32; h++)
69 {
70 for (unsigned int w = 0; w < 32; w++)
71 {
72 inputImageData[countR_o] = boost::numeric_cast<float>(I[countR++]);
73 inputImageData[countG_o] = boost::numeric_cast<float>(I[countG++]);
74 inputImageData[countB_o] = boost::numeric_cast<float>(I[countB++]);
75
76 countR_o += step;
77 countG_o += step;
78 countB_o += step;
79 }
80 }
81
82 const unsigned int label = boost::numeric_cast<unsigned int>(I[0]);
83 return std::make_unique<TTestCaseData>(label, std::move(inputImageData));
84}