blob: 947646cdfc0d098d6f62dff1a1761cc14499fc43 [file] [log] [blame]
John Richardson80127542018-06-07 11:07:00 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_HOG_MULTI_DETECTION_FIXTURE
25#define ARM_COMPUTE_TEST_HOG_MULTI_DETECTION_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29#include "tests/Globals.h"
30#include "tests/Utils.h"
31#include "tests/framework/Fixture.h"
32
33namespace arm_compute
34{
35namespace test
36{
37namespace benchmark
38{
39template <typename TensorType,
40 typename HOGType,
41 typename MultiHOGType,
42 typename DetectionWindowArrayType,
43 typename DetectionWindowStrideType,
44 typename Function,
45 typename Accessor,
46 typename HOGAccessorType,
47 typename Size2DArrayAccessorType>
48class HOGMultiDetectionFixture : public framework::Fixture
49{
50public:
51 template <typename...>
52 void setup(std::string image, std::vector<HOGInfo> models, Format format, BorderMode border_mode, bool non_maxima_suppression)
53 {
54 // Only defined borders supported
55 ARM_COMPUTE_ERROR_ON(border_mode == BorderMode::UNDEFINED);
56
57 std::mt19937 generator(library->seed());
58 std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
59 uint8_t constant_border_value = static_cast<uint8_t>(distribution_u8(generator));
60
61 // Load the image (cached by the library if loaded before)
62 const RawTensor &raw = library->get(image, format);
63
64 // Initialize descriptors vector
65 std::vector<std::vector<float>> descriptors(models.size());
66
67 // Resize detection window_strides for index access
68 detection_window_strides.resize(models.size());
69
70 // Initialiize MultiHOG and detection windows
71 initialize_batch(models, multi_hog, descriptors, detection_window_strides);
72
73 // Create tensors
74 src = create_tensor<TensorType>(raw.shape(), format);
75
76 // Use default values for threshold and min_distance
77 const float threshold = 0.f;
78 const float min_distance = 1.f;
79
80 hog_multi_detection_func.configure(&src,
81 &multi_hog,
82 &detection_windows,
83 &detection_window_strides,
84 border_mode,
85 constant_border_value,
86 threshold,
87 non_maxima_suppression,
88 min_distance);
89
90 // Reset detection windows
91 detection_windows.clear();
92
93 // Allocate tensor
94 src.allocator()->allocate();
95
96 library->fill(Accessor(src), raw);
97 }
98
99 void run()
100 {
101 hog_multi_detection_func.run();
102 }
103
104 void sync()
105 {
106 sync_if_necessary<TensorType>();
107 }
108
109private:
110 void initialize_batch(const std::vector<HOGInfo> &models, MultiHOGType &multi_hog,
111 std::vector<std::vector<float>> &descriptors, DetectionWindowStrideType &detection_window_strides)
112 {
113 for(unsigned i = 0; i < models.size(); ++i)
114 {
115 auto hog_model = reinterpret_cast<HOGType *>(multi_hog.model(i));
116 hog_model->init(models[i]);
117
118 // Initialise descriptor (linear SVM coefficients).
119 std::random_device::result_type seed = 0;
120 descriptors.at(i) = generate_random_real(models[i].descriptor_size(), -0.505f, 0.495f, seed);
121
122 // Copy HOG descriptor values to HOG memory
123 {
124 HOGAccessorType hog_accessor(*hog_model);
125 std::memcpy(hog_accessor.descriptor(), descriptors.at(i).data(), descriptors.at(i).size() * sizeof(float));
126 }
127
128 // Initialize detection window stride
129 Size2DArrayAccessorType accessor(detection_window_strides);
130 accessor.at(i) = models[i].block_stride();
131 }
132 }
133
134private:
135 static const unsigned int model_size = 4;
136 static const unsigned int max_num_detection_windows = 100000;
137
138 MultiHOGType multi_hog{ model_size };
139 DetectionWindowStrideType detection_window_strides{ model_size };
140 DetectionWindowArrayType detection_windows{ max_num_detection_windows };
141
142 TensorType src{};
143 Function hog_multi_detection_func{};
144};
145} // namespace benchmark
146} // namespace test
147} // namespace arm_compute
148#endif /* ARM_COMPUTE_TEST_HOG_MULTI_DETECTION_FIXTURE */