blob: 7c23f472b399266a10151dcd5b66c2171c867fa7 [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02003# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4#
Fredrik Svedberga0c36242020-06-03 15:43:31 +02005# SPDX-License-Identifier: Apache-2.0
6#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02007# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
Fredrik Svedberga0c36242020-06-03 15:43:31 +02009# You may obtain a copy of the License at
10#
Fredrik Svedberg1575b942020-08-18 13:19:18 +020011# http://www.apache.org/licenses/LICENSE-2.0
Fredrik Svedberga0c36242020-06-03 15:43:31 +020012#
13# Unless required by applicable law or agreed to in writing, software
Fredrik Svedberg1575b942020-08-18 13:19:18 +020014# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Fredrik Svedberga0c36242020-06-03 15:43:31 +020016# See the License for the specific language governing permissions and
17# limitations under the License.
Fredrik Svedberg1575b942020-08-18 13:19:18 +020018#
Fredrik Svedberga0c36242020-06-03 15:43:31 +020019# Description:
20# Contains SoftMax
Fredrik Svedberg1575b942020-08-18 13:19:18 +020021import math
22
Fredrik Svedberga0c36242020-06-03 15:43:31 +020023import numpy as np
24
Fredrik Svedberg1575b942020-08-18 13:19:18 +020025from . import fp_math
Fredrik Svedberga0c36242020-06-03 15:43:31 +020026from . import scaling
27from .data_type import DataType
28from .operation import Operation
Michael McGeagh5778ffd2020-08-06 17:31:02 +010029from .tensor import create_const_tensor
30from .tensor import create_reshape_tensor
Fredrik Svedberga0c36242020-06-03 15:43:31 +020031from .tensor import Tensor
32from .tensor import TensorPurpose
33
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035class SoftMax:
36 # Turn off black formatting for the LUT tables to keep them compact
37 # fmt: off
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +020038
Fredrik Svedberga0c36242020-06-03 15:43:31 +020039 EXP_LUT = [
40 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
41 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
42 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
43 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
44 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
45 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
46 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
47 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
48 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
49 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
50 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
51 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
52 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
53 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
54 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
55 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
56 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
57 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
58 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
59 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
60 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
61 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
62 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
63 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
64 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
65 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
66 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
67 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
68 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
69 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
70 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
71 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
72 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
73 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
74 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
75 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
76 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
77 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
78 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
79 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
80 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
81 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
82 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
83 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
84 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
85 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
86 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
87 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
88 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
89 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
90 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
91 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
92 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
93 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
94 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
95 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
96 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
97 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
98 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
99 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
100 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
101 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
102 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
103 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
104 ]
105
106 ONE_OVER_ONE_PLUS_X_LUT = [
107 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
108 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
109 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
110 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
111 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
112 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
113 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
114 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
115 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
116 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
117 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
118 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
119 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
120 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
121 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
122 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
123 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
124 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
125 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
126 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
127 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
128 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
129 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
130 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
131 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
132 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
133 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
134 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
135 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
136 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
137 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
138 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
139 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
140 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
141 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
142 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
143 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
144 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
145 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
146 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
147 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
148 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
149 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
150 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
151 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
152 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
153 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
154 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
155 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
156 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
157 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
158 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
159 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
160 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
161 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
162 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
163 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
164 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
165 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
166 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
167 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
168 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
169 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
170 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
171 ]
172 # fmt: on
173
174 def __init__(self, op):
175 self.op = op
176
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200177 def generate_exp_table(self, beta, input_scale):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200178 integer_bits = 5
179 total_signed_bits = 31
180 # Calculate scaling
181 real_beta = min(
182 np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
183 )
184 scale, shift = scaling.quantise_scale(real_beta)
185 shift = 31 - shift
186 diff_min = -1.0 * math.floor(
187 1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
188 )
189 # Generate the exp LUT
190 lut = []
191 for x in range(256):
192 input_diff = x - 255
193 if input_diff >= diff_min:
194 rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
195 lut.append(fp_math.exp_on_negative_values(rescale))
196 else:
197 lut.append(0)
198 return lut
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200199
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200200 def get_graph(self):
201 ifm = self.op.inputs[0]
202 ofm = self.op.outputs[0]
203
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200204 if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
205 return self.get_graph_8bit(ifm, ofm)
206 elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200207 return self.get_graph_int16(ifm, ofm)
208 else:
209 self.op.run_on_npu = False
210 return self.op
211
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200212 def get_graph_8bit(self, ifm, ofm):
213 exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
214 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
215 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
216 no_scale_quant = ifm.quantization.clone()
217 no_scale_quant.scale_f32 = None
218 no_scale_quant.zero_point = 0
219 one_scale_quant = ifm.quantization.clone()
220 one_scale_quant.scale_f32 = 1.0
221 one_scale_quant.zero_point = 0
222 ifm.quantization.zero_point = 0
223
224 # PASS 0 - Depthwise Maxpool
225 maxpool_op = self.op.clone("_maxpool0")
226 maxpool_op.type = "MaxPool"
227 maxpool_h = ifm.shape[1] * ifm.shape[2]
228 maxpool_w = ifm.shape[3]
229 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
230 maxpool_op.attrs["padding"] = b"VALID"
231 maxpool_op.attrs["stride_w"] = 1
232 maxpool_op.attrs["stride_h"] = 1
233 maxpool_op.attrs["filter_width"] = maxpool_w
234 maxpool_op.attrs["filter_height"] = 1
235 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
236 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
237 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
238 ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
239 ifm_max.quantization = no_scale_quant
240 maxpool_op.set_output_tensor(ifm_max)
241
242 # PASS 1 - Sub+LUT(exp)
243 sub_op = Operation("SubAct", self.op.name + "_sub1")
244 sub_op.add_input_tensor(ifm)
245 sub_op.add_input_tensor(ifm_max)
246 sub_op.set_activation_lut(
247 create_const_tensor(
248 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
249 )
250 )
251 ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
252 ifm_exp.quantization = one_scale_quant.clone()
253 ifm_exp.quantization.zero_point = 127
254 ifm_exp.quantization.quant_min = -128
255 ifm_exp.quantization.quant_max = 127
256 sub_op.set_output_tensor(ifm_exp)
257
258 # PASS 2 - SHR
259 shr2_op = Operation("SHR", self.op.name + "_shr2")
260 shr2_op.add_input_tensor(ifm_exp)
261 shr2_op.add_input_tensor(
262 create_const_tensor(
263 shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
264 ),
265 )
266 rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
267 rescaled_exp.quantization = no_scale_quant
268 shr2_op.set_output_tensor(rescaled_exp)
269
270 # PASS 3 - Reduce sum
271 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
272 reduce_sum_op.attrs["padding"] = b"VALID"
273 reduce_sum_op.attrs["stride_w"] = 1
274 reduce_sum_op.attrs["stride_h"] = 1
275 reduce_sum_op.attrs["filter_width"] = 1
276 reduce_sum_op.attrs["filter_height"] = 1
277 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
278 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
279 reduce_sum_op.add_input_tensor(rescaled_exp)
280
281 reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
282 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
283 sum_of_exp.quantization = no_scale_quant
284 reduce_sum_op.set_output_tensor(sum_of_exp)
285
286 # PASS 4 - CLZ
287 clz_op = Operation("CLZ", self.op.name + "_clz4")
288 clz_op.add_input_tensor(sum_of_exp)
289 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
290 headroom_plus_one.quantization = no_scale_quant
291 clz_op.set_output_tensor(headroom_plus_one)
292
293 # PASS 5 - Sub
294 sub5_op = Operation("SubAct", self.op.name + "_sub5")
295 sub5_op.add_input_tensor(
296 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200297 "headroom_offset_const",
298 [1, 1, 1, 1],
299 DataType.int32,
300 [12 + 31 - 8],
301 np.int32,
302 quantization=no_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200303 ),
304 )
305 sub5_op.add_input_tensor(headroom_plus_one)
306 right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
307 right_shift.quantization = no_scale_quant
308 sub5_op.set_output_tensor(right_shift)
309
310 # PASS 6 - Sub
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200311 one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200312 sub6_op = Operation("SubAct", self.op.name + "_sub6")
313 sub6_op.add_input_tensor(headroom_plus_one)
314 sub6_op.add_input_tensor(one)
315 headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
316 headroom.quantization = no_scale_quant
317 sub6_op.set_output_tensor(headroom)
318
319 # PASS 7 - SHL
320 shl7_op = Operation("SHL", self.op.name + "_shl7")
321 shl7_op.add_input_tensor(sum_of_exp)
322 shl7_op.add_input_tensor(headroom)
323 shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
324 shifted_sum.quantization = no_scale_quant
325 shl7_op.set_output_tensor(shifted_sum)
326
327 # PASS 8 - Sub
328 sub8_op = Operation("SubAct", self.op.name + "_sub8")
329 sub8_op.add_input_tensor(shifted_sum)
330 sub8_op.add_input_tensor(
331 create_const_tensor(
332 "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
333 ),
334 )
335 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
336 shifted_sum_minus_one.quantization = no_scale_quant
337 sub8_op.set_output_tensor(shifted_sum_minus_one)
338
339 # PASS 9 - SHL
340 shl9_op = Operation("SHL", self.op.name + "_shl9")
341 shl9_op.add_input_tensor(shifted_sum_minus_one)
342 shl9_op.add_input_tensor(one)
343 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
344 shifted_sum_minus_one.quantization = no_scale_quant
345 shl9_op.set_output_tensor(shifted_sum_minus_one)
346
347 # PASS 10 - Add
348 add10_op = Operation("AddAct", self.op.name + "_add10")
349 add10_op.add_input_tensor(
350 create_const_tensor(
351 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
352 ),
353 )
354 add10_op.add_input_tensor(shifted_sum_minus_one)
355 add10_op.attrs["rescale"] = [1, 1]
356 half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
357 half_denominator.quantization = one_scale_quant
358 add10_op.set_output_tensor(half_denominator)
359
360 # PASS 11 - Multiply
361 mul11_op = Operation("MulAct", self.op.name + "_mul11")
362 mul11_op.add_input_tensor(half_denominator)
363 mul11_op.add_input_tensor(
364 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200365 "neg_32_over_17_const",
366 [1, 1, 1, 1],
367 DataType.int32,
368 [-1010580540],
369 np.int32,
370 quantization=one_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200371 ),
372 )
373 rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
374 rescaled.quantization = one_scale_quant.clone()
375 rescaled.quantization.scale_f32 = 2.0
376 mul11_op.set_output_tensor(rescaled)
377
378 # PASS 12 - Add
379 add12_op = Operation("AddAct", self.op.name + "_add12")
380 add12_op.add_input_tensor(rescaled)
381 add12_op.add_input_tensor(
382 create_const_tensor(
383 "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
384 ),
385 )
386 rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
387 rescale_w_offset.quantization = one_scale_quant
388 add12_op.set_output_tensor(rescale_w_offset)
389
390 nr_x = rescale_w_offset
391 F2_one = create_const_tensor(
392 "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
393 )
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200394 four = create_const_tensor(
395 "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant
396 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200397 for i in range(3):
398 # PASS 13, 18, 23 - MUL
399 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
400 mul_op.add_input_tensor(nr_x)
401 mul_op.add_input_tensor(half_denominator)
402 half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
403 half_denominator_times_x.quantization = one_scale_quant.clone()
404 half_denominator_times_x.quantization.scale_f32 = 2.0
405 mul_op.set_output_tensor(half_denominator_times_x)
406 # PASS 14, 19, 24 - SUB
407 sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
408 sub_op.add_input_tensor(F2_one)
409 sub_op.add_input_tensor(half_denominator_times_x)
410 one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
411 one_minus_half_denominator_times_x.quantization = one_scale_quant
412 sub_op.set_output_tensor(one_minus_half_denominator_times_x)
413 # PASS 15, 20, 25 - MUL
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200414 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200415 mul_op.add_input_tensor(nr_x)
416 mul_op.add_input_tensor(one_minus_half_denominator_times_x)
417 to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
418 to_rescale.quantization = one_scale_quant.clone()
419 to_rescale.quantization.scale_f32 = 2.0
420 mul_op.set_output_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200421 # PASS 16, 21, 26 - MUL
422 shl_op = Operation("MulAct", self.op.name + "_mul%d" % (16 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200423 shl_op.add_input_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200424 shl_op.add_input_tensor(four)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200425 to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
426 to_add.quantization = no_scale_quant
427 shl_op.set_output_tensor(to_add)
428 # PASS 17, 22, 27 - ADD
429 add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
430 add_op.add_input_tensor(nr_x)
431 add_op.add_input_tensor(to_add)
432 nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
433 nr_x.quantization = one_scale_quant
434 add_op.set_output_tensor(nr_x)
435
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200436 # PASS 28 - Multiply
437 mul28_op = Operation("MulAct", self.op.name + "_mul28")
438 mul28_op.add_input_tensor(nr_x)
439 mul28_op.add_input_tensor(
440 create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
441 )
442 scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200443 scale_factor.quantization = one_scale_quant
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200444 mul28_op.set_output_tensor(scale_factor)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200445
446 # PASS 29 - Multiply
447 mul_op = Operation("MulAct", self.op.name + "_mul29")
448 mul_op.add_input_tensor(ifm_exp)
449 mul_op.add_input_tensor(scale_factor)
450 scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
451 scaled_exp.quantization = one_scale_quant.clone()
452 scaled_exp.quantization.scale_f32 = 2.0
453 mul_op.set_output_tensor(scaled_exp)
454
455 # PASS 30 - SHR
456 shr30_op = Operation("SHR", self.op.name + "_shr30")
457 shr30_op.add_input_tensor(scaled_exp)
458 shr30_op.add_input_tensor(right_shift)
459 shr30_op.set_output_tensor(ofm)
460
461 return shr30_op
462
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200463 def get_graph_int16(self, ifm, ofm):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100464 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
465 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200466 no_scale_quant = ifm.quantization.clone()
467 no_scale_quant.scale_f32 = None
468
469 # PASS 0 - Depthwise Maxpool
470 maxpool_op = self.op.clone("_maxpool0")
471 maxpool_op.type = "MaxPool"
472 maxpool_h = ifm.shape[1] * ifm.shape[2]
473 maxpool_w = ifm.shape[3]
474 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
475 maxpool_op.attrs["padding"] = b"VALID"
476 maxpool_op.attrs["stride_w"] = 1
477 maxpool_op.attrs["stride_h"] = 1
478 maxpool_op.attrs["filter_width"] = maxpool_w
479 maxpool_op.attrs["filter_height"] = 1
480 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
481 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100482 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200483 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200484 maxpool_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100485 maxpool_op.set_output_tensor(maxpool_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200486
487 # PASS 1 - Sub
488 sub1_op = Operation("SubAct", self.op.name + "_sub1")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100489 sub1_op.add_input_tensor(ifm)
490 sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200491 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
492 sub1_ofm.quantization = ifm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100493 sub1_op.set_output_tensor(sub1_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200494
495 # PASS 2 - Mul
496 beta = self.op.attrs.get("beta", 1.0)
497 mul2_out_range = 10.0 / 65535.0
498 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
499 mul2_quant = ifm.quantization.clone()
500 mul2_quant.scale_f32 = beta
501 mul2_op = Operation("MulAct", self.op.name + "_mul2")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100502 mul2_op.add_input_tensor(sub1_ofm)
503 mul2_op.add_input_tensor(
504 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200505 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200506 ),
507 )
508 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
509 mul2_ofm.quantization = ofm.quantization.clone()
510 mul2_ofm.quantization.scale_f32 = mul2_out_range
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100511 mul2_op.set_output_tensor(mul2_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200512
513 # PASS 3 - Add+LUT(exp)
514 add_op = Operation("AddAct", self.op.name + "_add3")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100515 add_op.add_input_tensor(mul2_ofm)
516 add_op.add_input_tensor(
517 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200518 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200519 ),
520 )
521 add_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100522 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200523 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200524 )
525 )
526 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
527 exp_ofm.quantization = mul2_ofm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100528 add_op.set_output_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200529
530 # PASS 4 - Reduce sum
531 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
532 reduce_sum_op.attrs["padding"] = b"VALID"
533 reduce_sum_op.attrs["stride_w"] = 1
534 reduce_sum_op.attrs["stride_h"] = 1
535 reduce_sum_op.attrs["filter_width"] = 1
536 reduce_sum_op.attrs["filter_height"] = 1
537 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
538 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100539 reduce_sum_op.add_input_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200540
541 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
542 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
543 sum_of_exp.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100544 reduce_sum_op.set_output_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200545
546 # PASS 5 - CLZ
547 clz_op = Operation("CLZ", self.op.name + "_clz5")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100548 clz_op.add_input_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200549 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
550 headroom_plus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100551 clz_op.set_output_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200552
553 # PASS 6 - Sub
554 sub6_op = Operation("SubAct", self.op.name + "_sub6")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100555 sub6_op.add_input_tensor(
556 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200557 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200558 ),
559 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200560 sub6_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200561 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
562 reciprocal_right_shift.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100563 sub6_op.set_output_tensor(reciprocal_right_shift)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200564
565 # PASS 7 - SHL
566 shl7_op = Operation("SHL", self.op.name + "_shl7")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100567 shl7_op.add_input_tensor(
568 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200569 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200570 ),
571 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200572 shl7_op.add_input_tensor(reciprocal_right_shift)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200573 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200574 constant_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100575 shl7_op.set_output_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200576
577 # PASS 8 - Sub
578 sub8_op = Operation("SubAct", self.op.name + "_sub8")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100579 sub8_op.add_input_tensor(sum_of_exp)
580 sub8_op.add_input_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200581 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
582 sum_of_exps_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100583 sub8_op.set_output_tensor(sum_of_exps_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200584
585 # PASS 9 - SHL
586 shl9_op = Operation("SHL", self.op.name + "_shl9")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100587 shl9_op.add_input_tensor(sum_of_exps_minus_one)
588 shl9_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200589 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
590 shifted_sum_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100591 shl9_op.set_output_tensor(shifted_sum_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200592
593 # PASS 10 - SHR
594 shr10_op = Operation("SHR", self.op.name + "_shr10")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100595 shr10_op.add_input_tensor(shifted_sum_minus_one)
596 shr10_op.add_input_tensor(
597 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200598 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200599 ),
600 )
601 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
602 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100603 shr10_op.set_output_tensor(shifted_sum_minus_one_16)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200604
605 # PASS 11 - Sub+LUT(one over one plus x)
606 sub11_op = Operation("SubAct", self.op.name + "_sub11")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100607 sub11_op.add_input_tensor(shifted_sum_minus_one_16)
608 sub11_op.add_input_tensor(
609 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200610 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200611 ),
612 )
613 sub11_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100614 create_const_tensor(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200615 sub11_op.name + "_lut",
616 [1, 1, 1, 512],
617 DataType.int32,
618 self.ONE_OVER_ONE_PLUS_X_LUT,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200619 np.int32,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200620 TensorPurpose.LUT,
621 )
622 )
623 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
624 reciprocal_scale.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100625 sub11_op.set_output_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200626
627 # PASS 12 - Multiply
628 mul_op = Operation("MulAct", self.op.name + "_mul12")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100629 mul_op.add_input_tensor(exp_ofm)
630 mul_op.add_input_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200631 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
632 mul_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100633 mul_op.set_output_tensor(mul_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200634
635 # PASS 13 - SHR
636 shr13_op = Operation("SHR", self.op.name + "_shr13")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100637 shr13_op.add_input_tensor(mul_ofm)
638 shr13_op.add_input_tensor(reciprocal_right_shift)
639 shr13_op.set_output_tensor(ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200640
641 return shr13_op