Richard Burton | dc0c6ed | 2020-04-08 16:39:05 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | import pytest |
| 4 | import numpy as np |
| 5 | |
| 6 | import pyarmnn as ann |
| 7 | |
| 8 | # import generated so we can test for Dequantize_* and Quantize_* |
| 9 | # functions not available in the public API. |
| 10 | import pyarmnn._generated.pyarmnn as gen_ann |
| 11 | |
| 12 | |
| 13 | @pytest.mark.parametrize('method', ['Quantize_int8_t', |
| 14 | 'Quantize_uint8_t', |
| 15 | 'Quantize_int16_t', |
| 16 | 'Quantize_int32_t', |
| 17 | 'Dequantize_int8_t', |
| 18 | 'Dequantize_uint8_t', |
| 19 | 'Dequantize_int16_t', |
| 20 | 'Dequantize_int32_t']) |
| 21 | def test_quantize_exists(method): |
| 22 | assert method in dir(gen_ann) and callable(getattr(gen_ann, method)) |
| 23 | |
| 24 | |
| 25 | @pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255), |
| 26 | ('int8', -128, 127), |
| 27 | ('int16', -32768, 32767), |
| 28 | ('int32', -2147483648, 2147483647)]) |
| 29 | def test_quantize_uint8_output(dt, min, max): |
| 30 | result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt) |
| 31 | assert type(result) is int and min <= result <= max |
| 32 | |
| 33 | |
| 34 | @pytest.mark.parametrize('dt', ['uint8', |
| 35 | 'int8', |
| 36 | 'int16', |
| 37 | 'int32']) |
| 38 | def test_dequantize_uint8_output(dt): |
| 39 | result = ann.dequantize(3, 0.02620004490017891, 128, dt) |
| 40 | assert type(result) is float |
| 41 | |
| 42 | |
| 43 | def test_quantize_unsupported_dtype(): |
| 44 | with pytest.raises(ValueError) as err: |
| 45 | ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'uint16') |
| 46 | |
| 47 | assert 'Unexpected target datatype uint16 given.' in str(err.value) |
| 48 | |
| 49 | |
| 50 | def test_dequantize_unsupported_dtype(): |
| 51 | with pytest.raises(ValueError) as err: |
| 52 | ann.dequantize(3, 0.02620004490017891, 128, 'uint16') |
| 53 | |
| 54 | assert 'Unexpected value datatype uint16 given.' in str(err.value) |
| 55 | |
| 56 | |
| 57 | def test_dequantize_value_range(): |
| 58 | with pytest.raises(ValueError) as err: |
| 59 | ann.dequantize(-1, 0.02620004490017891, 128, 'uint8') |
| 60 | |
| 61 | assert 'Value is not within range of the given datatype uint8' in str(err.value) |
| 62 | |
| 63 | |
| 64 | @pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)), |
| 65 | ('int8', np.int8(127)), |
| 66 | ('int16', np.int16(32767)), |
| 67 | ('int32', np.int32(2147483647)), |
| 68 | |
| 69 | ('uint8', np.int8(127)), |
| 70 | ('uint8', np.int16(255)), |
| 71 | ('uint8', np.int32(255)), |
| 72 | |
| 73 | ('int8', np.uint8(127)), |
| 74 | ('int8', np.int16(127)), |
| 75 | ('int8', np.int32(127)), |
| 76 | |
| 77 | ('int16', np.int8(127)), |
| 78 | ('int16', np.uint8(255)), |
| 79 | ('int16', np.int32(32767)), |
| 80 | |
| 81 | ('int32', np.uint8(255)), |
| 82 | ('int16', np.int8(127)), |
| 83 | ('int32', np.int16(32767)) |
| 84 | |
| 85 | ]) |
| 86 | def test_dequantize_numpy_dt(dt, data): |
| 87 | result = ann.dequantize(data, 1, 0, dt) |
| 88 | |
| 89 | assert type(result) is float |
| 90 | |
| 91 | assert np.float32(data) == result |