blob: 000c78e9343b41428c3eb106a14545caade4689d [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Contains SoftMax
18import numpy as np
19
20from . import scaling
21from .data_type import DataType
22from .operation import Operation
23from .tensor import Tensor
24from .tensor import TensorPurpose
25
26
27class TensorUtil:
28 # TODO: Move these functions to Tensor/Operation classes
29 @staticmethod
30 def create_const_tensor(
31 name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None
32 ):
33 const_op = Operation("Const", name)
34 const_tensor = Tensor(shape, dtype, name + "_0")
35 const_tensor.purpose = purpose
36 const_tensor.quantization = quantization
37 const_tensor.values = np.array(values, dtype=value_dtype)
38 const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
39 const_tensor.ops.append(const_op)
40 const_op.outputs.append(const_tensor)
41 return const_tensor
42
43 @staticmethod
44 def add_ifm_tensor(op, tens):
45 op.inputs.append(tens)
46 tens.consumer_list.append(op)
47
48 @staticmethod
49 def set_ofm_tensor(op, tens):
50 tens.ops = [op]
51 op.outputs = [tens]
52
53 @staticmethod
54 def reshape(tens, shape, ifm_reshape=True):
55 if shape == tens.shape:
56 return tens
57 name = tens.name + "_reshape"
58 reshape_op = Operation("Reshape", name)
59 reshape_op.attrs["new_shape"] = shape
60 reshape_ifm = tens
61 reshape_ofm = tens.clone("_reshaped")
62 reshape_ofm.shape = reshape_ofm.storage_shape = reshape_ofm.bandwidth_shape = shape
63 if not ifm_reshape:
64 reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
65 reshape_op.inputs = [reshape_ifm, TensorUtil.create_const_tensor(name + "_shape", [1], DataType.int32, shape)]
66 TensorUtil.set_ofm_tensor(reshape_op, reshape_ofm)
67 return reshape_ofm if ifm_reshape else reshape_ifm
68
69 @staticmethod
70 def get_full_shape(shape):
71 d = len(shape)
72 if d in (1, 3):
73 return [1] * (4 - d) + shape
74 elif d == 2:
75 return [shape[0], 1, 1, shape[1]]
76 else:
77 return shape
78
79
80class SoftMax:
81 # Turn off black formatting for the LUT tables to keep them compact
82 # fmt: off
83 EXP_LUT = [
84 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
85 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
86 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
87 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
88 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
89 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
90 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
91 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
92 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
93 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
94 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
95 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
96 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
97 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
98 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
99 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
100 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
101 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
102 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
103 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
104 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
105 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
106 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
107 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
108 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
109 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
110 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
111 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
112 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
113 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
114 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
115 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
116 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
117 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
118 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
119 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
120 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
121 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
122 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
123 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
124 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
125 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
126 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
127 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
128 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
129 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
130 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
131 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
132 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
133 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
134 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
135 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
136 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
137 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
138 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
139 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
140 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
141 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
142 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
143 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
144 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
145 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
146 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
147 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
148 ]
149
150 ONE_OVER_ONE_PLUS_X_LUT = [
151 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
152 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
153 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
154 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
155 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
156 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
157 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
158 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
159 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
160 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
161 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
162 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
163 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
164 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
165 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
166 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
167 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
168 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
169 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
170 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
171 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
172 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
173 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
174 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
175 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
176 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
177 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
178 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
179 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
180 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
181 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
182 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
183 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
184 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
185 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
186 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
187 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
188 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
189 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
190 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
191 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
192 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
193 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
194 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
195 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
196 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
197 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
198 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
199 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
200 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
201 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
202 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
203 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
204 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
205 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
206 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
207 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
208 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
209 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
210 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
211 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
212 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
213 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
214 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
215 ]
216 # fmt: on
217
218 def __init__(self, op):
219 self.op = op
220
221 def get_graph(self):
222 ifm = self.op.inputs[0]
223 ofm = self.op.outputs[0]
224
225 if ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
226 return self.get_graph_int16(ifm, ofm)
227 else:
228 self.op.run_on_npu = False
229 return self.op
230
231 def get_graph_int16(self, ifm, ofm):
232 ifm = TensorUtil.reshape(ifm, TensorUtil.get_full_shape(ifm.shape))
233 ofm = TensorUtil.reshape(ofm, TensorUtil.get_full_shape(ofm.shape), False)
234 no_scale_quant = ifm.quantization.clone()
235 no_scale_quant.scale_f32 = None
236
237 # PASS 0 - Depthwise Maxpool
238 maxpool_op = self.op.clone("_maxpool0")
239 maxpool_op.type = "MaxPool"
240 maxpool_h = ifm.shape[1] * ifm.shape[2]
241 maxpool_w = ifm.shape[3]
242 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
243 maxpool_op.attrs["padding"] = b"VALID"
244 maxpool_op.attrs["stride_w"] = 1
245 maxpool_op.attrs["stride_h"] = 1
246 maxpool_op.attrs["filter_width"] = maxpool_w
247 maxpool_op.attrs["filter_height"] = 1
248 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
249 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
250 maxpool_op.inputs = [TensorUtil.reshape(ifm, maxpool_ifm_shape)]
251 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], DataType.int16, maxpool_op.name + "_0")
252 maxpool_ofm.quantization = no_scale_quant
253 TensorUtil.set_ofm_tensor(maxpool_op, maxpool_ofm)
254
255 # PASS 1 - Sub
256 sub1_op = Operation("SubAct", self.op.name + "_sub1")
257 TensorUtil.add_ifm_tensor(sub1_op, ifm)
258 TensorUtil.add_ifm_tensor(sub1_op, TensorUtil.reshape(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
259 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
260 sub1_ofm.quantization = ifm.quantization.clone()
261 TensorUtil.set_ofm_tensor(sub1_op, sub1_ofm)
262
263 # PASS 2 - Mul
264 beta = self.op.attrs.get("beta", 1.0)
265 mul2_out_range = 10.0 / 65535.0
266 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
267 mul2_quant = ifm.quantization.clone()
268 mul2_quant.scale_f32 = beta
269 mul2_op = Operation("MulAct", self.op.name + "_mul2")
270 TensorUtil.add_ifm_tensor(mul2_op, sub1_ofm)
271 TensorUtil.add_ifm_tensor(
272 mul2_op,
273 TensorUtil.create_const_tensor(
274 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.uint32, quantization=mul2_quant
275 ),
276 )
277 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
278 mul2_ofm.quantization = ofm.quantization.clone()
279 mul2_ofm.quantization.scale_f32 = mul2_out_range
280 TensorUtil.set_ofm_tensor(mul2_op, mul2_ofm)
281
282 # PASS 3 - Add+LUT(exp)
283 add_op = Operation("AddAct", self.op.name + "_add3")
284 TensorUtil.add_ifm_tensor(add_op, mul2_ofm)
285 TensorUtil.add_ifm_tensor(
286 add_op,
287 TensorUtil.create_const_tensor(
288 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.uint32, quantization=no_scale_quant
289 ),
290 )
291 add_op.set_activation_lut(
292 TensorUtil.create_const_tensor(
293 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.uint32, TensorPurpose.LUT
294 )
295 )
296 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
297 exp_ofm.quantization = mul2_ofm.quantization.clone()
298 TensorUtil.set_ofm_tensor(add_op, exp_ofm)
299
300 # PASS 4 - Reduce sum
301 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
302 reduce_sum_op.attrs["padding"] = b"VALID"
303 reduce_sum_op.attrs["stride_w"] = 1
304 reduce_sum_op.attrs["stride_h"] = 1
305 reduce_sum_op.attrs["filter_width"] = 1
306 reduce_sum_op.attrs["filter_height"] = 1
307 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
308 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
309 TensorUtil.add_ifm_tensor(reduce_sum_op, exp_ofm)
310
311 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
312 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
313 sum_of_exp.quantization = no_scale_quant
314 TensorUtil.set_ofm_tensor(reduce_sum_op, sum_of_exp)
315
316 # PASS 5 - CLZ
317 clz_op = Operation("CLZ", self.op.name + "_clz5")
318 TensorUtil.add_ifm_tensor(clz_op, sum_of_exp)
319 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
320 headroom_plus_one.quantization = no_scale_quant
321 TensorUtil.set_ofm_tensor(clz_op, headroom_plus_one)
322
323 # PASS 6 - Sub
324 sub6_op = Operation("SubAct", self.op.name + "_sub6")
325 TensorUtil.add_ifm_tensor(sub6_op, headroom_plus_one)
326 TensorUtil.add_ifm_tensor(
327 sub6_op,
328 TensorUtil.create_const_tensor(
329 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.uint32, quantization=no_scale_quant
330 ),
331 )
332 # TODO: Adding this attribute to reverse the operand order is not ideal
333 # it should be handled automatically by register_command_stream_generator
334 # or added as an internal operator.
335 sub6_op.attrs["reverse_op_order"] = True
336 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
337 reciprocal_right_shift.quantization = no_scale_quant
338 TensorUtil.set_ofm_tensor(sub6_op, reciprocal_right_shift)
339
340 # PASS 7 - SHL
341 shl7_op = Operation("SHL", self.op.name + "_shl7")
342 TensorUtil.add_ifm_tensor(shl7_op, reciprocal_right_shift)
343 TensorUtil.add_ifm_tensor(
344 shl7_op,
345 TensorUtil.create_const_tensor(
346 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.uint32, quantization=no_scale_quant
347 ),
348 )
349 # TODO: See above
350 shl7_op.attrs["reverse_op_order"] = True
351 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "0")
352 constant_one.quantization = no_scale_quant
353 TensorUtil.set_ofm_tensor(shl7_op, constant_one)
354
355 # PASS 8 - Sub
356 sub8_op = Operation("SubAct", self.op.name + "_sub8")
357 TensorUtil.add_ifm_tensor(sub8_op, sum_of_exp)
358 TensorUtil.add_ifm_tensor(sub8_op, constant_one)
359 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
360 sum_of_exps_minus_one.quantization = no_scale_quant
361 TensorUtil.set_ofm_tensor(sub8_op, sum_of_exps_minus_one)
362
363 # PASS 9 - SHL
364 shl9_op = Operation("SHL", self.op.name + "_shl9")
365 TensorUtil.add_ifm_tensor(shl9_op, sum_of_exps_minus_one)
366 TensorUtil.add_ifm_tensor(shl9_op, headroom_plus_one)
367 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
368 shifted_sum_minus_one.quantization = no_scale_quant
369 TensorUtil.set_ofm_tensor(shl9_op, shifted_sum_minus_one)
370
371 # PASS 10 - SHR
372 shr10_op = Operation("SHR", self.op.name + "_shr10")
373 TensorUtil.add_ifm_tensor(shr10_op, shifted_sum_minus_one)
374 TensorUtil.add_ifm_tensor(
375 shr10_op,
376 TensorUtil.create_const_tensor(
377 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.uint32, quantization=no_scale_quant
378 ),
379 )
380 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
381 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
382 TensorUtil.set_ofm_tensor(shr10_op, shifted_sum_minus_one_16)
383
384 # PASS 11 - Sub+LUT(one over one plus x)
385 sub11_op = Operation("SubAct", self.op.name + "_sub11")
386 TensorUtil.add_ifm_tensor(sub11_op, shifted_sum_minus_one_16)
387 TensorUtil.add_ifm_tensor(
388 sub11_op,
389 TensorUtil.create_const_tensor(
390 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.uint32, quantization=no_scale_quant
391 ),
392 )
393 sub11_op.set_activation_lut(
394 TensorUtil.create_const_tensor(
395 sub11_op.name + "_lut",
396 [1, 1, 1, 512],
397 DataType.int32,
398 self.ONE_OVER_ONE_PLUS_X_LUT,
399 np.uint32,
400 TensorPurpose.LUT,
401 )
402 )
403 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
404 reciprocal_scale.quantization = no_scale_quant
405 TensorUtil.set_ofm_tensor(sub11_op, reciprocal_scale)
406
407 # PASS 12 - Multiply
408 mul_op = Operation("MulAct", self.op.name + "_mul12")
409 TensorUtil.add_ifm_tensor(mul_op, exp_ofm)
410 TensorUtil.add_ifm_tensor(mul_op, reciprocal_scale)
411 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
412 mul_ofm.quantization = no_scale_quant
413 TensorUtil.set_ofm_tensor(mul_op, mul_ofm)
414
415 # PASS 13 - SHR
416 shr13_op = Operation("SHR", self.op.name + "_shr13")
417 TensorUtil.add_ifm_tensor(shr13_op, mul_ofm)
418 TensorUtil.add_ifm_tensor(shr13_op, reciprocal_right_shift)
419 TensorUtil.set_ofm_tensor(shr13_op, ofm)
420
421 return shr13_op