blob: 5d849d98eea19bce39374d4a41e4f2b6d6356d3b [file] [log] [blame]
erik.andersson@arm.com460c6892021-02-24 14:38:09 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
patrik.gustavssoneeb85152020-12-21 17:10:40 +00002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Defines the class Shape4D.
Tim Hall73e843f2021-02-04 22:47:46 +000018from collections import namedtuple
19
patrik.gustavssoneeb85152020-12-21 17:10:40 +000020from .numeric_util import full_shape
Tim Hall73e843f2021-02-04 22:47:46 +000021from .numeric_util import round_up_divide
patrik.gustavssoneeb85152020-12-21 17:10:40 +000022
23
Tim Hall73e843f2021-02-04 22:47:46 +000024class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
patrik.gustavssoneeb85152020-12-21 17:10:40 +000025 """
26 4D Shape (in NHWC format)
27 """
28
Tim Hall73e843f2021-02-04 22:47:46 +000029 def __new__(cls, n=1, h=1, w=1, c=1):
30 assert n is not None
31 if isinstance(n, list):
32 assert h == 1 and w == 1 and c == 1
33 tmp = full_shape(4, n, 1)
34 self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
35 else:
36 self = super(Shape4D, cls).__new__(cls, n, h, w, c)
37 return self
38
39 @classmethod
40 def from_list(cls, shape, base=1):
41 tmp = full_shape(4, shape, base)
42 return cls(tmp[0], tmp[1], tmp[2], tmp[3])
43
44 @classmethod
45 def from_hwc(cls, h, w, c):
46 return cls(1, h, w, c)
47
48 def with_batch(self, new_batch):
49 return Shape4D(new_batch, self.height, self.width, self.depth)
50
51 def with_height(self, new_height):
52 return Shape4D(self.batch, new_height, self.width, self.depth)
53
54 def with_width(self, new_width):
55 return Shape4D(self.batch, self.height, new_width, self.depth)
56
57 def with_hw(self, new_height, new_width):
58 return Shape4D(self.batch, new_height, new_width, self.depth)
59
60 def with_depth(self, new_depth):
61 return Shape4D(self.batch, self.height, self.width, new_depth)
62
63 def add(self, n, h, w, c):
64 return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
65
66 def __add__(self, rhs):
67 return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)
68
69 def __sub__(self, rhs):
70 return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
71
72 def __floordiv__(self, rhs):
73 return Shape4D(
74 self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
75 )
76
77 def __mod__(self, rhs):
78 return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
patrik.gustavssoneeb85152020-12-21 17:10:40 +000079
80 def __str__(self):
Tim Hall73e843f2021-02-04 22:47:46 +000081 return f"<Shape4D {list(self)}>"
patrik.gustavssoneeb85152020-12-21 17:10:40 +000082
Tim Hall73e843f2021-02-04 22:47:46 +000083 def div_round_up(self, rhs):
84 return Shape4D(
85 round_up_divide(self.batch, rhs.batch),
86 round_up_divide(self.height, rhs.height),
87 round_up_divide(self.width, rhs.width),
88 round_up_divide(self.depth, rhs.depth),
89 )
patrik.gustavssoneeb85152020-12-21 17:10:40 +000090
Tim Hall73e843f2021-02-04 22:47:46 +000091 def elements(self):
92 return self.batch * self.width * self.height * self.depth
patrik.gustavssoneeb85152020-12-21 17:10:40 +000093
Tim Hall73e843f2021-02-04 22:47:46 +000094 def elements_wh(self):
95 return self.width * self.height
patrik.gustavssoneeb85152020-12-21 17:10:40 +000096
Tim Hall73e843f2021-02-04 22:47:46 +000097 def is_empty(self):
98 return (self.batch + self.width + self.height + self.depth) == 0
patrik.gustavssoneeb85152020-12-21 17:10:40 +000099
100 def as_list(self):
Tim Hall73e843f2021-02-04 22:47:46 +0000101 return list(self)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100102
103 def get_hw_as_list(self):
104 return list([self.height, self.width])