blob: 9e71f68e53ac392267d4735108c1c222d8d15da3 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7
8#include "../QuantizationDataSet.hpp"
Francis Murtagh532a29d2020-06-29 11:50:01 +01009
10#include <armnn/Optional.hpp>
11#include <Filesystem.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010012#include <iostream>
13#include <fstream>
14#include <vector>
15#include <map>
16
Nina Drozd59e15b02019-04-25 15:45:20 +010017
18using namespace armnnQuantizer;
19
20struct CsvTestHelper {
21
22 CsvTestHelper()
23 {
24 BOOST_TEST_MESSAGE("setup fixture");
25 }
26
27 ~CsvTestHelper()
28 {
29 BOOST_TEST_MESSAGE("teardown fixture");
30 TearDown();
31 }
32
33 std::string CreateTempCsvFile(std::map<int, std::vector<float>> csvData)
34 {
Francis Murtagh532a29d2020-06-29 11:50:01 +010035 fs::path fileDir = fs::temp_directory_path();
36 fs::path p = armnnUtils::Filesystem::NamedTempFile("Armnn-QuantizationCreateTempCsvFileTest-TempFile.csv");
Nina Drozd59e15b02019-04-25 15:45:20 +010037
Francis Murtagh532a29d2020-06-29 11:50:01 +010038 fs::path tensorInput1{fileDir / "input_0_0.raw"};
39 fs::path tensorInput2{fileDir / "input_1_0.raw"};
40 fs::path tensorInput3{fileDir / "input_2_0.raw"};
Nina Drozd59e15b02019-04-25 15:45:20 +010041
42 try
43 {
Francis Murtagh532a29d2020-06-29 11:50:01 +010044 std::ofstream ofs{p};
Nina Drozd59e15b02019-04-25 15:45:20 +010045
Francis Murtagh532a29d2020-06-29 11:50:01 +010046 std::ofstream ofs1{tensorInput1};
47 std::ofstream ofs2{tensorInput2};
48 std::ofstream ofs3{tensorInput3};
Nina Drozd59e15b02019-04-25 15:45:20 +010049
50
51 for(auto entry : csvData.at(0))
52 {
53 ofs1 << entry << " ";
54 }
55 for(auto entry : csvData.at(1))
56 {
57 ofs2 << entry << " ";
58 }
59 for(auto entry : csvData.at(2))
60 {
61 ofs3 << entry << " ";
62 }
63
64 ofs << "0, 0, " << tensorInput1.c_str() << std::endl;
65 ofs << "2, 0, " << tensorInput3.c_str() << std::endl;
66 ofs << "1, 0, " << tensorInput2.c_str() << std::endl;
67
68 ofs.close();
69 ofs1.close();
70 ofs2.close();
71 ofs3.close();
72 }
73 catch (std::exception &e)
74 {
75 std::cerr << "Unable to write to file at location [" << p.c_str() << "] : " << e.what() << std::endl;
76 BOOST_TEST(false);
77 }
78
79 m_CsvFile = p;
80 return p.string();
81 }
82
83 void TearDown()
84 {
85 RemoveCsvFile();
86 }
87
88 void RemoveCsvFile()
89 {
90 if (m_CsvFile)
91 {
92 try
93 {
Francis Murtagh532a29d2020-06-29 11:50:01 +010094 fs::remove(m_CsvFile.value());
Nina Drozd59e15b02019-04-25 15:45:20 +010095 }
96 catch (std::exception &e)
97 {
Francis Murtagh532a29d2020-06-29 11:50:01 +010098 std::cerr << "Unable to delete file [" << m_CsvFile.value() << "] : " << e.what() << std::endl;
Nina Drozd59e15b02019-04-25 15:45:20 +010099 BOOST_TEST(false);
100 }
101 }
102 }
103
Francis Murtagh532a29d2020-06-29 11:50:01 +0100104 armnn::Optional<fs::path> m_CsvFile;
Nina Drozd59e15b02019-04-25 15:45:20 +0100105};
106
107
108BOOST_AUTO_TEST_SUITE(QuantizationDataSetTests)
109
110BOOST_FIXTURE_TEST_CASE(CheckDataSet, CsvTestHelper)
111{
112
113 std::map<int, std::vector<float>> csvData;
114 csvData.insert(std::pair<int, std::vector<float>>(0, { 0.111111f, 0.222222f, 0.333333f }));
115 csvData.insert(std::pair<int, std::vector<float>>(1, { 0.444444f, 0.555555f, 0.666666f }));
116 csvData.insert(std::pair<int, std::vector<float>>(2, { 0.777777f, 0.888888f, 0.999999f }));
117
118 std::string myCsvFile = CsvTestHelper::CreateTempCsvFile(csvData);
119 QuantizationDataSet dataSet(myCsvFile);
120 BOOST_TEST(!dataSet.IsEmpty());
121
122 int csvRow = 0;
123 for(armnnQuantizer::QuantizationInput input : dataSet)
124 {
125 BOOST_TEST(input.GetPassId() == csvRow);
126
127 BOOST_TEST(input.GetLayerBindingIds().size() == 1);
128 BOOST_TEST(input.GetLayerBindingIds()[0] == 0);
129 BOOST_TEST(input.GetDataForEntry(0).size() == 3);
130
131 // Check that QuantizationInput data for binding ID 0 corresponds to float values
132 // used for populating the CSV file using by QuantizationDataSet
133 BOOST_TEST(input.GetDataForEntry(0).at(0) == csvData.at(csvRow).at(0));
134 BOOST_TEST(input.GetDataForEntry(0).at(1) == csvData.at(csvRow).at(1));
135 BOOST_TEST(input.GetDataForEntry(0).at(2) == csvData.at(csvRow).at(2));
136 ++csvRow;
137 }
138}
139
140BOOST_AUTO_TEST_SUITE_END();