blob: 7658e6d529122e29abc96a6b0b9e3c4d0f6623b4 [file] [log] [blame]
Georgios Pinitas587708b2018-12-31 15:43:52 +00001/*
Giorgio Arenaa8e2aeb2021-01-06 11:34:57 +00002 * Copyright (c) 2019-2021 Arm Limited.
Georgios Pinitas587708b2018-12-31 15:43:52 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_MISC_RANDOM_H
25#define ARM_COMPUTE_MISC_RANDOM_H
Georgios Pinitas587708b2018-12-31 15:43:52 +000026
27#include "arm_compute/core/Error.h"
Giorgio Arena6aeb2172020-12-15 15:45:43 +000028#include "utils/Utils.h"
Georgios Pinitas587708b2018-12-31 15:43:52 +000029
30#include <random>
31#include <type_traits>
32
33namespace arm_compute
34{
35namespace utils
36{
37namespace random
38{
39/** Uniform distribution within a given number of sub-ranges
40 *
41 * @tparam T Distribution primitive type
42 */
43template <typename T>
44class RangedUniformDistribution
45{
46public:
Giorgio Arenaa8e2aeb2021-01-06 11:34:57 +000047 static constexpr bool is_fp_16bit = std::is_same<T, half>::value || std::is_same<T, bfloat16>::value;
48 static constexpr bool is_integral = std::is_integral<T>::value && !is_fp_16bit;
Giorgio Arena6aeb2172020-12-15 15:45:43 +000049
Giorgio Arenaa8e2aeb2021-01-06 11:34:57 +000050 using fp_dist = typename std::conditional<is_fp_16bit, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
Giorgio Arena6aeb2172020-12-15 15:45:43 +000051 using DT = typename std::conditional<is_integral, std::uniform_int_distribution<T>, fp_dist>::type;
Georgios Pinitas587708b2018-12-31 15:43:52 +000052 using result_type = T;
53 using range_pair = std::pair<result_type, result_type>;
54
Georgios Pinitas587708b2018-12-31 15:43:52 +000055 /** Constructor
56 *
57 * @param[in] low lowest value in the range (inclusive)
58 * @param[in] high highest value in the range (inclusive for uniform_int_distribution, exclusive for uniform_real_distribution)
59 * @param[in] exclude_ranges Ranges to exclude from the generator
60 */
61 RangedUniformDistribution(result_type low, result_type high, const std::vector<range_pair> &exclude_ranges)
62 : _distributions(), _selector()
63 {
64 result_type clow = low;
65 for(const auto &erange : exclude_ranges)
66 {
Giorgio Arena6aeb2172020-12-15 15:45:43 +000067 result_type epsilon = is_integral ? result_type(1) : result_type(std::numeric_limits<T>::epsilon());
Georgios Pinitas587708b2018-12-31 15:43:52 +000068
69 ARM_COMPUTE_ERROR_ON(clow > erange.first || clow >= erange.second);
70
71 _distributions.emplace_back(DT(clow, erange.first - epsilon));
72 clow = erange.second + epsilon;
73 }
74 ARM_COMPUTE_ERROR_ON(clow > high);
75 _distributions.emplace_back(DT(clow, high));
76 _selector = std::uniform_int_distribution<uint32_t>(0, _distributions.size() - 1);
77 }
78 /** Generate random number
79 *
80 * @tparam URNG Random number generator object type
81 *
82 * @param[in] g A uniform random number generator object, used as the source of randomness.
83 *
84 * @return A new random number.
85 */
86 template <class URNG>
87 result_type operator()(URNG &g)
88 {
89 unsigned int rand_select = _selector(g);
90 return _distributions[rand_select](g);
91 }
92
93private:
94 std::vector<DT> _distributions;
95 std::uniform_int_distribution<uint32_t> _selector;
96};
97} // namespace random
98} // namespace utils
99} // namespace arm_compute
Michalis Spyrouf4643372019-11-29 16:17:13 +0000100#endif /* ARM_COMPUTE_MISC_RANDOM_H */