blob: eb97c7922f657a4d54eed6d352a3150cc3791586 [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02003# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4#
Fredrik Svedberga0c36242020-06-03 15:43:31 +02005# SPDX-License-Identifier: Apache-2.0
6#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02007# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
Fredrik Svedberga0c36242020-06-03 15:43:31 +02009# You may obtain a copy of the License at
10#
Fredrik Svedberg1575b942020-08-18 13:19:18 +020011# http://www.apache.org/licenses/LICENSE-2.0
Fredrik Svedberga0c36242020-06-03 15:43:31 +020012#
13# Unless required by applicable law or agreed to in writing, software
Fredrik Svedberg1575b942020-08-18 13:19:18 +020014# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Fredrik Svedberga0c36242020-06-03 15:43:31 +020016# See the License for the specific language governing permissions and
17# limitations under the License.
Fredrik Svedberg1575b942020-08-18 13:19:18 +020018#
Fredrik Svedberga0c36242020-06-03 15:43:31 +020019# Description:
20# Contains SoftMax
Fredrik Svedberg1575b942020-08-18 13:19:18 +020021import math
22
Fredrik Svedberga0c36242020-06-03 15:43:31 +020023import numpy as np
24
Fredrik Svedberg1575b942020-08-18 13:19:18 +020025from . import fp_math
Fredrik Svedberga0c36242020-06-03 15:43:31 +020026from . import scaling
27from .data_type import DataType
28from .operation import Operation
Michael McGeagh5778ffd2020-08-06 17:31:02 +010029from .tensor import create_const_tensor
30from .tensor import create_reshape_tensor
Fredrik Svedberga0c36242020-06-03 15:43:31 +020031from .tensor import Tensor
32from .tensor import TensorPurpose
33
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035class SoftMax:
36 # Turn off black formatting for the LUT tables to keep them compact
37 # fmt: off
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +020038
Fredrik Svedberga0c36242020-06-03 15:43:31 +020039 EXP_LUT = [
40 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
41 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
42 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
43 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
44 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
45 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
46 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
47 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
48 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
49 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
50 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
51 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
52 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
53 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
54 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
55 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
56 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
57 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
58 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
59 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
60 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
61 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
62 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
63 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
64 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
65 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
66 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
67 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
68 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
69 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
70 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
71 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
72 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
73 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
74 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
75 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
76 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
77 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
78 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
79 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
80 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
81 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
82 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
83 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
84 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
85 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
86 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
87 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
88 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
89 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
90 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
91 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
92 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
93 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
94 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
95 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
96 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
97 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
98 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
99 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
100 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
101 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
102 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
103 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
104 ]
105
106 ONE_OVER_ONE_PLUS_X_LUT = [
107 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
108 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
109 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
110 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
111 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
112 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
113 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
114 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
115 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
116 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
117 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
118 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
119 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
120 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
121 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
122 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
123 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
124 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
125 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
126 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
127 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
128 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
129 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
130 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
131 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
132 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
133 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
134 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
135 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
136 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
137 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
138 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
139 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
140 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
141 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
142 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
143 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
144 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
145 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
146 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
147 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
148 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
149 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
150 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
151 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
152 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
153 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
154 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
155 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
156 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
157 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
158 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
159 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
160 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
161 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
162 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
163 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
164 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
165 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
166 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
167 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
168 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
169 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
170 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
171 ]
172 # fmt: on
173
174 def __init__(self, op):
175 self.op = op
176
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200177 def generate_exp_table(self, beta, input_scale):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200178 integer_bits = 5
179 total_signed_bits = 31
180 # Calculate scaling
181 real_beta = min(
182 np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
183 )
184 scale, shift = scaling.quantise_scale(real_beta)
185 shift = 31 - shift
186 diff_min = -1.0 * math.floor(
187 1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
188 )
189 # Generate the exp LUT
190 lut = []
191 for x in range(256):
192 input_diff = x - 255
193 if input_diff >= diff_min:
194 rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
195 lut.append(fp_math.exp_on_negative_values(rescale))
196 else:
197 lut.append(0)
198 return lut
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200199
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200200 def get_graph(self):
201 ifm = self.op.inputs[0]
202 ofm = self.op.outputs[0]
203
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200204 if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
205 return self.get_graph_8bit(ifm, ofm)
206 elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200207 return self.get_graph_int16(ifm, ofm)
208 else:
209 self.op.run_on_npu = False
210 return self.op
211
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200212 def get_graph_8bit(self, ifm, ofm):
213 exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
214 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
215 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
216 no_scale_quant = ifm.quantization.clone()
217 no_scale_quant.scale_f32 = None
218 no_scale_quant.zero_point = 0
219 one_scale_quant = ifm.quantization.clone()
220 one_scale_quant.scale_f32 = 1.0
221 one_scale_quant.zero_point = 0
222 ifm.quantization.zero_point = 0
223
224 # PASS 0 - Depthwise Maxpool
225 maxpool_op = self.op.clone("_maxpool0")
226 maxpool_op.type = "MaxPool"
227 maxpool_h = ifm.shape[1] * ifm.shape[2]
228 maxpool_w = ifm.shape[3]
229 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
230 maxpool_op.attrs["padding"] = b"VALID"
231 maxpool_op.attrs["stride_w"] = 1
232 maxpool_op.attrs["stride_h"] = 1
233 maxpool_op.attrs["filter_width"] = maxpool_w
234 maxpool_op.attrs["filter_height"] = 1
235 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
236 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
237 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
238 ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
239 ifm_max.quantization = no_scale_quant
240 maxpool_op.set_output_tensor(ifm_max)
241
242 # PASS 1 - Sub+LUT(exp)
243 sub_op = Operation("SubAct", self.op.name + "_sub1")
244 sub_op.add_input_tensor(ifm)
245 sub_op.add_input_tensor(ifm_max)
246 sub_op.set_activation_lut(
247 create_const_tensor(
248 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
249 )
250 )
251 ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
252 ifm_exp.quantization = one_scale_quant.clone()
253 ifm_exp.quantization.zero_point = 127
254 ifm_exp.quantization.quant_min = -128
255 ifm_exp.quantization.quant_max = 127
256 sub_op.set_output_tensor(ifm_exp)
257
258 # PASS 2 - SHR
259 shr2_op = Operation("SHR", self.op.name + "_shr2")
260 shr2_op.add_input_tensor(ifm_exp)
261 shr2_op.add_input_tensor(
262 create_const_tensor(
263 shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
264 ),
265 )
266 rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
267 rescaled_exp.quantization = no_scale_quant
268 shr2_op.set_output_tensor(rescaled_exp)
269
270 # PASS 3 - Reduce sum
271 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
272 reduce_sum_op.attrs["padding"] = b"VALID"
273 reduce_sum_op.attrs["stride_w"] = 1
274 reduce_sum_op.attrs["stride_h"] = 1
275 reduce_sum_op.attrs["filter_width"] = 1
276 reduce_sum_op.attrs["filter_height"] = 1
277 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
278 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
279 reduce_sum_op.add_input_tensor(rescaled_exp)
280
281 reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
282 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
283 sum_of_exp.quantization = no_scale_quant
284 reduce_sum_op.set_output_tensor(sum_of_exp)
285
286 # PASS 4 - CLZ
287 clz_op = Operation("CLZ", self.op.name + "_clz4")
288 clz_op.add_input_tensor(sum_of_exp)
289 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
290 headroom_plus_one.quantization = no_scale_quant
291 clz_op.set_output_tensor(headroom_plus_one)
292
293 # PASS 5 - Sub
294 sub5_op = Operation("SubAct", self.op.name + "_sub5")
295 sub5_op.add_input_tensor(
296 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200297 "headroom_offset_const",
298 [1, 1, 1, 1],
299 DataType.int32,
300 [12 + 31 - 8],
301 np.int32,
302 quantization=no_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200303 ),
304 )
305 sub5_op.add_input_tensor(headroom_plus_one)
306 right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
307 right_shift.quantization = no_scale_quant
308 sub5_op.set_output_tensor(right_shift)
309
310 # PASS 6 - Sub
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200311 one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200312 sub6_op = Operation("SubAct", self.op.name + "_sub6")
313 sub6_op.add_input_tensor(headroom_plus_one)
314 sub6_op.add_input_tensor(one)
315 headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
316 headroom.quantization = no_scale_quant
317 sub6_op.set_output_tensor(headroom)
318
319 # PASS 7 - SHL
320 shl7_op = Operation("SHL", self.op.name + "_shl7")
321 shl7_op.add_input_tensor(sum_of_exp)
322 shl7_op.add_input_tensor(headroom)
323 shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
324 shifted_sum.quantization = no_scale_quant
325 shl7_op.set_output_tensor(shifted_sum)
326
327 # PASS 8 - Sub
328 sub8_op = Operation("SubAct", self.op.name + "_sub8")
329 sub8_op.add_input_tensor(shifted_sum)
330 sub8_op.add_input_tensor(
331 create_const_tensor(
332 "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
333 ),
334 )
335 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
336 shifted_sum_minus_one.quantization = no_scale_quant
337 sub8_op.set_output_tensor(shifted_sum_minus_one)
338
339 # PASS 9 - SHL
340 shl9_op = Operation("SHL", self.op.name + "_shl9")
341 shl9_op.add_input_tensor(shifted_sum_minus_one)
342 shl9_op.add_input_tensor(one)
343 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
344 shifted_sum_minus_one.quantization = no_scale_quant
345 shl9_op.set_output_tensor(shifted_sum_minus_one)
346
347 # PASS 10 - Add
348 add10_op = Operation("AddAct", self.op.name + "_add10")
349 add10_op.add_input_tensor(
350 create_const_tensor(
351 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
352 ),
353 )
354 add10_op.add_input_tensor(shifted_sum_minus_one)
355 add10_op.attrs["rescale"] = [1, 1]
356 half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
357 half_denominator.quantization = one_scale_quant
358 add10_op.set_output_tensor(half_denominator)
359
360 # PASS 11 - Multiply
361 mul11_op = Operation("MulAct", self.op.name + "_mul11")
362 mul11_op.add_input_tensor(half_denominator)
363 mul11_op.add_input_tensor(
364 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200365 "neg_32_over_17_const",
366 [1, 1, 1, 1],
367 DataType.int32,
368 [-1010580540],
369 np.int32,
370 quantization=one_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200371 ),
372 )
373 rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
374 rescaled.quantization = one_scale_quant.clone()
375 rescaled.quantization.scale_f32 = 2.0
376 mul11_op.set_output_tensor(rescaled)
377
378 # PASS 12 - Add
379 add12_op = Operation("AddAct", self.op.name + "_add12")
380 add12_op.add_input_tensor(rescaled)
381 add12_op.add_input_tensor(
382 create_const_tensor(
383 "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
384 ),
385 )
386 rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
387 rescale_w_offset.quantization = one_scale_quant
388 add12_op.set_output_tensor(rescale_w_offset)
389
390 nr_x = rescale_w_offset
391 F2_one = create_const_tensor(
392 "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
393 )
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200394 two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200395 for i in range(3):
396 # PASS 13, 18, 23 - MUL
397 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
398 mul_op.add_input_tensor(nr_x)
399 mul_op.add_input_tensor(half_denominator)
400 half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
401 half_denominator_times_x.quantization = one_scale_quant.clone()
402 half_denominator_times_x.quantization.scale_f32 = 2.0
403 mul_op.set_output_tensor(half_denominator_times_x)
404 # PASS 14, 19, 24 - SUB
405 sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
406 sub_op.add_input_tensor(F2_one)
407 sub_op.add_input_tensor(half_denominator_times_x)
408 one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
409 one_minus_half_denominator_times_x.quantization = one_scale_quant
410 sub_op.set_output_tensor(one_minus_half_denominator_times_x)
411 # PASS 15, 20, 25 - MUL
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200412 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200413 mul_op.add_input_tensor(nr_x)
414 mul_op.add_input_tensor(one_minus_half_denominator_times_x)
415 to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
416 to_rescale.quantization = one_scale_quant.clone()
417 to_rescale.quantization.scale_f32 = 2.0
418 mul_op.set_output_tensor(to_rescale)
419 # PASS 16, 21, 26 - SHL
420 shl_op = Operation("SHL", self.op.name + "_shl%d" % (16 + i * 5))
421 shl_op.add_input_tensor(to_rescale)
422 shl_op.add_input_tensor(two)
423 to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
424 to_add.quantization = no_scale_quant
425 shl_op.set_output_tensor(to_add)
426 # PASS 17, 22, 27 - ADD
427 add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
428 add_op.add_input_tensor(nr_x)
429 add_op.add_input_tensor(to_add)
430 nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
431 nr_x.quantization = one_scale_quant
432 add_op.set_output_tensor(nr_x)
433
434 # PASS 28 - SHL
435 shl28_op = Operation("SHL", self.op.name + "_shl28")
436 shl28_op.add_input_tensor(nr_x)
437 shl28_op.add_input_tensor(one)
438 scale_factor = Tensor(reduce_sum_shape, DataType.int32, shl28_op.name + "_0")
439 scale_factor.quantization = one_scale_quant
440 shl28_op.set_output_tensor(scale_factor)
441
442 # PASS 29 - Multiply
443 mul_op = Operation("MulAct", self.op.name + "_mul29")
444 mul_op.add_input_tensor(ifm_exp)
445 mul_op.add_input_tensor(scale_factor)
446 scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
447 scaled_exp.quantization = one_scale_quant.clone()
448 scaled_exp.quantization.scale_f32 = 2.0
449 mul_op.set_output_tensor(scaled_exp)
450
451 # PASS 30 - SHR
452 shr30_op = Operation("SHR", self.op.name + "_shr30")
453 shr30_op.add_input_tensor(scaled_exp)
454 shr30_op.add_input_tensor(right_shift)
455 shr30_op.set_output_tensor(ofm)
456
457 return shr30_op
458
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200459 def get_graph_int16(self, ifm, ofm):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100460 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
461 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200462 no_scale_quant = ifm.quantization.clone()
463 no_scale_quant.scale_f32 = None
464
465 # PASS 0 - Depthwise Maxpool
466 maxpool_op = self.op.clone("_maxpool0")
467 maxpool_op.type = "MaxPool"
468 maxpool_h = ifm.shape[1] * ifm.shape[2]
469 maxpool_w = ifm.shape[3]
470 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
471 maxpool_op.attrs["padding"] = b"VALID"
472 maxpool_op.attrs["stride_w"] = 1
473 maxpool_op.attrs["stride_h"] = 1
474 maxpool_op.attrs["filter_width"] = maxpool_w
475 maxpool_op.attrs["filter_height"] = 1
476 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
477 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100478 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200479 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200480 maxpool_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100481 maxpool_op.set_output_tensor(maxpool_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200482
483 # PASS 1 - Sub
484 sub1_op = Operation("SubAct", self.op.name + "_sub1")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100485 sub1_op.add_input_tensor(ifm)
486 sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200487 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
488 sub1_ofm.quantization = ifm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100489 sub1_op.set_output_tensor(sub1_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200490
491 # PASS 2 - Mul
492 beta = self.op.attrs.get("beta", 1.0)
493 mul2_out_range = 10.0 / 65535.0
494 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
495 mul2_quant = ifm.quantization.clone()
496 mul2_quant.scale_f32 = beta
497 mul2_op = Operation("MulAct", self.op.name + "_mul2")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100498 mul2_op.add_input_tensor(sub1_ofm)
499 mul2_op.add_input_tensor(
500 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200501 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200502 ),
503 )
504 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
505 mul2_ofm.quantization = ofm.quantization.clone()
506 mul2_ofm.quantization.scale_f32 = mul2_out_range
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100507 mul2_op.set_output_tensor(mul2_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200508
509 # PASS 3 - Add+LUT(exp)
510 add_op = Operation("AddAct", self.op.name + "_add3")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100511 add_op.add_input_tensor(mul2_ofm)
512 add_op.add_input_tensor(
513 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200514 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200515 ),
516 )
517 add_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100518 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200519 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200520 )
521 )
522 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
523 exp_ofm.quantization = mul2_ofm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100524 add_op.set_output_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200525
526 # PASS 4 - Reduce sum
527 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
528 reduce_sum_op.attrs["padding"] = b"VALID"
529 reduce_sum_op.attrs["stride_w"] = 1
530 reduce_sum_op.attrs["stride_h"] = 1
531 reduce_sum_op.attrs["filter_width"] = 1
532 reduce_sum_op.attrs["filter_height"] = 1
533 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
534 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100535 reduce_sum_op.add_input_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200536
537 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
538 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
539 sum_of_exp.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100540 reduce_sum_op.set_output_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200541
542 # PASS 5 - CLZ
543 clz_op = Operation("CLZ", self.op.name + "_clz5")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100544 clz_op.add_input_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200545 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
546 headroom_plus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100547 clz_op.set_output_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200548
549 # PASS 6 - Sub
550 sub6_op = Operation("SubAct", self.op.name + "_sub6")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100551 sub6_op.add_input_tensor(
552 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200553 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200554 ),
555 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200556 sub6_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200557 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
558 reciprocal_right_shift.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100559 sub6_op.set_output_tensor(reciprocal_right_shift)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200560
561 # PASS 7 - SHL
562 shl7_op = Operation("SHL", self.op.name + "_shl7")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100563 shl7_op.add_input_tensor(
564 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200565 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200566 ),
567 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200568 shl7_op.add_input_tensor(reciprocal_right_shift)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200569 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200570 constant_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100571 shl7_op.set_output_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200572
573 # PASS 8 - Sub
574 sub8_op = Operation("SubAct", self.op.name + "_sub8")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100575 sub8_op.add_input_tensor(sum_of_exp)
576 sub8_op.add_input_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200577 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
578 sum_of_exps_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100579 sub8_op.set_output_tensor(sum_of_exps_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200580
581 # PASS 9 - SHL
582 shl9_op = Operation("SHL", self.op.name + "_shl9")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100583 shl9_op.add_input_tensor(sum_of_exps_minus_one)
584 shl9_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200585 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
586 shifted_sum_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100587 shl9_op.set_output_tensor(shifted_sum_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200588
589 # PASS 10 - SHR
590 shr10_op = Operation("SHR", self.op.name + "_shr10")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100591 shr10_op.add_input_tensor(shifted_sum_minus_one)
592 shr10_op.add_input_tensor(
593 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200594 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200595 ),
596 )
597 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
598 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100599 shr10_op.set_output_tensor(shifted_sum_minus_one_16)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200600
601 # PASS 11 - Sub+LUT(one over one plus x)
602 sub11_op = Operation("SubAct", self.op.name + "_sub11")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100603 sub11_op.add_input_tensor(shifted_sum_minus_one_16)
604 sub11_op.add_input_tensor(
605 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200606 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200607 ),
608 )
609 sub11_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100610 create_const_tensor(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200611 sub11_op.name + "_lut",
612 [1, 1, 1, 512],
613 DataType.int32,
614 self.ONE_OVER_ONE_PLUS_X_LUT,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200615 np.int32,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200616 TensorPurpose.LUT,
617 )
618 )
619 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
620 reciprocal_scale.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100621 sub11_op.set_output_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200622
623 # PASS 12 - Multiply
624 mul_op = Operation("MulAct", self.op.name + "_mul12")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100625 mul_op.add_input_tensor(exp_ofm)
626 mul_op.add_input_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200627 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
628 mul_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100629 mul_op.set_output_tensor(mul_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200630
631 # PASS 13 - SHR
632 shr13_op = Operation("SHR", self.op.name + "_shr13")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100633 shr13_op.add_input_tensor(mul_ofm)
634 shr13_op.add_input_tensor(reciprocal_right_shift)
635 shr13_op.set_output_tensor(ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200636
637 return shr13_op