blob: 08fea39eda9bbeeb9d22d3fed143ff24ba8dcce3 [file] [log] [blame]
# Copyright © 2020 Arm Ltd. All rights reserved.
# SPDX-License-Identifier: MIT
import pytest
import numpy as np
import pyarmnn as ann
# import generated so we can test for Dequantize_* and Quantize_*
# functions not available in the public API.
import pyarmnn._generated.pyarmnn as gen_ann
@pytest.mark.parametrize('method', ['Quantize_int8_t',
'Quantize_uint8_t',
'Quantize_int16_t',
'Quantize_int32_t',
'Dequantize_int8_t',
'Dequantize_uint8_t',
'Dequantize_int16_t',
'Dequantize_int32_t'])
def test_quantize_exists(method):
assert method in dir(gen_ann) and callable(getattr(gen_ann, method))
@pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255),
('int8', -128, 127),
('int16', -32768, 32767),
('int32', -2147483648, 2147483647)])
def test_quantize_uint8_output(dt, min, max):
result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt)
assert type(result) is int and min <= result <= max
@pytest.mark.parametrize('dt', ['uint8',
'int8',
'int16',
'int32'])
def test_dequantize_uint8_output(dt):
result = ann.dequantize(3, 0.02620004490017891, 128, dt)
assert type(result) is float
def test_quantize_unsupported_dtype():
with pytest.raises(ValueError) as err:
ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'uint16')
assert 'Unexpected target datatype uint16 given.' in str(err.value)
def test_dequantize_unsupported_dtype():
with pytest.raises(ValueError) as err:
ann.dequantize(3, 0.02620004490017891, 128, 'uint16')
assert 'Unexpected value datatype uint16 given.' in str(err.value)
def test_dequantize_value_range():
with pytest.raises(ValueError) as err:
ann.dequantize(-1, 0.02620004490017891, 128, 'uint8')
assert 'Value is not within range of the given datatype uint8' in str(err.value)
@pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)),
('int8', np.int8(127)),
('int16', np.int16(32767)),
('int32', np.int32(2147483647)),
('uint8', np.int8(127)),
('uint8', np.int16(255)),
('uint8', np.int32(255)),
('int8', np.uint8(127)),
('int8', np.int16(127)),
('int8', np.int32(127)),
('int16', np.int8(127)),
('int16', np.uint8(255)),
('int16', np.int32(32767)),
('int32', np.uint8(255)),
('int16', np.int8(127)),
('int32', np.int16(32767))
])
def test_dequantize_numpy_dt(dt, data):
result = ann.dequantize(data, 1, 0, dt)
assert type(result) is float
assert np.float32(data) == result