blob: a1b4feaa9aacf2f8a0ceaad3e32854ac13797908 [file] [log] [blame]
patrik.gustavssoneeb85152020-12-21 17:10:40 +00001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Defines the class Shape4D.
18from .numeric_util import full_shape
19
20
21class Shape4D:
22 """
23 4D Shape (in NHWC format)
24 """
25
26 def __init__(self, shape, base=1):
27 assert shape is not None
28 assert len(shape) <= 4
29 self._shape4D = tuple(full_shape(4, shape, base))
30
31 def __str__(self):
32 return f"<Shape4D {self.as_list()}>"
33
34 def __eq__(self, other):
35 return self._shape4D == other._shape4D
36
37 def clone(self):
38 return Shape4D(self.as_list())
39
40 @property
41 def batch(self):
42 return self._shape4D[0]
43
44 @property
45 def height(self):
46 return self._shape4D[1]
47
48 @property
49 def width(self):
50 return self._shape4D[2]
51
52 @property
53 def depth(self):
54 return self._shape4D[3]
55
56 @batch.setter
57 def batch(self, new_batch):
58 self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3])
59
60 @height.setter
61 def height(self, new_height):
62 self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3])
63
64 @width.setter
65 def width(self, new_width):
66 self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3])
67
68 @depth.setter
69 def depth(self, new_depth):
70 self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth)
71
72 def get_dim(self, dim):
73 assert -4 <= dim < 4
74 return self._shape4D[dim]
75
76 def as_list(self):
77 return list(self._shape4D)