blob: 9add6d86ad41cdaa7a7faaf4bbcc859580deb23c [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "ClassifierTestCaseData.hpp"
8
9#include <array>
10#include <string>
11#include <vector>
12#include <memory>
13
14///Tf requires RGB images, normalized in range [0, 1] and resized using Bilinear algorithm
15
16
17using ImageSet = std::pair<const std::string, unsigned int>;
18
19template <typename TDataType>
20class ImagePreprocessor
21{
22public:
23 using DataType = TDataType;
24 using TTestCaseData = ClassifierTestCaseData<DataType>;
25
26 enum DataFormat
27 {
28 NHWC,
29 NCHW
30 };
31
32 explicit ImagePreprocessor(const std::string& binaryFileDirectory,
33 unsigned int width,
34 unsigned int height,
35 const std::vector<ImageSet>& imageSet,
36 float scale=1.0,
37 int32_t offset=0,
38 const std::array<float, 3> mean={{0, 0, 0}},
39 const std::array<float, 3> stddev={{1, 1, 1}},
40 DataFormat dataFormat=DataFormat::NHWC)
41 : m_BinaryDirectory(binaryFileDirectory)
42 , m_Height(height)
43 , m_Width(width)
44 , m_Scale(scale)
45 , m_Offset(offset)
46 , m_ImageSet(imageSet)
47 , m_Mean(mean)
48 , m_Stddev(stddev)
49 , m_DataFormat(dataFormat)
50 {
51 }
52
53 std::unique_ptr<TTestCaseData> GetTestCaseData(unsigned int testCaseId);
54
55private:
56 unsigned int GetNumImageElements() const { return 3 * m_Width * m_Height; }
57 unsigned int GetNumImageBytes() const { return sizeof(DataType) * GetNumImageElements(); }
58 unsigned int GetLabelAndResizedImageAsFloat(unsigned int testCaseId,
59 std::vector<float> & result);
60
61 std::string m_BinaryDirectory;
62 unsigned int m_Height;
63 unsigned int m_Width;
64 // Quantization parameters
65 float m_Scale;
66 int32_t m_Offset;
67 const std::vector<ImageSet> m_ImageSet;
68
69 const std::array<float, 3> m_Mean;
70 const std::array<float, 3> m_Stddev;
71
72 DataFormat m_DataFormat;
73};