erik.andersson@arm.com | 460c689 | 2021-02-24 14:38:09 +0100 | [diff] [blame] | 1 | # Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | # Description: |
| 17 | # Defines the class Shape4D. |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 18 | from collections import namedtuple |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 19 | from enum import Enum |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 20 | |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 21 | from .numeric_util import full_shape |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 22 | from .numeric_util import round_up |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 23 | from .numeric_util import round_up_divide |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 24 | |
| 25 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 26 | class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 27 | """ |
| 28 | 4D Shape (in NHWC format) |
| 29 | """ |
| 30 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 31 | def __new__(cls, n=1, h=1, w=1, c=1): |
| 32 | assert n is not None |
| 33 | if isinstance(n, list): |
| 34 | assert h == 1 and w == 1 and c == 1 |
| 35 | tmp = full_shape(4, n, 1) |
| 36 | self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3]) |
| 37 | else: |
| 38 | self = super(Shape4D, cls).__new__(cls, n, h, w, c) |
| 39 | return self |
| 40 | |
| 41 | @classmethod |
| 42 | def from_list(cls, shape, base=1): |
| 43 | tmp = full_shape(4, shape, base) |
| 44 | return cls(tmp[0], tmp[1], tmp[2], tmp[3]) |
| 45 | |
| 46 | @classmethod |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 47 | def min(cls, lhs, rhs): |
| 48 | return Shape4D( |
| 49 | min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth) |
| 50 | ) |
| 51 | |
| 52 | @classmethod |
| 53 | def max(cls, lhs, rhs): |
| 54 | return Shape4D( |
| 55 | max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth) |
| 56 | ) |
| 57 | |
| 58 | @classmethod |
| 59 | def round_up(cls, lhs, rhs): |
| 60 | return Shape4D( |
| 61 | round_up(lhs.batch, rhs.batch), |
| 62 | round_up(lhs.height, rhs.height), |
| 63 | round_up(lhs.width, rhs.width), |
| 64 | round_up(lhs.depth, rhs.depth), |
| 65 | ) |
| 66 | |
| 67 | @classmethod |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 68 | def from_hwc(cls, h, w, c): |
| 69 | return cls(1, h, w, c) |
| 70 | |
| 71 | def with_batch(self, new_batch): |
| 72 | return Shape4D(new_batch, self.height, self.width, self.depth) |
| 73 | |
| 74 | def with_height(self, new_height): |
| 75 | return Shape4D(self.batch, new_height, self.width, self.depth) |
| 76 | |
| 77 | def with_width(self, new_width): |
| 78 | return Shape4D(self.batch, self.height, new_width, self.depth) |
| 79 | |
| 80 | def with_hw(self, new_height, new_width): |
| 81 | return Shape4D(self.batch, new_height, new_width, self.depth) |
| 82 | |
| 83 | def with_depth(self, new_depth): |
| 84 | return Shape4D(self.batch, self.height, self.width, new_depth) |
| 85 | |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 86 | def with_axis(self, axis, new_val): |
| 87 | shape_as_list = self.as_list() |
| 88 | shape_as_list[axis] = new_val |
| 89 | return Shape4D.from_list(shape_as_list) |
| 90 | |
| 91 | @staticmethod |
| 92 | def _clip_len(pos, length, size): |
| 93 | if pos < 0: |
| 94 | length = length + pos |
| 95 | pos = 0 |
| 96 | return min(pos + length, size) - pos |
| 97 | |
| 98 | def clip(self, offset, sub_shape): |
| 99 | n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch) |
| 100 | h = Shape4D._clip_len(offset.height, sub_shape.height, self.height) |
| 101 | w = Shape4D._clip_len(offset.width, sub_shape.width, self.width) |
| 102 | c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth) |
| 103 | return Shape4D(n, h, w, c) |
| 104 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 105 | def add(self, n, h, w, c): |
| 106 | return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c) |
| 107 | |
| 108 | def __add__(self, rhs): |
| 109 | return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth) |
| 110 | |
| 111 | def __sub__(self, rhs): |
| 112 | return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth) |
| 113 | |
| 114 | def __floordiv__(self, rhs): |
| 115 | return Shape4D( |
| 116 | self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth |
| 117 | ) |
| 118 | |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 119 | def __truediv__(self, rhs): |
| 120 | return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth) |
| 121 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 122 | def __mod__(self, rhs): |
| 123 | return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth) |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 124 | |
| 125 | def __str__(self): |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 126 | return f"<Shape4D {list(self)}>" |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 127 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 128 | def div_round_up(self, rhs): |
| 129 | return Shape4D( |
| 130 | round_up_divide(self.batch, rhs.batch), |
| 131 | round_up_divide(self.height, rhs.height), |
| 132 | round_up_divide(self.width, rhs.width), |
| 133 | round_up_divide(self.depth, rhs.depth), |
| 134 | ) |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 135 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 136 | def elements(self): |
| 137 | return self.batch * self.width * self.height * self.depth |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 138 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 139 | def elements_wh(self): |
| 140 | return self.width * self.height |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 141 | |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 142 | def is_empty(self): |
| 143 | return (self.batch + self.width + self.height + self.depth) == 0 |
patrik.gustavsson | eeb8515 | 2020-12-21 17:10:40 +0000 | [diff] [blame] | 144 | |
| 145 | def as_list(self): |
Tim Hall | 73e843f | 2021-02-04 22:47:46 +0000 | [diff] [blame] | 146 | return list(self) |
Patrik Gustavsson | 3a26920 | 2021-01-21 08:28:55 +0100 | [diff] [blame] | 147 | |
| 148 | def get_hw_as_list(self): |
| 149 | return list([self.height, self.width]) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 150 | |
| 151 | |
| 152 | class VolumeIterator: |
| 153 | """ |
| 154 | 4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes. |
| 155 | """ |
| 156 | |
| 157 | class Direction(Enum): |
| 158 | CWHN = 0 |
| 159 | |
| 160 | def __init__( |
| 161 | self, |
| 162 | shape: Shape4D, |
| 163 | sub_shape: Shape4D, |
| 164 | start: Shape4D = Shape4D(0, 0, 0, 0), |
| 165 | delta: Shape4D = None, |
| 166 | dir=Direction.CWHN, |
| 167 | ): |
| 168 | self.b = start.batch |
| 169 | self.y = start.height |
| 170 | self.x = start.width |
| 171 | self.z = start.depth |
| 172 | self.shape = shape |
| 173 | self.sub_shape = sub_shape |
| 174 | self.delta = sub_shape if delta is None else delta |
| 175 | assert self.delta.elements() > 0, "Iterator will not move" |
| 176 | |
| 177 | def __iter__(self): |
| 178 | return self |
| 179 | |
| 180 | def __next__(self): |
| 181 | if self.b >= self.shape.batch: |
| 182 | raise StopIteration() |
| 183 | |
| 184 | offset = Shape4D(self.b, self.y, self.x, self.z) |
| 185 | |
| 186 | # CWHN |
| 187 | self.z += self.delta.depth |
| 188 | if self.z >= self.shape.depth: |
| 189 | self.z = 0 |
| 190 | self.x += self.delta.width |
| 191 | if self.x >= self.shape.width: |
| 192 | self.x = 0 |
| 193 | self.y += self.delta.height |
| 194 | if self.y >= self.shape.height: |
| 195 | self.y = 0 |
| 196 | self.b += self.delta.batch |
| 197 | |
| 198 | return offset, self.shape.clip(offset, self.sub_shape) |