blob: 6d5fc437ceaf8b204a35161b45a5bfeeb0999d75 [file] [log] [blame]
Pablo Telloe96e4f02018-12-21 16:47:23 +00001/*
Giorgio Arena4bdd1772020-12-17 16:47:07 +00002 * Copyright (c) 2019-2020 Arm Limited.
Pablo Telloe96e4f02018-12-21 16:47:23 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_NON_MAX_SUPPRESSION_FIXTURE
25#define ARM_COMPUTE_TEST_NON_MAX_SUPPRESSION_FIXTURE
26
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/TensorShape.h"
29#include "arm_compute/core/Types.h"
30#include "arm_compute/runtime/Tensor.h"
31#include "tests/AssetsLibrary.h"
32#include "tests/Globals.h"
33#include "tests/IAccessor.h"
34#include "tests/framework/Asserts.h"
35#include "tests/framework/Fixture.h"
36#include "tests/validation/reference/NonMaxSuppression.h"
37
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44template <typename TensorType, typename AccessorType, typename FunctionType>
45
46class NMSValidationFixture : public framework::Fixture
47{
48public:
49 template <typename...>
50 void setup(TensorShape input_shape, unsigned int max_output_size, float score_threshold, float nms_threshold)
51 {
52 ARM_COMPUTE_ERROR_ON(max_output_size == 0);
53 ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() != 2);
54 const TensorShape output_shape(max_output_size);
55 const TensorShape scores_shape(input_shape[1]);
56 _target = compute_target(input_shape, scores_shape, output_shape, max_output_size, score_threshold, nms_threshold);
57 _reference = compute_reference(input_shape, scores_shape, output_shape, max_output_size, score_threshold, nms_threshold);
58 }
59
60protected:
61 template <typename U>
Giorgio Arena4bdd1772020-12-17 16:47:07 +000062 void fill(U &&tensor, int i, float lo, float hi)
Pablo Telloe96e4f02018-12-21 16:47:23 +000063 {
Giorgio Arena4bdd1772020-12-17 16:47:07 +000064 std::uniform_real_distribution<float> distribution(lo, hi);
Pablo Telloe96e4f02018-12-21 16:47:23 +000065 library->fill_boxes(tensor, distribution, i);
66 }
67
68 TensorType compute_target(const TensorShape input_shape, const TensorShape scores_shape, const TensorShape output_shape,
69 unsigned int max_output_size, float score_threshold, float nms_threshold)
70 {
71 // Create tensors
72 TensorType bboxes = create_tensor<TensorType>(input_shape, DataType::F32);
73 TensorType scores = create_tensor<TensorType>(scores_shape, DataType::F32);
74 TensorType indices = create_tensor<TensorType>(output_shape, DataType::S32);
75
76 // Create and configure function
77 FunctionType nms_func;
78 nms_func.configure(&bboxes, &scores, &indices, max_output_size, score_threshold, nms_threshold);
79
80 ARM_COMPUTE_EXPECT(bboxes.info()->is_resizable(), framework::LogLevel::ERRORS);
81 ARM_COMPUTE_EXPECT(indices.info()->is_resizable(), framework::LogLevel::ERRORS);
82 ARM_COMPUTE_EXPECT(scores.info()->is_resizable(), framework::LogLevel::ERRORS);
83
84 // Allocate tensors
85 bboxes.allocator()->allocate();
86 indices.allocator()->allocate();
87 scores.allocator()->allocate();
88
89 ARM_COMPUTE_EXPECT(!bboxes.info()->is_resizable(), framework::LogLevel::ERRORS);
90 ARM_COMPUTE_EXPECT(!indices.info()->is_resizable(), framework::LogLevel::ERRORS);
91 ARM_COMPUTE_EXPECT(!scores.info()->is_resizable(), framework::LogLevel::ERRORS);
92
93 // Fill tensors
94 fill(AccessorType(bboxes), 0, 0.f, 1.f);
95 fill(AccessorType(scores), 1, 0.f, 1.f);
96
97 // Compute function
98 nms_func.run();
99 return indices;
100 }
101
102 SimpleTensor<int> compute_reference(const TensorShape input_shape, const TensorShape scores_shape, const TensorShape output_shape,
103 unsigned int max_output_size, float score_threshold, float nms_threshold)
104 {
105 // Create reference
106 SimpleTensor<float> bboxes{ input_shape, DataType::F32 };
107 SimpleTensor<float> scores{ scores_shape, DataType::F32 };
108 SimpleTensor<int> indices{ output_shape, DataType::S32 };
109
110 // Fill reference
111 fill(bboxes, 0, 0.f, 1.f);
112 fill(scores, 1, 0.f, 1.f);
113
114 return reference::non_max_suppression(bboxes, scores, indices, max_output_size, score_threshold, nms_threshold);
115 }
116
117 TensorType _target{};
118 SimpleTensor<int> _reference{};
119};
120
121} // namespace validation
122} // namespace test
123} // namespace arm_compute
124#endif /* ARM_COMPUTE_TEST_NON_MAX_SUPPRESSION_FIXTURE */