blob: 2b46aae26e80f7639cccd5e5003933ecfb82cd8a [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7
8#include "../QuantizationDataSet.hpp"
9#include <iostream>
10#include <fstream>
11#include <vector>
12#include <map>
13
14#define BOOST_FILESYSTEM_NO_DEPRECATED
15
16#include <boost/filesystem/operations.hpp>
17#include <boost/filesystem/fstream.hpp>
18#include <boost/filesystem/path.hpp>
19#include <boost/optional/optional.hpp>
20
21
22using namespace armnnQuantizer;
23
24struct CsvTestHelper {
25
26 CsvTestHelper()
27 {
28 BOOST_TEST_MESSAGE("setup fixture");
29 }
30
31 ~CsvTestHelper()
32 {
33 BOOST_TEST_MESSAGE("teardown fixture");
34 TearDown();
35 }
36
37 std::string CreateTempCsvFile(std::map<int, std::vector<float>> csvData)
38 {
39 boost::filesystem::path fileDir = boost::filesystem::temp_directory_path();
40 boost::filesystem::path p{fileDir / boost::filesystem::unique_path("%%%%-%%%%-%%%%.csv")};
41
42 boost::filesystem::path tensorInput1{fileDir / boost::filesystem::unique_path("input_0_0.raw")};
43 boost::filesystem::path tensorInput2{fileDir / boost::filesystem::unique_path("input_1_0.raw")};
44 boost::filesystem::path tensorInput3{fileDir / boost::filesystem::unique_path("input_2_0.raw")};
45
46 try
47 {
48 boost::filesystem::ofstream ofs{p};
49
50 boost::filesystem::ofstream ofs1{tensorInput1};
51 boost::filesystem::ofstream ofs2{tensorInput2};
52 boost::filesystem::ofstream ofs3{tensorInput3};
53
54
55 for(auto entry : csvData.at(0))
56 {
57 ofs1 << entry << " ";
58 }
59 for(auto entry : csvData.at(1))
60 {
61 ofs2 << entry << " ";
62 }
63 for(auto entry : csvData.at(2))
64 {
65 ofs3 << entry << " ";
66 }
67
68 ofs << "0, 0, " << tensorInput1.c_str() << std::endl;
69 ofs << "2, 0, " << tensorInput3.c_str() << std::endl;
70 ofs << "1, 0, " << tensorInput2.c_str() << std::endl;
71
72 ofs.close();
73 ofs1.close();
74 ofs2.close();
75 ofs3.close();
76 }
77 catch (std::exception &e)
78 {
79 std::cerr << "Unable to write to file at location [" << p.c_str() << "] : " << e.what() << std::endl;
80 BOOST_TEST(false);
81 }
82
83 m_CsvFile = p;
84 return p.string();
85 }
86
87 void TearDown()
88 {
89 RemoveCsvFile();
90 }
91
92 void RemoveCsvFile()
93 {
94 if (m_CsvFile)
95 {
96 try
97 {
98 boost::filesystem::remove(*m_CsvFile);
99 }
100 catch (std::exception &e)
101 {
102 std::cerr << "Unable to delete file [" << *m_CsvFile << "] : " << e.what() << std::endl;
103 BOOST_TEST(false);
104 }
105 }
106 }
107
108 boost::optional<boost::filesystem::path> m_CsvFile;
109};
110
111
112BOOST_AUTO_TEST_SUITE(QuantizationDataSetTests)
113
114BOOST_FIXTURE_TEST_CASE(CheckDataSet, CsvTestHelper)
115{
116
117 std::map<int, std::vector<float>> csvData;
118 csvData.insert(std::pair<int, std::vector<float>>(0, { 0.111111f, 0.222222f, 0.333333f }));
119 csvData.insert(std::pair<int, std::vector<float>>(1, { 0.444444f, 0.555555f, 0.666666f }));
120 csvData.insert(std::pair<int, std::vector<float>>(2, { 0.777777f, 0.888888f, 0.999999f }));
121
122 std::string myCsvFile = CsvTestHelper::CreateTempCsvFile(csvData);
123 QuantizationDataSet dataSet(myCsvFile);
124 BOOST_TEST(!dataSet.IsEmpty());
125
126 int csvRow = 0;
127 for(armnnQuantizer::QuantizationInput input : dataSet)
128 {
129 BOOST_TEST(input.GetPassId() == csvRow);
130
131 BOOST_TEST(input.GetLayerBindingIds().size() == 1);
132 BOOST_TEST(input.GetLayerBindingIds()[0] == 0);
133 BOOST_TEST(input.GetDataForEntry(0).size() == 3);
134
135 // Check that QuantizationInput data for binding ID 0 corresponds to float values
136 // used for populating the CSV file using by QuantizationDataSet
137 BOOST_TEST(input.GetDataForEntry(0).at(0) == csvData.at(csvRow).at(0));
138 BOOST_TEST(input.GetDataForEntry(0).at(1) == csvData.at(csvRow).at(1));
139 BOOST_TEST(input.GetDataForEntry(0).at(2) == csvData.at(csvRow).at(2));
140 ++csvRow;
141 }
142}
143
144BOOST_AUTO_TEST_SUITE_END();