blob: f2a16097255764bf39dfaa3955e21a4b28ba3722 [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Contains unit tests for register command stream generator
19from ethosu.vela.api import NpuAddressRange
20from ethosu.vela.api import NpuDataType
21from ethosu.vela.api import NpuFeatureMap
22from ethosu.vela.api import NpuLayout
23from ethosu.vela.api import NpuShape3D
24from ethosu.vela.api import NpuTileBox
25from ethosu.vela.register_command_stream_generator import get_address_ranges
26from ethosu.vela.register_command_stream_generator import get_strides
27
28
29def test_get_fm_strides():
30 """Tests calculation of feature map strides"""
31 fm = NpuFeatureMap()
32 fm.layout = NpuLayout.NHCWB16
33 fm.data_type = NpuDataType.INT16
34 fm.shape = NpuShape3D(height=7, width=10, depth=24)
35 assert get_strides(fm) == NpuShape3D(height=640, width=32, depth=320)
36 fm.layout = NpuLayout.NHWC
37 assert get_strides(fm) == NpuShape3D(height=480, width=48, depth=2)
38 fm.data_type = NpuDataType.UINT8
39 assert get_strides(fm) == NpuShape3D(height=240, width=24, depth=1)
40
41
42def test_get_address_ranges_one_tile():
43 """Tests calculation of feature map address ranges, with 1 tile used"""
44 fm = NpuFeatureMap()
45 fm.region = 4
46 fm.layout = NpuLayout.NHWC
47 fm.data_type = NpuDataType.INT16
48 fm.shape = NpuShape3D(height=50, width=40, depth=3)
49 fm.tiles = NpuTileBox(height_0=50, height_1=50, width_0=40, addresses=[8000, 0, 0, 0])
50 ranges = get_address_ranges(fm)
51 assert ranges == [NpuAddressRange(region=4, address=8000, length=12000), None, None, None]
52
53
54def test_get_address_ranges_horizontal_tiles():
55 """Tests calculation of feature map address ranges, with 2 horizontal tiles used"""
56 fm = NpuFeatureMap()
57 fm.region = 6
58 fm.layout = NpuLayout.NHWC
59 fm.data_type = NpuDataType.INT16
60 fm.shape = NpuShape3D(height=50, width=10, depth=20)
61 fm.tiles = NpuTileBox(height_0=20, height_1=30, width_0=10, addresses=[256, 0, 16000, 0])
62 ranges = get_address_ranges(fm)
63 assert ranges == [
64 NpuAddressRange(region=6, address=256, length=8000),
65 None,
66 NpuAddressRange(region=6, address=16000, length=12000),
67 None,
68 ]
69
70
71def test_get_address_ranges_vertical_tiles():
72 """Tests calculation of feature map address ranges, with 2 vertical tiles used"""
73 fm = NpuFeatureMap()
74 fm.region = 6
75 fm.layout = NpuLayout.NHWC
76 fm.data_type = NpuDataType.INT8
77 # Set strides explicitly
78 fm.shape = NpuShape3D(height=50, width=10, depth=20)
79 fm.strides = NpuShape3D(height=100, width=20, depth=1)
80 fm.tiles = NpuTileBox(height_0=50, height_1=50, width_0=5, addresses=[16, 32000, 0, 0])
81 ranges = get_address_ranges(fm)
82 assert ranges == [
83 NpuAddressRange(region=6, address=16, length=5000),
84 NpuAddressRange(region=6, address=32000, length=5000),
85 None,
86 None,
87 ]
88
89
90def test_get_address_ranges_4_tiles():
91 """Tests calculation of feature map address ranges, with 4 tiles used"""
92 fm = NpuFeatureMap()
93 fm.region = 6
94 fm.layout = NpuLayout.NHCWB16
95 fm.data_type = NpuDataType.INT16
96 fm.shape = NpuShape3D(height=50, width=10, depth=20)
97 fm.tiles = NpuTileBox(height_0=30, height_1=10, width_0=3, addresses=[16, 32000, 8000, 16000])
98 ranges = get_address_ranges(fm)
99 assert ranges == [
100 NpuAddressRange(region=6, address=16, length=18952),
101 NpuAddressRange(region=6, address=32000, length=6280),
102 NpuAddressRange(region=6, address=8000, length=12552),
103 NpuAddressRange(region=6, address=28800, length=12680),
104 ]