blob: 9e8b846dff6632f5fcaccd26902a0f9d2020b691 [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02003# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4#
Fredrik Svedberga0c36242020-06-03 15:43:31 +02005# SPDX-License-Identifier: Apache-2.0
6#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02007# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
Fredrik Svedberga0c36242020-06-03 15:43:31 +02009# You may obtain a copy of the License at
10#
Fredrik Svedberg1575b942020-08-18 13:19:18 +020011# http://www.apache.org/licenses/LICENSE-2.0
Fredrik Svedberga0c36242020-06-03 15:43:31 +020012#
13# Unless required by applicable law or agreed to in writing, software
Fredrik Svedberg1575b942020-08-18 13:19:18 +020014# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Fredrik Svedberga0c36242020-06-03 15:43:31 +020016# See the License for the specific language governing permissions and
17# limitations under the License.
Fredrik Svedberg1575b942020-08-18 13:19:18 +020018#
Fredrik Svedberga0c36242020-06-03 15:43:31 +020019# Description:
20# Contains SoftMax
Fredrik Svedberg1575b942020-08-18 13:19:18 +020021import math
22
Fredrik Svedberga0c36242020-06-03 15:43:31 +020023import numpy as np
24
Fredrik Svedberg1575b942020-08-18 13:19:18 +020025from . import fp_math
Fredrik Svedberga0c36242020-06-03 15:43:31 +020026from . import scaling
27from .data_type import DataType
28from .operation import Operation
Michael McGeagh5778ffd2020-08-06 17:31:02 +010029from .tensor import create_const_tensor
30from .tensor import create_reshape_tensor
Fredrik Svedberga0c36242020-06-03 15:43:31 +020031from .tensor import Tensor
32from .tensor import TensorPurpose
33
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035class SoftMax:
36 # Turn off black formatting for the LUT tables to keep them compact
37 # fmt: off
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +020038
Fredrik Svedberga0c36242020-06-03 15:43:31 +020039 EXP_LUT = [
40 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
41 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
42 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
43 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
44 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
45 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
46 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
47 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
48 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
49 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
50 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
51 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
52 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
53 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
54 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
55 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
56 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
57 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
58 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
59 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
60 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
61 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
62 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
63 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
64 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
65 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
66 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
67 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
68 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
69 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
70 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
71 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
72 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
73 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
74 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
75 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
76 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
77 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
78 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
79 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
80 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
81 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
82 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
83 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
84 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
85 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
86 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
87 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
88 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
89 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
90 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
91 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
92 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
93 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
94 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
95 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
96 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
97 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
98 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
99 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
100 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
101 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
102 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
103 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
104 ]
105
106 ONE_OVER_ONE_PLUS_X_LUT = [
107 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
108 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
109 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
110 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
111 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
112 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
113 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
114 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
115 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
116 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
117 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
118 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
119 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
120 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
121 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
122 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
123 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
124 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
125 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
126 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
127 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
128 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
129 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
130 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
131 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
132 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
133 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
134 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
135 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
136 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
137 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
138 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
139 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
140 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
141 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
142 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
143 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
144 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
145 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
146 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
147 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
148 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
149 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
150 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
151 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
152 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
153 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
154 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
155 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
156 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
157 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
158 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
159 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
160 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
161 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
162 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
163 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
164 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
165 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
166 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
167 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
168 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
169 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
170 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
171 ]
172 # fmt: on
173
174 def __init__(self, op):
175 self.op = op
176
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200177 def generate_exp_table(self, beta, input_scale):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200178 integer_bits = 5
179 total_signed_bits = 31
180 # Calculate scaling
181 real_beta = min(
182 np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
183 )
184 scale, shift = scaling.quantise_scale(real_beta)
185 shift = 31 - shift
186 diff_min = -1.0 * math.floor(
187 1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
188 )
189 # Generate the exp LUT
190 lut = []
191 for x in range(256):
192 input_diff = x - 255
193 if input_diff >= diff_min:
194 rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
195 lut.append(fp_math.exp_on_negative_values(rescale))
196 else:
197 lut.append(0)
198 return lut
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200199
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200200 def get_graph(self):
201 ifm = self.op.inputs[0]
202 ofm = self.op.outputs[0]
203
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200204 # Reshape ifm/ofm (if needed)
205 full_shape = ifm.get_full_shape()
206 if full_shape[0] > 1:
207 full_shape[1] *= full_shape[0]
208 full_shape[0] = 1
209 ifm = create_reshape_tensor(ifm, full_shape)
210 ofm = create_reshape_tensor(ofm, full_shape, False)
211
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200212 if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
213 return self.get_graph_8bit(ifm, ofm)
214 elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200215 return self.get_graph_int16(ifm, ofm)
216 else:
217 self.op.run_on_npu = False
218 return self.op
219
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200220 def get_graph_8bit(self, ifm, ofm):
221 exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200222 no_scale_quant = ifm.quantization.clone()
223 no_scale_quant.scale_f32 = None
224 no_scale_quant.zero_point = 0
225 one_scale_quant = ifm.quantization.clone()
226 one_scale_quant.scale_f32 = 1.0
227 one_scale_quant.zero_point = 0
228 ifm.quantization.zero_point = 0
229
230 # PASS 0 - Depthwise Maxpool
231 maxpool_op = self.op.clone("_maxpool0")
232 maxpool_op.type = "MaxPool"
233 maxpool_h = ifm.shape[1] * ifm.shape[2]
234 maxpool_w = ifm.shape[3]
235 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
236 maxpool_op.attrs["padding"] = b"VALID"
237 maxpool_op.attrs["stride_w"] = 1
238 maxpool_op.attrs["stride_h"] = 1
239 maxpool_op.attrs["filter_width"] = maxpool_w
240 maxpool_op.attrs["filter_height"] = 1
241 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
242 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
243 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
244 ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
245 ifm_max.quantization = no_scale_quant
246 maxpool_op.set_output_tensor(ifm_max)
247
248 # PASS 1 - Sub+LUT(exp)
249 sub_op = Operation("SubAct", self.op.name + "_sub1")
250 sub_op.add_input_tensor(ifm)
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200251 sub_op.add_input_tensor(create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200252 sub_op.set_activation_lut(
253 create_const_tensor(
254 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
255 )
256 )
257 ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
258 ifm_exp.quantization = one_scale_quant.clone()
259 ifm_exp.quantization.zero_point = 127
260 ifm_exp.quantization.quant_min = -128
261 ifm_exp.quantization.quant_max = 127
262 sub_op.set_output_tensor(ifm_exp)
263
264 # PASS 2 - SHR
265 shr2_op = Operation("SHR", self.op.name + "_shr2")
Tim Halld775e372020-08-28 18:33:38 +0100266 shr2_op.attrs["rounding_mode"] = b"NATURAL"
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200267 shr2_op.add_input_tensor(ifm_exp)
268 shr2_op.add_input_tensor(
269 create_const_tensor(
270 shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
271 ),
272 )
273 rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
274 rescaled_exp.quantization = no_scale_quant
275 shr2_op.set_output_tensor(rescaled_exp)
276
277 # PASS 3 - Reduce sum
278 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
279 reduce_sum_op.attrs["padding"] = b"VALID"
280 reduce_sum_op.attrs["stride_w"] = 1
281 reduce_sum_op.attrs["stride_h"] = 1
282 reduce_sum_op.attrs["filter_width"] = 1
283 reduce_sum_op.attrs["filter_height"] = 1
284 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
285 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
286 reduce_sum_op.add_input_tensor(rescaled_exp)
287
288 reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
289 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
290 sum_of_exp.quantization = no_scale_quant
291 reduce_sum_op.set_output_tensor(sum_of_exp)
292
293 # PASS 4 - CLZ
294 clz_op = Operation("CLZ", self.op.name + "_clz4")
295 clz_op.add_input_tensor(sum_of_exp)
296 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
297 headroom_plus_one.quantization = no_scale_quant
298 clz_op.set_output_tensor(headroom_plus_one)
299
300 # PASS 5 - Sub
301 sub5_op = Operation("SubAct", self.op.name + "_sub5")
302 sub5_op.add_input_tensor(
303 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200304 "headroom_offset_const",
305 [1, 1, 1, 1],
306 DataType.int32,
307 [12 + 31 - 8],
308 np.int32,
309 quantization=no_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200310 ),
311 )
312 sub5_op.add_input_tensor(headroom_plus_one)
313 right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
314 right_shift.quantization = no_scale_quant
315 sub5_op.set_output_tensor(right_shift)
316
317 # PASS 6 - Sub
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200318 one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200319 sub6_op = Operation("SubAct", self.op.name + "_sub6")
320 sub6_op.add_input_tensor(headroom_plus_one)
321 sub6_op.add_input_tensor(one)
322 headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
323 headroom.quantization = no_scale_quant
324 sub6_op.set_output_tensor(headroom)
325
326 # PASS 7 - SHL
327 shl7_op = Operation("SHL", self.op.name + "_shl7")
328 shl7_op.add_input_tensor(sum_of_exp)
329 shl7_op.add_input_tensor(headroom)
330 shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
331 shifted_sum.quantization = no_scale_quant
332 shl7_op.set_output_tensor(shifted_sum)
333
334 # PASS 8 - Sub
335 sub8_op = Operation("SubAct", self.op.name + "_sub8")
336 sub8_op.add_input_tensor(shifted_sum)
337 sub8_op.add_input_tensor(
338 create_const_tensor(
339 "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
340 ),
341 )
342 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
343 shifted_sum_minus_one.quantization = no_scale_quant
344 sub8_op.set_output_tensor(shifted_sum_minus_one)
345
346 # PASS 9 - SHL
347 shl9_op = Operation("SHL", self.op.name + "_shl9")
348 shl9_op.add_input_tensor(shifted_sum_minus_one)
349 shl9_op.add_input_tensor(one)
350 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
351 shifted_sum_minus_one.quantization = no_scale_quant
352 shl9_op.set_output_tensor(shifted_sum_minus_one)
353
354 # PASS 10 - Add
355 add10_op = Operation("AddAct", self.op.name + "_add10")
356 add10_op.add_input_tensor(
357 create_const_tensor(
358 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
359 ),
360 )
361 add10_op.add_input_tensor(shifted_sum_minus_one)
362 add10_op.attrs["rescale"] = [1, 1]
363 half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
364 half_denominator.quantization = one_scale_quant
365 add10_op.set_output_tensor(half_denominator)
366
367 # PASS 11 - Multiply
368 mul11_op = Operation("MulAct", self.op.name + "_mul11")
369 mul11_op.add_input_tensor(half_denominator)
370 mul11_op.add_input_tensor(
371 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200372 "neg_32_over_17_const",
373 [1, 1, 1, 1],
374 DataType.int32,
375 [-1010580540],
376 np.int32,
377 quantization=one_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200378 ),
379 )
380 rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
381 rescaled.quantization = one_scale_quant.clone()
382 rescaled.quantization.scale_f32 = 2.0
383 mul11_op.set_output_tensor(rescaled)
384
385 # PASS 12 - Add
386 add12_op = Operation("AddAct", self.op.name + "_add12")
387 add12_op.add_input_tensor(rescaled)
388 add12_op.add_input_tensor(
389 create_const_tensor(
390 "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
391 ),
392 )
393 rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
394 rescale_w_offset.quantization = one_scale_quant
395 add12_op.set_output_tensor(rescale_w_offset)
396
397 nr_x = rescale_w_offset
398 F2_one = create_const_tensor(
399 "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
400 )
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200401 four = create_const_tensor(
402 "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant
403 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200404 for i in range(3):
405 # PASS 13, 18, 23 - MUL
406 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
407 mul_op.add_input_tensor(nr_x)
408 mul_op.add_input_tensor(half_denominator)
409 half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
410 half_denominator_times_x.quantization = one_scale_quant.clone()
411 half_denominator_times_x.quantization.scale_f32 = 2.0
412 mul_op.set_output_tensor(half_denominator_times_x)
413 # PASS 14, 19, 24 - SUB
414 sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
415 sub_op.add_input_tensor(F2_one)
416 sub_op.add_input_tensor(half_denominator_times_x)
417 one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
418 one_minus_half_denominator_times_x.quantization = one_scale_quant
419 sub_op.set_output_tensor(one_minus_half_denominator_times_x)
420 # PASS 15, 20, 25 - MUL
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200421 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200422 mul_op.add_input_tensor(nr_x)
423 mul_op.add_input_tensor(one_minus_half_denominator_times_x)
424 to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
425 to_rescale.quantization = one_scale_quant.clone()
426 to_rescale.quantization.scale_f32 = 2.0
427 mul_op.set_output_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200428 # PASS 16, 21, 26 - MUL
429 shl_op = Operation("MulAct", self.op.name + "_mul%d" % (16 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200430 shl_op.add_input_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200431 shl_op.add_input_tensor(four)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200432 to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
433 to_add.quantization = no_scale_quant
434 shl_op.set_output_tensor(to_add)
435 # PASS 17, 22, 27 - ADD
436 add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
437 add_op.add_input_tensor(nr_x)
438 add_op.add_input_tensor(to_add)
439 nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
440 nr_x.quantization = one_scale_quant
441 add_op.set_output_tensor(nr_x)
442
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200443 # PASS 28 - Multiply
444 mul28_op = Operation("MulAct", self.op.name + "_mul28")
445 mul28_op.add_input_tensor(nr_x)
446 mul28_op.add_input_tensor(
447 create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
448 )
449 scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200450 scale_factor.quantization = one_scale_quant
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200451 mul28_op.set_output_tensor(scale_factor)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200452
453 # PASS 29 - Multiply
454 mul_op = Operation("MulAct", self.op.name + "_mul29")
455 mul_op.add_input_tensor(ifm_exp)
456 mul_op.add_input_tensor(scale_factor)
457 scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
458 scaled_exp.quantization = one_scale_quant.clone()
459 scaled_exp.quantization.scale_f32 = 2.0
460 mul_op.set_output_tensor(scaled_exp)
461
462 # PASS 30 - SHR
463 shr30_op = Operation("SHR", self.op.name + "_shr30")
Tim Halld775e372020-08-28 18:33:38 +0100464 shr30_op.attrs["rounding_mode"] = b"NATURAL"
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200465 shr30_op.add_input_tensor(scaled_exp)
466 shr30_op.add_input_tensor(right_shift)
467 shr30_op.set_output_tensor(ofm)
468
469 return shr30_op
470
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200471 def get_graph_int16(self, ifm, ofm):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200472 no_scale_quant = ifm.quantization.clone()
473 no_scale_quant.scale_f32 = None
474
475 # PASS 0 - Depthwise Maxpool
476 maxpool_op = self.op.clone("_maxpool0")
477 maxpool_op.type = "MaxPool"
478 maxpool_h = ifm.shape[1] * ifm.shape[2]
479 maxpool_w = ifm.shape[3]
480 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
481 maxpool_op.attrs["padding"] = b"VALID"
482 maxpool_op.attrs["stride_w"] = 1
483 maxpool_op.attrs["stride_h"] = 1
484 maxpool_op.attrs["filter_width"] = maxpool_w
485 maxpool_op.attrs["filter_height"] = 1
486 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
487 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100488 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200489 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200490 maxpool_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100491 maxpool_op.set_output_tensor(maxpool_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200492
493 # PASS 1 - Sub
494 sub1_op = Operation("SubAct", self.op.name + "_sub1")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100495 sub1_op.add_input_tensor(ifm)
496 sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200497 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
498 sub1_ofm.quantization = ifm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100499 sub1_op.set_output_tensor(sub1_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200500
501 # PASS 2 - Mul
502 beta = self.op.attrs.get("beta", 1.0)
503 mul2_out_range = 10.0 / 65535.0
504 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
505 mul2_quant = ifm.quantization.clone()
506 mul2_quant.scale_f32 = beta
507 mul2_op = Operation("MulAct", self.op.name + "_mul2")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100508 mul2_op.add_input_tensor(sub1_ofm)
509 mul2_op.add_input_tensor(
510 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200511 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200512 ),
513 )
514 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
515 mul2_ofm.quantization = ofm.quantization.clone()
516 mul2_ofm.quantization.scale_f32 = mul2_out_range
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100517 mul2_op.set_output_tensor(mul2_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200518
519 # PASS 3 - Add+LUT(exp)
520 add_op = Operation("AddAct", self.op.name + "_add3")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100521 add_op.add_input_tensor(mul2_ofm)
522 add_op.add_input_tensor(
523 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200524 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200525 ),
526 )
527 add_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100528 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200529 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200530 )
531 )
532 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
533 exp_ofm.quantization = mul2_ofm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100534 add_op.set_output_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200535
536 # PASS 4 - Reduce sum
537 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
538 reduce_sum_op.attrs["padding"] = b"VALID"
539 reduce_sum_op.attrs["stride_w"] = 1
540 reduce_sum_op.attrs["stride_h"] = 1
541 reduce_sum_op.attrs["filter_width"] = 1
542 reduce_sum_op.attrs["filter_height"] = 1
543 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
544 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100545 reduce_sum_op.add_input_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200546
547 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
548 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
549 sum_of_exp.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100550 reduce_sum_op.set_output_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200551
552 # PASS 5 - CLZ
553 clz_op = Operation("CLZ", self.op.name + "_clz5")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100554 clz_op.add_input_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200555 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
556 headroom_plus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100557 clz_op.set_output_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200558
559 # PASS 6 - Sub
560 sub6_op = Operation("SubAct", self.op.name + "_sub6")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100561 sub6_op.add_input_tensor(
562 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200563 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200564 ),
565 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200566 sub6_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200567 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
568 reciprocal_right_shift.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100569 sub6_op.set_output_tensor(reciprocal_right_shift)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200570
571 # PASS 7 - SHL
572 shl7_op = Operation("SHL", self.op.name + "_shl7")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100573 shl7_op.add_input_tensor(
574 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200575 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200576 ),
577 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200578 shl7_op.add_input_tensor(reciprocal_right_shift)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200579 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200580 constant_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100581 shl7_op.set_output_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200582
583 # PASS 8 - Sub
584 sub8_op = Operation("SubAct", self.op.name + "_sub8")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100585 sub8_op.add_input_tensor(sum_of_exp)
586 sub8_op.add_input_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200587 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
588 sum_of_exps_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100589 sub8_op.set_output_tensor(sum_of_exps_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200590
591 # PASS 9 - SHL
592 shl9_op = Operation("SHL", self.op.name + "_shl9")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100593 shl9_op.add_input_tensor(sum_of_exps_minus_one)
594 shl9_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200595 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
596 shifted_sum_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100597 shl9_op.set_output_tensor(shifted_sum_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200598
599 # PASS 10 - SHR
600 shr10_op = Operation("SHR", self.op.name + "_shr10")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100601 shr10_op.add_input_tensor(shifted_sum_minus_one)
602 shr10_op.add_input_tensor(
603 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200604 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200605 ),
606 )
607 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
608 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100609 shr10_op.set_output_tensor(shifted_sum_minus_one_16)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200610
611 # PASS 11 - Sub+LUT(one over one plus x)
612 sub11_op = Operation("SubAct", self.op.name + "_sub11")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100613 sub11_op.add_input_tensor(shifted_sum_minus_one_16)
614 sub11_op.add_input_tensor(
615 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200616 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200617 ),
618 )
619 sub11_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100620 create_const_tensor(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200621 sub11_op.name + "_lut",
622 [1, 1, 1, 512],
623 DataType.int32,
624 self.ONE_OVER_ONE_PLUS_X_LUT,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200625 np.int32,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200626 TensorPurpose.LUT,
627 )
628 )
629 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
630 reciprocal_scale.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100631 sub11_op.set_output_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200632
633 # PASS 12 - Multiply
634 mul_op = Operation("MulAct", self.op.name + "_mul12")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100635 mul_op.add_input_tensor(exp_ofm)
636 mul_op.add_input_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200637 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
638 mul_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100639 mul_op.set_output_tensor(mul_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200640
641 # PASS 13 - SHR
642 shr13_op = Operation("SHR", self.op.name + "_shr13")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100643 shr13_op.add_input_tensor(mul_ofm)
644 shr13_op.add_input_tensor(reciprocal_right_shift)
645 shr13_op.set_output_tensor(ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200646
647 return shr13_op