blob: fd1ee94989b7fc200a6e5c1f56df63daa99b3e0e [file] [log] [blame]
erik.andersson@arm.com460c6892021-02-24 14:38:09 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
patrik.gustavssoneeb85152020-12-21 17:10:40 +00002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Defines the class Shape4D.
Tim Hall73e843f2021-02-04 22:47:46 +000018from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010019from enum import Enum
Tim Hall73e843f2021-02-04 22:47:46 +000020
patrik.gustavssoneeb85152020-12-21 17:10:40 +000021from .numeric_util import full_shape
Tim Halld8339a72021-05-27 18:49:40 +010022from .numeric_util import round_up
Tim Hall73e843f2021-02-04 22:47:46 +000023from .numeric_util import round_up_divide
patrik.gustavssoneeb85152020-12-21 17:10:40 +000024
25
Tim Hall73e843f2021-02-04 22:47:46 +000026class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
patrik.gustavssoneeb85152020-12-21 17:10:40 +000027 """
28 4D Shape (in NHWC format)
29 """
30
Tim Hall73e843f2021-02-04 22:47:46 +000031 def __new__(cls, n=1, h=1, w=1, c=1):
32 assert n is not None
33 if isinstance(n, list):
34 assert h == 1 and w == 1 and c == 1
35 tmp = full_shape(4, n, 1)
36 self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
37 else:
38 self = super(Shape4D, cls).__new__(cls, n, h, w, c)
39 return self
40
41 @classmethod
42 def from_list(cls, shape, base=1):
43 tmp = full_shape(4, shape, base)
44 return cls(tmp[0], tmp[1], tmp[2], tmp[3])
45
46 @classmethod
Tim Halld8339a72021-05-27 18:49:40 +010047 def min(cls, lhs, rhs):
48 return Shape4D(
49 min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth)
50 )
51
52 @classmethod
53 def max(cls, lhs, rhs):
54 return Shape4D(
55 max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth)
56 )
57
58 @classmethod
59 def round_up(cls, lhs, rhs):
60 return Shape4D(
61 round_up(lhs.batch, rhs.batch),
62 round_up(lhs.height, rhs.height),
63 round_up(lhs.width, rhs.width),
64 round_up(lhs.depth, rhs.depth),
65 )
66
67 @classmethod
Tim Hall73e843f2021-02-04 22:47:46 +000068 def from_hwc(cls, h, w, c):
69 return cls(1, h, w, c)
70
71 def with_batch(self, new_batch):
72 return Shape4D(new_batch, self.height, self.width, self.depth)
73
74 def with_height(self, new_height):
75 return Shape4D(self.batch, new_height, self.width, self.depth)
76
77 def with_width(self, new_width):
78 return Shape4D(self.batch, self.height, new_width, self.depth)
79
80 def with_hw(self, new_height, new_width):
81 return Shape4D(self.batch, new_height, new_width, self.depth)
82
83 def with_depth(self, new_depth):
84 return Shape4D(self.batch, self.height, self.width, new_depth)
85
Tim Halld8339a72021-05-27 18:49:40 +010086 def with_axis(self, axis, new_val):
87 shape_as_list = self.as_list()
88 shape_as_list[axis] = new_val
89 return Shape4D.from_list(shape_as_list)
90
91 @staticmethod
92 def _clip_len(pos, length, size):
93 if pos < 0:
94 length = length + pos
95 pos = 0
96 return min(pos + length, size) - pos
97
98 def clip(self, offset, sub_shape):
99 n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch)
100 h = Shape4D._clip_len(offset.height, sub_shape.height, self.height)
101 w = Shape4D._clip_len(offset.width, sub_shape.width, self.width)
102 c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth)
103 return Shape4D(n, h, w, c)
104
Tim Hall73e843f2021-02-04 22:47:46 +0000105 def add(self, n, h, w, c):
106 return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
107
108 def __add__(self, rhs):
109 return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)
110
111 def __sub__(self, rhs):
112 return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
113
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200114 def floordiv_const(self, const):
115 return Shape4D(self.batch // const, self.height // const, self.width // const, self.depth // const)
116
Tim Hall73e843f2021-02-04 22:47:46 +0000117 def __floordiv__(self, rhs):
118 return Shape4D(
119 self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
120 )
121
Tim Halld8339a72021-05-27 18:49:40 +0100122 def __truediv__(self, rhs):
123 return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth)
124
Tim Hall73e843f2021-02-04 22:47:46 +0000125 def __mod__(self, rhs):
126 return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000127
128 def __str__(self):
Tim Hall73e843f2021-02-04 22:47:46 +0000129 return f"<Shape4D {list(self)}>"
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000130
Tim Hall73e843f2021-02-04 22:47:46 +0000131 def div_round_up(self, rhs):
132 return Shape4D(
133 round_up_divide(self.batch, rhs.batch),
134 round_up_divide(self.height, rhs.height),
135 round_up_divide(self.width, rhs.width),
136 round_up_divide(self.depth, rhs.depth),
137 )
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000138
Tim Hall73e843f2021-02-04 22:47:46 +0000139 def elements(self):
140 return self.batch * self.width * self.height * self.depth
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000141
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200142 def dot_prod(self, rhs):
143 return self.batch * rhs.batch + self.width * rhs.width + self.height * rhs.height + self.depth * rhs.depth
144
Tim Hall73e843f2021-02-04 22:47:46 +0000145 def elements_wh(self):
146 return self.width * self.height
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000147
Tim Hall73e843f2021-02-04 22:47:46 +0000148 def is_empty(self):
149 return (self.batch + self.width + self.height + self.depth) == 0
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000150
151 def as_list(self):
Tim Hall73e843f2021-02-04 22:47:46 +0000152 return list(self)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100153
154 def get_hw_as_list(self):
155 return list([self.height, self.width])
Tim Halld8339a72021-05-27 18:49:40 +0100156
157
158class VolumeIterator:
159 """
160 4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes.
161 """
162
163 class Direction(Enum):
164 CWHN = 0
165
166 def __init__(
167 self,
168 shape: Shape4D,
169 sub_shape: Shape4D,
170 start: Shape4D = Shape4D(0, 0, 0, 0),
171 delta: Shape4D = None,
172 dir=Direction.CWHN,
173 ):
174 self.b = start.batch
175 self.y = start.height
176 self.x = start.width
177 self.z = start.depth
178 self.shape = shape
179 self.sub_shape = sub_shape
180 self.delta = sub_shape if delta is None else delta
181 assert self.delta.elements() > 0, "Iterator will not move"
182
183 def __iter__(self):
184 return self
185
186 def __next__(self):
187 if self.b >= self.shape.batch:
188 raise StopIteration()
189
190 offset = Shape4D(self.b, self.y, self.x, self.z)
191
192 # CWHN
193 self.z += self.delta.depth
194 if self.z >= self.shape.depth:
195 self.z = 0
196 self.x += self.delta.width
197 if self.x >= self.shape.width:
198 self.x = 0
199 self.y += self.delta.height
200 if self.y >= self.shape.height:
201 self.y = 0
202 self.b += self.delta.batch
203
204 return offset, self.shape.clip(offset, self.sub_shape)