Jeremy Johnson | 0a6d1de | 2023-09-27 14:59:43 +0100 | [diff] [blame] | 1 | # Copyright (c) 2024, ARM Limited. |
| 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | import hashlib |
| 4 | import logging |
| 5 | |
| 6 | import generator.tosa_utils as gtu |
| 7 | import numpy as np |
| 8 | from tosa.DType import DType |
| 9 | |
| 10 | logging.basicConfig() |
| 11 | logger = logging.getLogger("tosa_verif_build_tests") |
| 12 | |
| 13 | |
| 14 | class TosaRandomGenerator(np.random.Generator): |
| 15 | """Equivalent to numpy.default_rng, with support for TOSA data types""" |
| 16 | |
| 17 | def __init__(self, seed, restrict_range_by_type={}): |
| 18 | """Create random generator with TOSA type support. |
| 19 | |
| 20 | seed: integer seed |
| 21 | restrict_range_by_type: see TosaHashRandomGenerator.__init__() |
| 22 | """ |
| 23 | self._restrict_range_by_type = restrict_range_by_type |
| 24 | self._seed = int(seed) |
| 25 | self._bitgen = np.random.PCG64(self._seed) |
| 26 | super().__init__(self._bitgen) |
| 27 | |
| 28 | @property |
| 29 | def seed(self): |
| 30 | return self._seed |
| 31 | |
| 32 | @property |
| 33 | def hexSeed(self): |
| 34 | return hex(self._seed) |
| 35 | |
| 36 | def dTypeRange(self, dtype, high_inclusive=False): |
| 37 | """Returns range tuple for given dtype. |
| 38 | |
| 39 | dtype: DType |
| 40 | high_inclusive: True for inclusive high values |
| 41 | Returns: dtype value range boundaries tuple (low, high) |
| 42 | The high boundary is excluded in the range unless high_inclusive is True |
| 43 | """ |
| 44 | if dtype in self._restrict_range_by_type: |
| 45 | rng = self._restrict_range_by_type[dtype] |
| 46 | elif dtype == DType.BOOL: |
| 47 | rng = (0, 2) |
| 48 | elif dtype == DType.UINT8: |
| 49 | rng = (0, 256) |
| 50 | elif dtype == DType.UINT16: |
| 51 | rng = (0, 65536) |
| 52 | elif dtype == DType.INT4: |
| 53 | # TOSA specific INT4 weight range from -7 to 7 |
| 54 | rng = (-7, 8) |
| 55 | elif dtype == DType.INT8: |
| 56 | rng = (-128, 128) |
| 57 | elif dtype == DType.INT16: |
| 58 | rng = (-32768, 32768) |
| 59 | elif dtype == DType.INT32: |
| 60 | rng = (-(1 << 31), (1 << 31)) |
| 61 | elif dtype == DType.INT48: |
| 62 | rng = (-(1 << 47), (1 << 47)) |
| 63 | else: |
| 64 | # Float types and SHAPE should be in _restrict_range_by_type dict |
| 65 | raise Exception("Unknown supported dtype: {}".format(dtype)) |
| 66 | |
| 67 | if dtype in (DType.FP16, DType.BF16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2): |
| 68 | # Floating point - range is always inclusive |
| 69 | return rng |
| 70 | else: |
| 71 | # Integer |
| 72 | if not high_inclusive: |
| 73 | # Exclusive high: low <= range < high |
| 74 | return rng |
| 75 | else: |
| 76 | # Inclusive range: low <= range <= high |
| 77 | return (rng[0], rng[1] - 1) |
| 78 | |
| 79 | def randInt(self, low=0, high=256): |
| 80 | return np.int32(self.integers(low=low, high=high, size=1))[0] |
| 81 | |
| 82 | def randNumberDType(self, dtype): |
| 83 | low, high = self.dTypeRange(dtype) |
| 84 | |
| 85 | if dtype == DType.FP32: |
| 86 | return np.float32(self.uniform(low=low, high=high)) |
| 87 | elif dtype == DType.FP16: |
| 88 | return np.float16(self.uniform(low=low, high=high)) |
| 89 | elif dtype == DType.BF16: |
| 90 | rand_f32 = np.float32(self.uniform(low=low, high=high)) |
| 91 | return gtu.vect_f32_to_bf16(rand_f32) |
| 92 | elif dtype == DType.FP8E4M3: |
| 93 | rand_f32 = np.float32(self.uniform(low=low, high=high)) |
| 94 | return gtu.vect_f32_to_fp8e4m3(rand_f32) |
| 95 | elif dtype == DType.FP8E5M2: |
| 96 | rand_f32 = np.float32(self.uniform(low=low, high=high)) |
| 97 | return gtu.vect_f32_to_fp8e5m2(rand_f32) |
| 98 | elif dtype == DType.BOOL: |
| 99 | return self.choice([False, True]) |
| 100 | elif dtype == DType.INT48 or dtype == DType.SHAPE: |
| 101 | # Special size |
| 102 | return np.int64(self.integers(low, high, size=1))[0] |
| 103 | |
| 104 | return np.int32(self.integers(low, high, size=1))[0] |
| 105 | |
| 106 | def randTensor(self, shape, dtype, data_range=None): |
| 107 | if data_range is None: |
| 108 | low, high = self.dTypeRange(dtype) |
| 109 | else: |
| 110 | low, high = data_range |
| 111 | |
| 112 | if dtype == DType.BOOL: |
| 113 | return np.bool_(self.choice(a=[False, True], size=shape)) |
| 114 | elif dtype == DType.INT4: |
| 115 | return np.int8(self.integers(low=low, high=high, size=shape)) |
| 116 | elif dtype == DType.INT8: |
| 117 | return np.int8(self.integers(low=low, high=high, size=shape)) |
| 118 | elif dtype == DType.UINT8: |
| 119 | return np.uint8(self.integers(low=low, high=high, size=shape)) |
| 120 | elif dtype == DType.INT16: |
| 121 | return np.int16(self.integers(low=low, high=high, size=shape)) |
| 122 | elif dtype == DType.UINT16: |
| 123 | return np.uint16(self.integers(low=low, high=high, size=shape)) |
| 124 | elif dtype in (DType.INT48, DType.SHAPE): |
| 125 | return np.int64(self.integers(low=low, high=high, size=shape)) |
| 126 | elif dtype in ( |
| 127 | DType.FP16, |
| 128 | DType.BF16, |
| 129 | DType.FP32, |
| 130 | DType.FP8E4M3, |
| 131 | DType.FP8E5M2, |
| 132 | ): |
| 133 | f_tensor = self.uniform(low=low, high=high, size=shape) |
| 134 | |
| 135 | if dtype == DType.FP16: |
| 136 | return np.float16(f_tensor) |
| 137 | else: |
| 138 | f32_tensor = np.float32(f_tensor) |
| 139 | if dtype == DType.BF16: |
| 140 | # Floor the last 16 bits of each f32 value |
| 141 | return np.float32(gtu.vect_f32_to_bf16(f32_tensor)) |
| 142 | elif dtype == DType.FP8E4M3: |
| 143 | return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor)) |
| 144 | elif dtype == DType.FP8E5M2: |
| 145 | return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor)) |
| 146 | else: |
| 147 | return f32_tensor |
| 148 | else: |
| 149 | # All other integer types |
| 150 | return np.int32(self.integers(low=low, high=high, size=shape)) |
| 151 | |
| 152 | |
| 153 | class TosaHashRandomGenerator(TosaRandomGenerator): |
| 154 | """Hash seeded TOSA random number generator.""" |
| 155 | |
| 156 | def __init__(self, seed, seed_list, restrict_range_by_type={}): |
| 157 | """Create TOSA random generator seeding it with a hashable list. |
| 158 | |
| 159 | seed: integer starting seed |
| 160 | seed_list: list of hashable items to add to starting seed |
| 161 | restrict_range_by_type: dictionary of DTypes with (low, high) range tuples |
| 162 | This must contain entries for SHAPE and all Floating Point data types. |
| 163 | NOTE: For integers, the high value must be the exclusive value |
| 164 | """ |
| 165 | # Convert seed_list to strings |
| 166 | seed_strings_list = [str(s) for s in seed_list] |
| 167 | # Create a single string and create hash |
| 168 | self._seed_string = "__".join(seed_strings_list) |
| 169 | self._hash = hashlib.md5(bytes(self._seed_string, "utf-8")) |
| 170 | # Add the hash value to the given seed |
| 171 | seed += int(self._hash.hexdigest(), 16) |
| 172 | |
| 173 | logger.debug(f"Seed={seed} Seed string={self._seed_string}") |
| 174 | super().__init__(seed, restrict_range_by_type) |