blob: af5dc174c3141e02f86a52fd198e242375c282b8 [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for tflite support_operators
19import numpy as np
20
21from ethosu.vela.data_type import DataType
22from ethosu.vela.operation import ActivationFunction
23from ethosu.vela.operation import Op
24from ethosu.vela.operation import Padding
25from ethosu.vela.tensor import create_const_tensor
26from ethosu.vela.tensor import QuantizationParameters
27from ethosu.vela.tensor import Tensor
28from ethosu.vela.test import testutil
29from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators
30
31support = TFLiteSupportedOperators()
32
33
34def test_constraint_tens_dtype():
35 # Tensors can only be of type uint8, int8, int16 and int32
36 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.float32)
37 assert not support.is_operator_supported(op)
38
39
40def test_constraint_tens_int32_ops():
41 # For int32, only select op types are allowed:
42 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], datatype=DataType.int32)
43 assert support.is_operator_supported(op)
44 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
45 assert not support.is_operator_supported(op)
46
47
48def test_constraint_tens_dimension():
49 # Tensors can only have values in the inclusive range of 1-65535
50 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 0], [1, 8, 8, 65536])
51 assert not support.is_operator_supported(op)
52
53
54def test_constraint_tens_quant_per_axis_not_supp():
55 # Quantization scale cannot be array-valued for elemwise ops
56 qp = QuantizationParameters()
57 qp.zero_point = np.zeros((1, 3))
58 qp.scale_f32 = np.ones((1, 3))
59 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
60 assert not support.is_operator_supported(op)
61
62
63def test_constraint_tens_quant_per_axis_is_supp():
64 op = testutil.create_op_with_quant_tensors(
65 Op.Conv2DBias, [1, 1, 1, 3], [1, 1, 1, 3], weights_shape=[1, 1, 1, 3], bias_shape=[1, 1, 1, 3]
66 )
67 op.attrs = {"stride_w": 1, "stride_h": 1}
68 assert support.is_operator_supported(op)
69 qp = QuantizationParameters()
70 qp.zero_point = np.zeros((1, 3))
71 qp.scale_f32 = np.ones((1, 3))
72 op.bias.quantization = qp
73 assert support.is_operator_supported(op)
74
75
76def test_constraint_fc_output_2d_is_supp():
77 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
78 assert support.is_operator_supported(op)
79 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
80 assert support.is_operator_supported(op)
81
82
83def test_constraint_faf():
84 # Fused activation functions, if set, must be a valid op type
85 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
86 op.activation = ActivationFunction(Op.Conv2D)
87 assert not support.is_operator_supported(op)
88
89
90def test_constraint_faf_ofm_dtype():
91 # If fused activation function is present, OFM must be 8 or 16 bit
92 shp = [1, 8, 8, 8]
93 for dtype in [DataType.int8, DataType.uint8, DataType.int16, DataType.int32]:
94 op = testutil.create_elemwise_op(Op.Add, "op", shp, shp, shp, datatype=dtype)
95 op.activation = ActivationFunction(Op.Relu)
96 expected = dtype.size_in_bytes() <= 2
97 assert support.is_operator_supported(op) == expected, f"Data type: {dtype}"
98
99
100def test_constraint_conv_pass():
101 # First test a simple conv passes
102 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
103 op.attrs = {"stride_w": 1, "stride_h": 1}
104 assert support.is_operator_supported(op)
105
106
107def test_constraint_stride_range():
108 # Stride width and height must lie within a certain range
109 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
110 op.attrs = {"stride_w": 0, "stride_h": 20}
111 assert not support.is_operator_supported(op)
112
113
114def test_constraint_dilation_range():
115 # Dilation width and height must lie within a certain range
116 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
117 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 0, "dilation_h_factor": 20}
118 assert not support.is_operator_supported(op)
119
120
121def test_constraint_dilated_height_range():
122 # Dilated kernel height must lie within a certain range
123 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[65, 64, 1, 1])
124 op.attrs = {"stride_w": 1, "stride_h": 1}
125 assert not support.is_operator_supported(op)
126
127
128def test_constraint_dilated_product_range():
129 # Dilated kernel width x height must lie within a certain range
130 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[64, 65, 1, 1])
131 op.attrs = {"stride_w": 1, "stride_h": 1}
132 assert not support.is_operator_supported(op)
133
134
135def test_constraint_weights_type():
136 # Weight tensor must be 8-bit
137 op = testutil.create_op_with_quant_tensors(
138 Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1], datatype=DataType.int16
139 )
140 op.attrs = {"stride_w": 1, "stride_h": 1}
141 assert not support.is_operator_supported(op)
142
143
144def test_constraint_weights_const():
145 # Weight tensor cannot be non-const tensors
146 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
147 op.attrs = {"stride_w": 1, "stride_h": 1}
148 weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
149 weights.quantization = testutil.default_quant_params()
150 op.add_input_tensor(weights)
151 assert not support.is_operator_supported(op)
152
153
154def test_constraint_weights_limit():
155 # Sum of weights has a limit
156 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
157 op.attrs = {"stride_w": 1, "stride_h": 1}
158 op.weights.quantization.zero_point = np.array([[[[(127 * 65536) + 1]]]])
159 assert not support.is_operator_supported(op)
160
161
162def test_constraint_bias_type():
163 # Bias must have a certain datatype
164 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
165 op.attrs = {"stride_w": 1, "stride_h": 1}
166 bias = Tensor([1, 8, 8, 8], DataType.uint8, "bias")
167 op.add_input_tensor(bias)
168 assert not support.is_operator_supported(op)
169
170
171def test_constraint_bias_40bit():
172 # Bias must not exceed 40-bit
173 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
174 op.attrs = {"stride_w": 1, "stride_h": 1}
175 bias = Tensor([1, 1, 1, 1], DataType.int64, "bias")
176 bias.values = np.array([0x01FF_FFFF_FFFF])
177 op.add_input_tensor(bias)
178 assert not support.is_operator_supported(op)
179
180
181def test_constraint_batch_size():
182 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
183 op.attrs = {"stride_w": 1, "stride_h": 1}
184 assert not support.is_operator_supported(op)
185
186
187def test_constraint_depth_multiplier():
188 # Valid. Depth multiplier is 1 so no further constraints
189 op = testutil.create_op_with_quant_tensors(
190 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
191 )
192 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 1}
193 assert support.is_operator_supported(op)
194 # Invalid. Depth multiplier doesnt equal ofm channel
195 op = testutil.create_op_with_quant_tensors(
196 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]
197 )
198 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
199 assert not support.is_operator_supported(op)
200 # Valid. Depth multiplier is equal to ofm channel
201 op = testutil.create_op_with_quant_tensors(
202 Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
203 )
204 op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
205 assert support.is_operator_supported(op)
206
207
208def test_constraint_tconv_stride():
209 # Strides must be 2
210 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
211 op.attrs = {"stride_w": 1, "stride_h": 1, "padding": Padding.SAME}
212 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
213 ifm.quantization = testutil.default_quant_params()
214 op.add_input_tensor(ifm)
215 assert not support.is_operator_supported(op)
216
217
218def test_constraint_tconv_same():
219 # Valid
220 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
221 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
222 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
223 ifm.quantization = testutil.default_quant_params()
224 op.add_input_tensor(ifm)
225 assert support.is_operator_supported(op)
226 # Invalid
227 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[1, 1, 1, 1])
228 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
229 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
230 ifm.quantization = testutil.default_quant_params()
231 op.add_input_tensor(ifm)
232 assert not support.is_operator_supported(op)
233
234
235def test_constraint_tconv_valid():
236 # Valid
237 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[4, 4, 1, 1])
238 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
239 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
240 ifm.quantization = testutil.default_quant_params()
241 op.add_input_tensor(ifm)
242 assert support.is_operator_supported(op)
243 # Invalid
244 op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[2, 2, 1, 1])
245 op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
246 ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
247 ifm.quantization = testutil.default_quant_params()
248 op.add_input_tensor(ifm)
249 assert not support.is_operator_supported(op)
250
251
252def test_constraint_filter_range():
253 # Avg pool restrictions are dependent on padding:
254 # SAME padding restricts both W and H to max 8
255 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
256 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": Padding.SAME}
257 assert not support.is_operator_supported(op)
258 # VALID padding limits are much larger
259 op.attrs["padding"] = Padding.VALID
260 assert support.is_operator_supported(op)
261
262
263def test_constraint_filter_height_range_valid_pad():
264 # Avg pool restrictions are dependent on padding:
265 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
266 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.VALID}
267 assert support.is_operator_supported(op)
268 # VALID padding restricts to 256 in filter height
269 op.attrs["filter_height"] = 257
270 assert not support.is_operator_supported(op)
271
272
273def test_constraint_filter_product_height_range_valid_pad():
274 # Avg pool restrictions are dependent on padding:
275 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
276 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.VALID}
277 assert support.is_operator_supported(op)
278 # VALID padding restricts filter W x H to 256x256
279 op.attrs["filter_width"] = 257
280 assert not support.is_operator_supported(op)
281
282
283def test_constraint_filter_height_range():
284 # Max pool restrictions arent dependent on padding
285 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
286 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.SAME}
287 assert support.is_operator_supported(op)
288 # Restricts to 256 in filter height
289 op.attrs["filter_height"] = 257
290 assert not support.is_operator_supported(op)
291 # Doesnt matter if SAME or VALID
292 op.attrs["padding"] = Padding.VALID
293 assert not support.is_operator_supported(op)
294
295
296def test_constraint_filter_product_height_range():
297 # Max pool restrictions arent dependent on padding
298 op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
299 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.SAME}
300 assert support.is_operator_supported(op)
301 # Restricts filter W x H to 256x256
302 op.attrs["filter_width"] = 257
303 assert not support.is_operator_supported(op)
304 # Doesnt matter if SAME or VALID
305 op.attrs["padding"] = Padding.VALID
306 assert not support.is_operator_supported(op)
307
308
309def test_constraint_resize():
310 # IFM W and H == 1
311 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 1, 1, 8], [1, 8, 8, 8])
312 assert support.is_operator_supported(op)
313 # IFM == OFM
314 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 8, 8, 8], [1, 8, 8, 8])
315 assert support.is_operator_supported(op)
316 # IFM x2 == OFM ; align_corners = False
317 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
318 assert support.is_operator_supported(op)
319 # IFM x2 -1 == OFM ; align_corners = True
320 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 7, 7, 8])
321 op.attrs["align_corners"] = True
322 assert support.is_operator_supported(op)
323 # Invalid cases
324 op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 20, 20, 8])
325 assert not support.is_operator_supported(op)
326 op.attrs["align_corners"] = True
327 assert not support.is_operator_supported(op)
328
329
330def test_constraint_concat_pass():
331 # A working concat
332 op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
333 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
334 ifm2.quantization = testutil.default_quant_params()
335 op.add_input_tensor(ifm2)
336 op.attrs["axis"] = 3
337 assert support.is_operator_supported(op)
338
339
340def create_pad_op(
341 in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
342):
343 qp = testutil.default_quant_params()
344 in0 = Tensor(in_shape, in_dtype, "in")
345 in0.quantization = qp
346 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
347 out = Tensor(out_shape, out_dtype, "out")
348 out.quantization = qp.clone()
349 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
350 return op
351
352
353def test_constraint_padded_dimensions():
354 # Incorrect padding dimensions, can only pad width and height
355 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
356 assert not support.is_operator_supported(op)
357 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]],)
358 assert support.is_operator_supported(op)
359 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 1]],)
360 assert not support.is_operator_supported(op)
361
362
363def test_constraint_pad_shape():
364 # PAD operator must be of shape (3,2) or (4,2)
365 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]])
366 assert support.is_operator_supported(op)
367 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
368 assert not support.is_operator_supported(op)
369
370
371def test_constraint_pad_none():
372 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[],)
373 assert not support.is_operator_supported(op)
374
375
376def test_constraint_pad_dtype():
377 # PAD operator dtype should be int32 or int64
378 op = create_pad_op(
379 in_shape=[1, 1, 1, 1],
380 out_shape=[1, 3, 3, 1],
381 padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
382 pad_dtype=DataType.int16,
383 )
384 assert not support.is_operator_supported(op)
385
386
387def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
388 qp = testutil.default_quant_params()
389 in0 = Tensor(in_shape, DataType.uint8, "in")
390 in0.quantization = qp
391 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
392 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
393 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
394 out = Tensor(out_shape, DataType.uint8, "out")
395 out.quantization = qp
396 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
397 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
398
399
400def create_strided_slice():
401 # Creates a valid strided slice operator with some valid inputs/outputs
402 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
403 op.attrs["begin_mask"] = 1
404 op.attrs["end_mask"] = 9
405 assert support.is_operator_supported(op)
406 return op
407
408
409def test_constraint_stridedslice_stride_values():
410 # Unsupported strides
411 op = create_strided_slice()
412 op.inputs[3].values = [1, 1, 2, 1]
413 assert not support.is_operator_supported(op)
414
415
416def test_constraint_inputs_int32():
417 # both inputs must be type int32
418 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
419 assert not support.is_operator_supported(op)
420 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
421 assert support.is_operator_supported(op)
422 op.ifm2.dtype = DataType.int16
423 assert not support.is_operator_supported(op)
424
425
426def test_constraint_output_int32():
427 # output must be type int32
428 op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
429 assert support.is_operator_supported(op)
430 op.ofm.dtype = DataType.int16
431 assert not support.is_operator_supported(op)
432
433
434def test_constraint_matching_quantization_parameters():
435 qp = QuantizationParameters()
436 qp.scale_f32 = np.float32(1.5)
437 qp.zero_point = 128
438 # valid - all matching (uses default quant params)
439 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
440 assert support.is_operator_supported(op)
441 # invalid - ifm mismatch ofm
442 op.ifm.quantization = qp
443 assert not support.is_operator_supported(op)
444 # invalid - ifm2 mismatch ofm
445 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
446 op.ifm2.quantization = qp
447 assert not support.is_operator_supported(op)
448 # invalid - both ifm and ifm2 mismatch ofm
449 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
450 op.ifm.quantization = qp
451 op.ifm2.quantization = qp
452 assert not support.is_operator_supported(op)
453 # valid - all matching
454 op.ofm.quantization = qp
455 assert support.is_operator_supported(op)
456 op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8])
457 assert support.is_operator_supported(op)
458
459
460def test_constraint_elemwise_batch_size():
461 # BINARY CASE
462 # Batch can be >1 if dims is <=2D
463 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2])
464 assert support.is_operator_supported(op)
465 # For dims >2D, batch must be 1
466 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2])
467 assert support.is_operator_supported(op)
468 # invalid case
469 op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
470 assert not support.is_operator_supported(op)
471
472 # UNARY CASE
473 # Batch can be >1 if dims is <=2D
474 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
475 assert support.is_operator_supported(op)
476 # For dims >2D, batch must be 1
477 op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32)
478 assert support.is_operator_supported(op)
479 # invalid case
480 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
481 assert not support.is_operator_supported(op)
482
483
484def test_constraint_broadcast_shapes():
485 # BINARY CASE
486 # Only allow broadcast to 1 dim, for 1 rank index
487 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4], [1, 2, 4], [1, 2, 4])
488 assert support.is_operator_supported(op)
489 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 1, 4], [1, 2, 4])
490 assert support.is_operator_supported(op)
491 # Only allow broadcast to 1 dim, for 3 rank indexes
492 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 1, 1], [1, 4, 8, 16], [1, 4, 8, 16])
493 assert support.is_operator_supported(op)
494 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 8, 16], [1, 1, 1, 1], [1, 4, 8, 16])
495 assert support.is_operator_supported(op)
496 # One broadcast dim not 1
497 op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 4], [1, 4, 4], [1, 4, 4])
498 assert not support.is_operator_supported(op)
499 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 4], [1, 2, 4], [1, 4, 4])
500 assert not support.is_operator_supported(op)
501 # OFM shape dim largest ifm/ifm2 shape dim
502 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
503 assert not support.is_operator_supported(op)
504 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [1, 4])
505 assert not support.is_operator_supported(op)
506 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 1, 16])
507 assert not support.is_operator_supported(op)
508 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 1, 16])
509 assert not support.is_operator_supported(op)
510
511
512def create_mean(input_shape, output_shape, axis, datatype, attrs):
513 ifm = Tensor(input_shape, datatype, "in")
514 ifm.quantization = testutil.default_quant_params()
515 ofm = Tensor(output_shape, datatype, "out")
516 ofm.quantization = testutil.default_quant_params()
517 if type(axis) is list:
518 indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
519 elif type(axis) is int:
520 indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
521 op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
522 return op
523
524
525def test_mean_hw_product():
526 op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
527 assert support.is_operator_supported(op)
528 op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
529 assert not support.is_operator_supported(op)
530
531
532def test_mean_hw_product_int8():
533 op = create_mean([1, 16, 16, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
534 assert support.is_operator_supported(op)
535 op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
536 assert not support.is_operator_supported(op)
537
538
539def test_mean_hw_product_avgpool():
540 op = create_mean([1, 200, 200, 16], [1, 16], [1, 2], DataType.uint8, {"keep_dims": False})
541 assert support.is_operator_supported(op)
542 op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
543 assert not support.is_operator_supported(op)