Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 1 | /* |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 2 | * Copyright (c) 2018-2020 Arm Limited. |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
| 24 | #include "arm_compute/core/CPP/kernels/CPPBoxWithNonMaximaSuppressionLimitKernel.h" |
| 25 | |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 26 | #include "arm_compute/core/Helpers.h" |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 27 | #include "src/core/helpers/WindowHelpers.h" |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 28 | |
| 29 | #include <algorithm> |
| 30 | #include <cmath> |
| 31 | |
| 32 | namespace arm_compute |
| 33 | { |
| 34 | namespace |
| 35 | { |
| 36 | template <typename T> |
| 37 | std::vector<int> SoftNMS(const ITensor *proposals, std::vector<std::vector<T>> &scores_in, std::vector<int> inds, const BoxNMSLimitInfo &info, int class_id) |
| 38 | { |
| 39 | std::vector<int> keep; |
| 40 | const int proposals_width = proposals->info()->dimension(1); |
| 41 | |
| 42 | std::vector<T> x1(proposals_width); |
| 43 | std::vector<T> y1(proposals_width); |
| 44 | std::vector<T> x2(proposals_width); |
| 45 | std::vector<T> y2(proposals_width); |
| 46 | std::vector<T> areas(proposals_width); |
| 47 | |
| 48 | for(int i = 0; i < proposals_width; ++i) |
| 49 | { |
| 50 | x1[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4, i))); |
| 51 | y1[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4 + 1, i))); |
| 52 | x2[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4 + 2, i))); |
| 53 | y2[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4 + 3, i))); |
| 54 | areas[i] = (x2[i] - x1[i] + 1.0) * (y2[i] - y1[i] + 1.0); |
| 55 | } |
| 56 | |
Manuel Bottini | 5209be5 | 2019-02-13 16:34:56 +0000 | [diff] [blame] | 57 | // Note: Soft NMS scores have already been initialized with input scores |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 58 | |
| 59 | while(!inds.empty()) |
| 60 | { |
| 61 | // Find proposal with max score among remaining proposals |
| 62 | int max_pos = 0; |
| 63 | for(unsigned int i = 1; i < inds.size(); ++i) |
| 64 | { |
| 65 | if(scores_in[class_id][inds.at(i)] > scores_in[class_id][inds.at(max_pos)]) |
| 66 | { |
| 67 | max_pos = i; |
| 68 | } |
| 69 | } |
| 70 | int element = inds.at(max_pos); |
| 71 | keep.push_back(element); |
| 72 | std::swap(inds.at(0), inds.at(max_pos)); |
| 73 | |
| 74 | // Remove first element and compute IoU of the remaining boxes with identified max box |
| 75 | inds.erase(inds.begin()); |
| 76 | |
| 77 | std::vector<int> sorted_indices_temp; |
| 78 | for(auto idx : inds) |
| 79 | { |
| 80 | const auto xx1 = std::max(x1[idx], x1[element]); |
| 81 | const auto yy1 = std::max(y1[idx], y1[element]); |
| 82 | const auto xx2 = std::min(x2[idx], x2[element]); |
| 83 | const auto yy2 = std::min(y2[idx], y2[element]); |
| 84 | |
| 85 | const auto w = std::max((xx2 - xx1 + 1.f), 0.f); |
| 86 | const auto h = std::max((yy2 - yy1 + 1.f), 0.f); |
| 87 | const auto inter = w * h; |
| 88 | const auto ovr = inter / (areas[element] + areas[idx] - inter); |
| 89 | |
| 90 | // Update scores based on computed IoU, overlap threshold and NMS method |
| 91 | T weight; |
| 92 | switch(info.soft_nms_method()) |
| 93 | { |
| 94 | case NMSType::LINEAR: |
| 95 | weight = (ovr > info.nms()) ? (1.f - ovr) : 1.f; |
| 96 | break; |
| 97 | case NMSType::GAUSSIAN: // Gaussian |
| 98 | weight = std::exp(-1.f * ovr * ovr / info.soft_nms_sigma()); |
| 99 | break; |
| 100 | case NMSType::ORIGINAL: // Original NMS |
| 101 | weight = (ovr > info.nms()) ? 0.f : 1.f; |
| 102 | break; |
| 103 | default: |
| 104 | ARM_COMPUTE_ERROR("Not supported"); |
| 105 | } |
| 106 | |
| 107 | // Discard boxes with new scores below min threshold and update pending indices |
| 108 | scores_in[class_id][idx] *= weight; |
| 109 | if(scores_in[class_id][idx] >= info.soft_nms_min_score_thres()) |
| 110 | { |
| 111 | sorted_indices_temp.push_back(idx); |
| 112 | } |
| 113 | } |
| 114 | inds = sorted_indices_temp; |
| 115 | } |
| 116 | |
| 117 | return keep; |
| 118 | } |
| 119 | |
| 120 | template <typename T> |
| 121 | std::vector<int> NonMaximaSuppression(const ITensor *proposals, std::vector<int> sorted_indices, const BoxNMSLimitInfo &info, int class_id) |
| 122 | { |
| 123 | std::vector<int> keep; |
| 124 | |
| 125 | const int proposals_width = proposals->info()->dimension(1); |
| 126 | |
| 127 | std::vector<T> x1(proposals_width); |
| 128 | std::vector<T> y1(proposals_width); |
| 129 | std::vector<T> x2(proposals_width); |
| 130 | std::vector<T> y2(proposals_width); |
| 131 | std::vector<T> areas(proposals_width); |
| 132 | |
| 133 | for(int i = 0; i < proposals_width; ++i) |
| 134 | { |
| 135 | x1[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4, i))); |
| 136 | y1[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4 + 1, i))); |
| 137 | x2[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4 + 2, i))); |
| 138 | y2[i] = *reinterpret_cast<T *>(proposals->ptr_to_element(Coordinates(class_id * 4 + 3, i))); |
| 139 | areas[i] = (x2[i] - x1[i] + 1.0) * (y2[i] - y1[i] + 1.0); |
| 140 | } |
| 141 | |
| 142 | while(!sorted_indices.empty()) |
| 143 | { |
| 144 | int i = sorted_indices.at(0); |
| 145 | keep.push_back(i); |
| 146 | |
| 147 | std::vector<int> sorted_indices_temp = sorted_indices; |
| 148 | std::vector<int> new_indices; |
| 149 | sorted_indices_temp.erase(sorted_indices_temp.begin()); |
| 150 | |
| 151 | for(unsigned int j = 0; j < sorted_indices_temp.size(); ++j) |
| 152 | { |
Manuel Bottini | 5209be5 | 2019-02-13 16:34:56 +0000 | [diff] [blame] | 153 | const float xx1 = std::max(x1[sorted_indices_temp.at(j)], x1[i]); |
| 154 | const float yy1 = std::max(y1[sorted_indices_temp.at(j)], y1[i]); |
| 155 | const float xx2 = std::min(x2[sorted_indices_temp.at(j)], x2[i]); |
| 156 | const float yy2 = std::min(y2[sorted_indices_temp.at(j)], y2[i]); |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 157 | |
Manuel Bottini | 5209be5 | 2019-02-13 16:34:56 +0000 | [diff] [blame] | 158 | const float w = std::max((xx2 - xx1 + 1.f), 0.f); |
| 159 | const float h = std::max((yy2 - yy1 + 1.f), 0.f); |
| 160 | const float inter = w * h; |
| 161 | const float ovr = inter / (areas[i] + areas[sorted_indices_temp.at(j)] - inter); |
| 162 | const float ctr_x = xx1 + (w / 2); |
| 163 | const float ctr_y = yy1 + (h / 2); |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 164 | |
Manuel Bottini | 5209be5 | 2019-02-13 16:34:56 +0000 | [diff] [blame] | 165 | // If suppress_size is specified, filter the boxes based on their size and position |
| 166 | const bool keep_size = !info.suppress_size() || (w >= info.min_size() && h >= info.min_size() && ctr_x < info.im_width() && ctr_y < info.im_height()); |
| 167 | if(ovr <= info.nms() && keep_size) |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 168 | { |
| 169 | new_indices.push_back(j); |
| 170 | } |
| 171 | } |
| 172 | |
| 173 | const unsigned int new_indices_size = new_indices.size(); |
| 174 | std::vector<int> new_sorted_indices(new_indices_size); |
| 175 | for(unsigned int i = 0; i < new_indices_size; ++i) |
| 176 | { |
| 177 | new_sorted_indices[i] = sorted_indices[new_indices[i] + 1]; |
| 178 | } |
| 179 | sorted_indices = new_sorted_indices; |
| 180 | } |
| 181 | |
| 182 | return keep; |
| 183 | } |
| 184 | } // namespace |
| 185 | |
| 186 | CPPBoxWithNonMaximaSuppressionLimitKernel::CPPBoxWithNonMaximaSuppressionLimitKernel() |
| 187 | : _scores_in(nullptr), _boxes_in(nullptr), _batch_splits_in(nullptr), _scores_out(nullptr), _boxes_out(nullptr), _classes(nullptr), _batch_splits_out(nullptr), _keeps(nullptr), _keeps_size(nullptr), |
| 188 | _info() |
| 189 | { |
| 190 | } |
| 191 | |
| 192 | bool CPPBoxWithNonMaximaSuppressionLimitKernel::is_parallelisable() const |
| 193 | { |
| 194 | return false; |
| 195 | } |
| 196 | |
| 197 | template <typename T> |
| 198 | void CPPBoxWithNonMaximaSuppressionLimitKernel::run_nmslimit() |
| 199 | { |
| 200 | const int batch_size = _batch_splits_in == nullptr ? 1 : _batch_splits_in->info()->dimension(0); |
| 201 | const int num_classes = _scores_in->info()->dimension(0); |
| 202 | const int scores_count = _scores_in->info()->dimension(1); |
| 203 | std::vector<int> total_keep_per_batch(batch_size); |
| 204 | std::vector<std::vector<int>> keeps(num_classes); |
| 205 | int total_keep_count = 0; |
| 206 | |
| 207 | std::vector<std::vector<T>> in_scores(num_classes, std::vector<T>(scores_count)); |
| 208 | for(int i = 0; i < scores_count; ++i) |
| 209 | { |
| 210 | for(int j = 0; j < num_classes; ++j) |
| 211 | { |
| 212 | in_scores[j][i] = *reinterpret_cast<const T *>(_scores_in->ptr_to_element(Coordinates(j, i))); |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | int offset = 0; |
| 217 | int cur_start_idx = 0; |
| 218 | for(int b = 0; b < batch_size; ++b) |
| 219 | { |
| 220 | const int num_boxes = _batch_splits_in == nullptr ? 1 : static_cast<int>(*reinterpret_cast<T *>(_batch_splits_in->ptr_to_element(Coordinates(b)))); |
Manuel Bottini | 5209be5 | 2019-02-13 16:34:56 +0000 | [diff] [blame] | 221 | // Skip first class if there is more than 1 except if the number of classes is 1. |
| 222 | const int j_start = (num_classes == 1 ? 0 : 1); |
| 223 | for(int j = j_start; j < num_classes; ++j) |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 224 | { |
| 225 | std::vector<T> cur_scores(scores_count); |
| 226 | std::vector<int> inds; |
| 227 | for(int i = 0; i < scores_count; ++i) |
| 228 | { |
| 229 | const T score = in_scores[j][i]; |
| 230 | cur_scores[i] = score; |
| 231 | |
| 232 | if(score > _info.score_thresh()) |
| 233 | { |
| 234 | inds.push_back(i); |
| 235 | } |
| 236 | } |
| 237 | if(_info.soft_nms_enabled()) |
| 238 | { |
| 239 | keeps[j] = SoftNMS(_boxes_in, in_scores, inds, _info, j); |
| 240 | } |
| 241 | else |
| 242 | { |
| 243 | std::sort(inds.data(), inds.data() + inds.size(), |
| 244 | [&cur_scores](int lhs, int rhs) |
| 245 | { |
| 246 | return cur_scores[lhs] > cur_scores[rhs]; |
| 247 | }); |
| 248 | |
| 249 | keeps[j] = NonMaximaSuppression<T>(_boxes_in, inds, _info, j); |
| 250 | } |
| 251 | total_keep_count += keeps[j].size(); |
| 252 | } |
| 253 | |
| 254 | if(_info.detections_per_im() > 0 && total_keep_count > _info.detections_per_im()) |
| 255 | { |
| 256 | // merge all scores (represented by indices) together and sort |
| 257 | auto get_all_scores_sorted = [&in_scores, &keeps, total_keep_count]() |
| 258 | { |
| 259 | std::vector<T> ret(total_keep_count); |
| 260 | |
| 261 | int ret_idx = 0; |
| 262 | for(unsigned int i = 1; i < keeps.size(); ++i) |
| 263 | { |
| 264 | auto &cur_keep = keeps[i]; |
| 265 | for(auto &ckv : cur_keep) |
| 266 | { |
| 267 | ret[ret_idx++] = in_scores[i][ckv]; |
| 268 | } |
| 269 | } |
| 270 | |
| 271 | std::sort(ret.data(), ret.data() + ret.size()); |
| 272 | |
| 273 | return ret; |
| 274 | }; |
| 275 | |
| 276 | auto all_scores_sorted = get_all_scores_sorted(); |
| 277 | const T image_thresh = all_scores_sorted[all_scores_sorted.size() - _info.detections_per_im()]; |
| 278 | for(int j = 1; j < num_classes; ++j) |
| 279 | { |
| 280 | auto &cur_keep = keeps[j]; |
| 281 | std::vector<int> new_keeps_j; |
| 282 | for(auto &k : cur_keep) |
| 283 | { |
| 284 | if(in_scores[j][k] >= image_thresh) |
| 285 | { |
| 286 | new_keeps_j.push_back(k); |
| 287 | } |
| 288 | } |
| 289 | keeps[j] = new_keeps_j; |
| 290 | } |
| 291 | total_keep_count = _info.detections_per_im(); |
| 292 | } |
| 293 | |
| 294 | total_keep_per_batch[b] = total_keep_count; |
| 295 | |
| 296 | // Write results |
| 297 | int cur_out_idx = 0; |
Manuel Bottini | 5209be5 | 2019-02-13 16:34:56 +0000 | [diff] [blame] | 298 | for(int j = j_start; j < num_classes; ++j) |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 299 | { |
| 300 | auto &cur_keep = keeps[j]; |
| 301 | auto cur_out_scores = reinterpret_cast<T *>(_scores_out->ptr_to_element(Coordinates(cur_start_idx + cur_out_idx))); |
| 302 | auto cur_out_classes = reinterpret_cast<T *>(_classes->ptr_to_element(Coordinates(cur_start_idx + cur_out_idx))); |
| 303 | const int box_column = (cur_start_idx + cur_out_idx) * 4; |
| 304 | |
| 305 | for(unsigned int k = 0; k < cur_keep.size(); ++k) |
| 306 | { |
| 307 | cur_out_scores[k] = in_scores[j][cur_keep[k]]; |
| 308 | cur_out_classes[k] = static_cast<T>(j); |
| 309 | auto cur_out_box_row0 = reinterpret_cast<T *>(_boxes_out->ptr_to_element(Coordinates(box_column + 0, k))); |
| 310 | auto cur_out_box_row1 = reinterpret_cast<T *>(_boxes_out->ptr_to_element(Coordinates(box_column + 1, k))); |
| 311 | auto cur_out_box_row2 = reinterpret_cast<T *>(_boxes_out->ptr_to_element(Coordinates(box_column + 2, k))); |
| 312 | auto cur_out_box_row3 = reinterpret_cast<T *>(_boxes_out->ptr_to_element(Coordinates(box_column + 3, k))); |
| 313 | *cur_out_box_row0 = *reinterpret_cast<const T *>(_boxes_in->ptr_to_element(Coordinates(j * 4 + 0, cur_keep[k]))); |
| 314 | *cur_out_box_row1 = *reinterpret_cast<const T *>(_boxes_in->ptr_to_element(Coordinates(j * 4 + 1, cur_keep[k]))); |
| 315 | *cur_out_box_row2 = *reinterpret_cast<const T *>(_boxes_in->ptr_to_element(Coordinates(j * 4 + 2, cur_keep[k]))); |
| 316 | *cur_out_box_row3 = *reinterpret_cast<const T *>(_boxes_in->ptr_to_element(Coordinates(j * 4 + 3, cur_keep[k]))); |
| 317 | } |
| 318 | |
| 319 | cur_out_idx += cur_keep.size(); |
| 320 | } |
| 321 | |
| 322 | if(_keeps != nullptr) |
| 323 | { |
| 324 | cur_out_idx = 0; |
| 325 | for(int j = 0; j < num_classes; ++j) |
| 326 | { |
| 327 | for(unsigned int i = 0; i < keeps[j].size(); ++i) |
| 328 | { |
| 329 | *reinterpret_cast<T *>(_keeps->ptr_to_element(Coordinates(cur_start_idx + cur_out_idx + i))) = static_cast<T>(keeps[j].at(i)); |
| 330 | } |
Michele Di Giorgio | c8df89f | 2018-11-16 10:02:26 +0000 | [diff] [blame] | 331 | *reinterpret_cast<uint32_t *>(_keeps_size->ptr_to_element(Coordinates(j + b * num_classes))) = keeps[j].size(); |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 332 | cur_out_idx += keeps[j].size(); |
| 333 | } |
| 334 | } |
| 335 | |
| 336 | offset += num_boxes; |
| 337 | cur_start_idx += total_keep_count; |
| 338 | } |
| 339 | |
| 340 | if(_batch_splits_out != nullptr) |
| 341 | { |
| 342 | for(int b = 0; b < batch_size; ++b) |
| 343 | { |
| 344 | *reinterpret_cast<float *>(_batch_splits_out->ptr_to_element(Coordinates(b))) = total_keep_per_batch[b]; |
| 345 | } |
| 346 | } |
| 347 | } |
| 348 | |
| 349 | void CPPBoxWithNonMaximaSuppressionLimitKernel::configure(const ITensor *scores_in, const ITensor *boxes_in, const ITensor *batch_splits_in, ITensor *scores_out, ITensor *boxes_out, ITensor *classes, |
| 350 | ITensor *batch_splits_out, ITensor *keeps, ITensor *keeps_size, const BoxNMSLimitInfo info) |
| 351 | { |
| 352 | ARM_COMPUTE_ERROR_ON_NULLPTR(scores_in, boxes_in, scores_out, boxes_out, classes); |
| 353 | ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(scores_in, 1, DataType::F16, DataType::F32); |
Michele Di Giorgio | 70ad619 | 2019-09-06 17:51:37 +0100 | [diff] [blame] | 354 | ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(scores_in, boxes_in, scores_out); |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 355 | const unsigned int num_classes = scores_in->info()->dimension(0); |
| 356 | |
| 357 | ARM_COMPUTE_UNUSED(num_classes); |
| 358 | ARM_COMPUTE_ERROR_ON_MSG((4 * num_classes) != boxes_in->info()->dimension(0), "First dimension of input boxes must be of size 4*num_classes"); |
| 359 | ARM_COMPUTE_ERROR_ON_MSG(scores_in->info()->dimension(1) != boxes_in->info()->dimension(1), "Input scores and input boxes must have the same number of rows"); |
Michele Di Giorgio | c8df89f | 2018-11-16 10:02:26 +0000 | [diff] [blame] | 360 | |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 361 | ARM_COMPUTE_ERROR_ON(scores_out->info()->dimension(0) != boxes_out->info()->dimension(1)); |
| 362 | ARM_COMPUTE_ERROR_ON(boxes_out->info()->dimension(0) != 4); |
Michele Di Giorgio | 6b612f5 | 2019-09-05 12:30:22 +0100 | [diff] [blame] | 363 | ARM_COMPUTE_ERROR_ON(scores_out->info()->dimension(0) != classes->info()->dimension(0)); |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 364 | if(keeps != nullptr) |
| 365 | { |
| 366 | ARM_COMPUTE_ERROR_ON_MSG(keeps_size == nullptr, "keeps_size cannot be nullptr if keeps has to be provided as output"); |
| 367 | ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(scores_in, keeps); |
Michele Di Giorgio | c8df89f | 2018-11-16 10:02:26 +0000 | [diff] [blame] | 368 | ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(keeps_size, 1, DataType::U32); |
Michalis Spyrou | 2709d61 | 2018-09-19 09:46:47 +0100 | [diff] [blame] | 369 | ARM_COMPUTE_ERROR_ON(scores_out->info()->dimension(0) != keeps->info()->dimension(0)); |
| 370 | ARM_COMPUTE_ERROR_ON(num_classes != keeps_size->info()->dimension(0)); |
| 371 | } |
| 372 | if(batch_splits_in != nullptr) |
| 373 | { |
| 374 | ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(scores_in, batch_splits_in); |
| 375 | } |
| 376 | if(batch_splits_out != nullptr) |
| 377 | { |
| 378 | ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(scores_in, batch_splits_out); |
| 379 | } |
| 380 | |
| 381 | _scores_in = scores_in; |
| 382 | _boxes_in = boxes_in; |
| 383 | _batch_splits_in = batch_splits_in; |
| 384 | _scores_out = scores_out; |
| 385 | _boxes_out = boxes_out; |
| 386 | _classes = classes; |
| 387 | _batch_splits_out = batch_splits_out; |
| 388 | _keeps = keeps; |
| 389 | _keeps_size = keeps_size; |
| 390 | _info = info; |
| 391 | |
| 392 | // Configure kernel window |
| 393 | Window win = calculate_max_window(*scores_in->info(), Steps(scores_in->info()->dimension(0))); |
| 394 | |
| 395 | IKernel::configure(win); |
| 396 | } |
| 397 | |
| 398 | void CPPBoxWithNonMaximaSuppressionLimitKernel::run(const Window &window, const ThreadInfo &info) |
| 399 | { |
| 400 | ARM_COMPUTE_UNUSED(info); |
| 401 | ARM_COMPUTE_UNUSED(window); |
| 402 | ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); |
| 403 | ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(IKernel::window(), window); |
| 404 | |
| 405 | switch(_scores_in->info()->data_type()) |
| 406 | { |
| 407 | case DataType::F32: |
| 408 | run_nmslimit<float>(); |
| 409 | break; |
| 410 | case DataType::F16: |
| 411 | run_nmslimit<half>(); |
| 412 | break; |
| 413 | default: |
| 414 | ARM_COMPUTE_ERROR("Not supported"); |
| 415 | } |
| 416 | } |
| 417 | } // namespace arm_compute |