blob: 7a01b3994c42cd843e6d642ecf15c7716acc4b7d [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
patrik.gustavssoneeb85152020-12-21 17:10:40 +00002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
patrik.gustavssoneeb85152020-12-21 17:10:40 +000017# Description:
18# Defines the class Shape4D.
Tim Hall73e843f2021-02-04 22:47:46 +000019from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010020from enum import Enum
Tim Hall73e843f2021-02-04 22:47:46 +000021
patrik.gustavssoneeb85152020-12-21 17:10:40 +000022from .numeric_util import full_shape
Tim Halld8339a72021-05-27 18:49:40 +010023from .numeric_util import round_up
Tim Hall73e843f2021-02-04 22:47:46 +000024from .numeric_util import round_up_divide
patrik.gustavssoneeb85152020-12-21 17:10:40 +000025
26
Tim Hall73e843f2021-02-04 22:47:46 +000027class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
patrik.gustavssoneeb85152020-12-21 17:10:40 +000028 """
29 4D Shape (in NHWC format)
30 """
31
Tim Hall73e843f2021-02-04 22:47:46 +000032 def __new__(cls, n=1, h=1, w=1, c=1):
33 assert n is not None
34 if isinstance(n, list):
35 assert h == 1 and w == 1 and c == 1
36 tmp = full_shape(4, n, 1)
37 self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
38 else:
39 self = super(Shape4D, cls).__new__(cls, n, h, w, c)
40 return self
41
42 @classmethod
43 def from_list(cls, shape, base=1):
44 tmp = full_shape(4, shape, base)
45 return cls(tmp[0], tmp[1], tmp[2], tmp[3])
46
47 @classmethod
Tim Halld8339a72021-05-27 18:49:40 +010048 def min(cls, lhs, rhs):
49 return Shape4D(
50 min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth)
51 )
52
53 @classmethod
54 def max(cls, lhs, rhs):
55 return Shape4D(
56 max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth)
57 )
58
59 @classmethod
60 def round_up(cls, lhs, rhs):
61 return Shape4D(
62 round_up(lhs.batch, rhs.batch),
63 round_up(lhs.height, rhs.height),
64 round_up(lhs.width, rhs.width),
65 round_up(lhs.depth, rhs.depth),
66 )
67
68 @classmethod
Tim Hall73e843f2021-02-04 22:47:46 +000069 def from_hwc(cls, h, w, c):
70 return cls(1, h, w, c)
71
72 def with_batch(self, new_batch):
73 return Shape4D(new_batch, self.height, self.width, self.depth)
74
75 def with_height(self, new_height):
76 return Shape4D(self.batch, new_height, self.width, self.depth)
77
78 def with_width(self, new_width):
79 return Shape4D(self.batch, self.height, new_width, self.depth)
80
81 def with_hw(self, new_height, new_width):
82 return Shape4D(self.batch, new_height, new_width, self.depth)
83
84 def with_depth(self, new_depth):
85 return Shape4D(self.batch, self.height, self.width, new_depth)
86
Tim Halld8339a72021-05-27 18:49:40 +010087 def with_axis(self, axis, new_val):
88 shape_as_list = self.as_list()
89 shape_as_list[axis] = new_val
90 return Shape4D.from_list(shape_as_list)
91
92 @staticmethod
93 def _clip_len(pos, length, size):
94 if pos < 0:
95 length = length + pos
96 pos = 0
97 return min(pos + length, size) - pos
98
99 def clip(self, offset, sub_shape):
100 n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch)
101 h = Shape4D._clip_len(offset.height, sub_shape.height, self.height)
102 w = Shape4D._clip_len(offset.width, sub_shape.width, self.width)
103 c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth)
104 return Shape4D(n, h, w, c)
105
Tim Hall73e843f2021-02-04 22:47:46 +0000106 def add(self, n, h, w, c):
107 return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
108
109 def __add__(self, rhs):
110 return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)
111
112 def __sub__(self, rhs):
113 return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
114
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200115 def floordiv_const(self, const):
116 return Shape4D(self.batch // const, self.height // const, self.width // const, self.depth // const)
117
Tim Hall73e843f2021-02-04 22:47:46 +0000118 def __floordiv__(self, rhs):
119 return Shape4D(
120 self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
121 )
122
Tim Halld8339a72021-05-27 18:49:40 +0100123 def __truediv__(self, rhs):
124 return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth)
125
Tim Hall73e843f2021-02-04 22:47:46 +0000126 def __mod__(self, rhs):
127 return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000128
129 def __str__(self):
Tim Hall73e843f2021-02-04 22:47:46 +0000130 return f"<Shape4D {list(self)}>"
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000131
Tim Hall73e843f2021-02-04 22:47:46 +0000132 def div_round_up(self, rhs):
133 return Shape4D(
134 round_up_divide(self.batch, rhs.batch),
135 round_up_divide(self.height, rhs.height),
136 round_up_divide(self.width, rhs.width),
137 round_up_divide(self.depth, rhs.depth),
138 )
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000139
Tim Hall73e843f2021-02-04 22:47:46 +0000140 def elements(self):
141 return self.batch * self.width * self.height * self.depth
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000142
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200143 def dot_prod(self, rhs):
144 return self.batch * rhs.batch + self.width * rhs.width + self.height * rhs.height + self.depth * rhs.depth
145
Tim Hall73e843f2021-02-04 22:47:46 +0000146 def elements_wh(self):
147 return self.width * self.height
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000148
Tim Hall73e843f2021-02-04 22:47:46 +0000149 def is_empty(self):
150 return (self.batch + self.width + self.height + self.depth) == 0
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000151
152 def as_list(self):
Tim Hall73e843f2021-02-04 22:47:46 +0000153 return list(self)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100154
155 def get_hw_as_list(self):
156 return list([self.height, self.width])
Tim Halld8339a72021-05-27 18:49:40 +0100157
158
159class VolumeIterator:
160 """
161 4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes.
162 """
163
164 class Direction(Enum):
165 CWHN = 0
166
167 def __init__(
168 self,
169 shape: Shape4D,
170 sub_shape: Shape4D,
171 start: Shape4D = Shape4D(0, 0, 0, 0),
172 delta: Shape4D = None,
173 dir=Direction.CWHN,
174 ):
175 self.b = start.batch
176 self.y = start.height
177 self.x = start.width
178 self.z = start.depth
179 self.shape = shape
180 self.sub_shape = sub_shape
181 self.delta = sub_shape if delta is None else delta
182 assert self.delta.elements() > 0, "Iterator will not move"
183
184 def __iter__(self):
185 return self
186
187 def __next__(self):
188 if self.b >= self.shape.batch:
189 raise StopIteration()
190
191 offset = Shape4D(self.b, self.y, self.x, self.z)
192
193 # CWHN
194 self.z += self.delta.depth
195 if self.z >= self.shape.depth:
196 self.z = 0
197 self.x += self.delta.width
198 if self.x >= self.shape.width:
199 self.x = 0
200 self.y += self.delta.height
201 if self.y >= self.shape.height:
202 self.y = 0
203 self.b += self.delta.batch
204
205 return offset, self.shape.clip(offset, self.sub_shape)