blob: 2834f8c253c4387f4573cb98b3a392cb96aa1e95 [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02003# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4#
Fredrik Svedberga0c36242020-06-03 15:43:31 +02005# SPDX-License-Identifier: Apache-2.0
6#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02007# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
Fredrik Svedberga0c36242020-06-03 15:43:31 +02009# You may obtain a copy of the License at
10#
Fredrik Svedberg1575b942020-08-18 13:19:18 +020011# http://www.apache.org/licenses/LICENSE-2.0
Fredrik Svedberga0c36242020-06-03 15:43:31 +020012#
13# Unless required by applicable law or agreed to in writing, software
Fredrik Svedberg1575b942020-08-18 13:19:18 +020014# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Fredrik Svedberga0c36242020-06-03 15:43:31 +020016# See the License for the specific language governing permissions and
17# limitations under the License.
Fredrik Svedberg1575b942020-08-18 13:19:18 +020018#
Fredrik Svedberga0c36242020-06-03 15:43:31 +020019# Description:
20# Contains SoftMax
Fredrik Svedberg1575b942020-08-18 13:19:18 +020021import math
22
Fredrik Svedberga0c36242020-06-03 15:43:31 +020023import numpy as np
24
Fredrik Svedberg1575b942020-08-18 13:19:18 +020025from . import fp_math
Fredrik Svedberga0c36242020-06-03 15:43:31 +020026from . import scaling
27from .data_type import DataType
28from .operation import Operation
Michael McGeagh5778ffd2020-08-06 17:31:02 +010029from .tensor import create_const_tensor
30from .tensor import create_reshape_tensor
Fredrik Svedberga0c36242020-06-03 15:43:31 +020031from .tensor import Tensor
32from .tensor import TensorPurpose
33
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035class SoftMax:
36 # Turn off black formatting for the LUT tables to keep them compact
37 # fmt: off
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +020038
Fredrik Svedberga0c36242020-06-03 15:43:31 +020039 EXP_LUT = [
40 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
41 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
42 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
43 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
44 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
45 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
46 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
47 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
48 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
49 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
50 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
51 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
52 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
53 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
54 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
55 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
56 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
57 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
58 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
59 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
60 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
61 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
62 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
63 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
64 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
65 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
66 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
67 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
68 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
69 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
70 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
71 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
72 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
73 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
74 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
75 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
76 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
77 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
78 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
79 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
80 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
81 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
82 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
83 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
84 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
85 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
86 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
87 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
88 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
89 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
90 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
91 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
92 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
93 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
94 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
95 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
96 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
97 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
98 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
99 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
100 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
101 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
102 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
103 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
104 ]
105
106 ONE_OVER_ONE_PLUS_X_LUT = [
107 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
108 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
109 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
110 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
111 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
112 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
113 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
114 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
115 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
116 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
117 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
118 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
119 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
120 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
121 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
122 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
123 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
124 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
125 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
126 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
127 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
128 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
129 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
130 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
131 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
132 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
133 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
134 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
135 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
136 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
137 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
138 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
139 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
140 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
141 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
142 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
143 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
144 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
145 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
146 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
147 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
148 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
149 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
150 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
151 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
152 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
153 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
154 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
155 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
156 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
157 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
158 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
159 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
160 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
161 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
162 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
163 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
164 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
165 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
166 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
167 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
168 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
169 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
170 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
171 ]
172 # fmt: on
173
174 def __init__(self, op):
175 self.op = op
176
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200177 def generate_exp_table(self, beta, input_scale):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200178 integer_bits = 5
179 total_signed_bits = 31
180 # Calculate scaling
181 real_beta = min(
182 np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
183 )
184 scale, shift = scaling.quantise_scale(real_beta)
185 shift = 31 - shift
186 diff_min = -1.0 * math.floor(
187 1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
188 )
189 # Generate the exp LUT
190 lut = []
191 for x in range(256):
192 input_diff = x - 255
193 if input_diff >= diff_min:
194 rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
195 lut.append(fp_math.exp_on_negative_values(rescale))
196 else:
197 lut.append(0)
198 return lut
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200199
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200200 def get_graph(self):
201 ifm = self.op.inputs[0]
202 ofm = self.op.outputs[0]
203
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200204 if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
205 return self.get_graph_8bit(ifm, ofm)
206 elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200207 return self.get_graph_int16(ifm, ofm)
208 else:
209 self.op.run_on_npu = False
210 return self.op
211
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200212 def get_graph_8bit(self, ifm, ofm):
213 exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
214 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
215 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
216 no_scale_quant = ifm.quantization.clone()
217 no_scale_quant.scale_f32 = None
218 no_scale_quant.zero_point = 0
219 one_scale_quant = ifm.quantization.clone()
220 one_scale_quant.scale_f32 = 1.0
221 one_scale_quant.zero_point = 0
222 ifm.quantization.zero_point = 0
223
224 # PASS 0 - Depthwise Maxpool
225 maxpool_op = self.op.clone("_maxpool0")
226 maxpool_op.type = "MaxPool"
227 maxpool_h = ifm.shape[1] * ifm.shape[2]
228 maxpool_w = ifm.shape[3]
229 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
230 maxpool_op.attrs["padding"] = b"VALID"
231 maxpool_op.attrs["stride_w"] = 1
232 maxpool_op.attrs["stride_h"] = 1
233 maxpool_op.attrs["filter_width"] = maxpool_w
234 maxpool_op.attrs["filter_height"] = 1
235 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
236 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
237 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
238 ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
239 ifm_max.quantization = no_scale_quant
240 maxpool_op.set_output_tensor(ifm_max)
241
242 # PASS 1 - Sub+LUT(exp)
243 sub_op = Operation("SubAct", self.op.name + "_sub1")
244 sub_op.add_input_tensor(ifm)
245 sub_op.add_input_tensor(ifm_max)
246 sub_op.set_activation_lut(
247 create_const_tensor(
248 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
249 )
250 )
251 ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
252 ifm_exp.quantization = one_scale_quant.clone()
253 ifm_exp.quantization.zero_point = 127
254 ifm_exp.quantization.quant_min = -128
255 ifm_exp.quantization.quant_max = 127
256 sub_op.set_output_tensor(ifm_exp)
257
258 # PASS 2 - SHR
259 shr2_op = Operation("SHR", self.op.name + "_shr2")
Tim Halld775e372020-08-28 18:33:38 +0100260 shr2_op.attrs["rounding_mode"] = b"NATURAL"
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200261 shr2_op.add_input_tensor(ifm_exp)
262 shr2_op.add_input_tensor(
263 create_const_tensor(
264 shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
265 ),
266 )
267 rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
268 rescaled_exp.quantization = no_scale_quant
269 shr2_op.set_output_tensor(rescaled_exp)
270
271 # PASS 3 - Reduce sum
272 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
273 reduce_sum_op.attrs["padding"] = b"VALID"
274 reduce_sum_op.attrs["stride_w"] = 1
275 reduce_sum_op.attrs["stride_h"] = 1
276 reduce_sum_op.attrs["filter_width"] = 1
277 reduce_sum_op.attrs["filter_height"] = 1
278 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
279 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
280 reduce_sum_op.add_input_tensor(rescaled_exp)
281
282 reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
283 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
284 sum_of_exp.quantization = no_scale_quant
285 reduce_sum_op.set_output_tensor(sum_of_exp)
286
287 # PASS 4 - CLZ
288 clz_op = Operation("CLZ", self.op.name + "_clz4")
289 clz_op.add_input_tensor(sum_of_exp)
290 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
291 headroom_plus_one.quantization = no_scale_quant
292 clz_op.set_output_tensor(headroom_plus_one)
293
294 # PASS 5 - Sub
295 sub5_op = Operation("SubAct", self.op.name + "_sub5")
296 sub5_op.add_input_tensor(
297 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200298 "headroom_offset_const",
299 [1, 1, 1, 1],
300 DataType.int32,
301 [12 + 31 - 8],
302 np.int32,
303 quantization=no_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200304 ),
305 )
306 sub5_op.add_input_tensor(headroom_plus_one)
307 right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
308 right_shift.quantization = no_scale_quant
309 sub5_op.set_output_tensor(right_shift)
310
311 # PASS 6 - Sub
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200312 one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200313 sub6_op = Operation("SubAct", self.op.name + "_sub6")
314 sub6_op.add_input_tensor(headroom_plus_one)
315 sub6_op.add_input_tensor(one)
316 headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
317 headroom.quantization = no_scale_quant
318 sub6_op.set_output_tensor(headroom)
319
320 # PASS 7 - SHL
321 shl7_op = Operation("SHL", self.op.name + "_shl7")
322 shl7_op.add_input_tensor(sum_of_exp)
323 shl7_op.add_input_tensor(headroom)
324 shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
325 shifted_sum.quantization = no_scale_quant
326 shl7_op.set_output_tensor(shifted_sum)
327
328 # PASS 8 - Sub
329 sub8_op = Operation("SubAct", self.op.name + "_sub8")
330 sub8_op.add_input_tensor(shifted_sum)
331 sub8_op.add_input_tensor(
332 create_const_tensor(
333 "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
334 ),
335 )
336 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
337 shifted_sum_minus_one.quantization = no_scale_quant
338 sub8_op.set_output_tensor(shifted_sum_minus_one)
339
340 # PASS 9 - SHL
341 shl9_op = Operation("SHL", self.op.name + "_shl9")
342 shl9_op.add_input_tensor(shifted_sum_minus_one)
343 shl9_op.add_input_tensor(one)
344 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
345 shifted_sum_minus_one.quantization = no_scale_quant
346 shl9_op.set_output_tensor(shifted_sum_minus_one)
347
348 # PASS 10 - Add
349 add10_op = Operation("AddAct", self.op.name + "_add10")
350 add10_op.add_input_tensor(
351 create_const_tensor(
352 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
353 ),
354 )
355 add10_op.add_input_tensor(shifted_sum_minus_one)
356 add10_op.attrs["rescale"] = [1, 1]
357 half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
358 half_denominator.quantization = one_scale_quant
359 add10_op.set_output_tensor(half_denominator)
360
361 # PASS 11 - Multiply
362 mul11_op = Operation("MulAct", self.op.name + "_mul11")
363 mul11_op.add_input_tensor(half_denominator)
364 mul11_op.add_input_tensor(
365 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200366 "neg_32_over_17_const",
367 [1, 1, 1, 1],
368 DataType.int32,
369 [-1010580540],
370 np.int32,
371 quantization=one_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200372 ),
373 )
374 rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
375 rescaled.quantization = one_scale_quant.clone()
376 rescaled.quantization.scale_f32 = 2.0
377 mul11_op.set_output_tensor(rescaled)
378
379 # PASS 12 - Add
380 add12_op = Operation("AddAct", self.op.name + "_add12")
381 add12_op.add_input_tensor(rescaled)
382 add12_op.add_input_tensor(
383 create_const_tensor(
384 "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
385 ),
386 )
387 rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
388 rescale_w_offset.quantization = one_scale_quant
389 add12_op.set_output_tensor(rescale_w_offset)
390
391 nr_x = rescale_w_offset
392 F2_one = create_const_tensor(
393 "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
394 )
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200395 four = create_const_tensor(
396 "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant
397 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200398 for i in range(3):
399 # PASS 13, 18, 23 - MUL
400 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
401 mul_op.add_input_tensor(nr_x)
402 mul_op.add_input_tensor(half_denominator)
403 half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
404 half_denominator_times_x.quantization = one_scale_quant.clone()
405 half_denominator_times_x.quantization.scale_f32 = 2.0
406 mul_op.set_output_tensor(half_denominator_times_x)
407 # PASS 14, 19, 24 - SUB
408 sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
409 sub_op.add_input_tensor(F2_one)
410 sub_op.add_input_tensor(half_denominator_times_x)
411 one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
412 one_minus_half_denominator_times_x.quantization = one_scale_quant
413 sub_op.set_output_tensor(one_minus_half_denominator_times_x)
414 # PASS 15, 20, 25 - MUL
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200415 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200416 mul_op.add_input_tensor(nr_x)
417 mul_op.add_input_tensor(one_minus_half_denominator_times_x)
418 to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
419 to_rescale.quantization = one_scale_quant.clone()
420 to_rescale.quantization.scale_f32 = 2.0
421 mul_op.set_output_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200422 # PASS 16, 21, 26 - MUL
423 shl_op = Operation("MulAct", self.op.name + "_mul%d" % (16 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200424 shl_op.add_input_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200425 shl_op.add_input_tensor(four)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200426 to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
427 to_add.quantization = no_scale_quant
428 shl_op.set_output_tensor(to_add)
429 # PASS 17, 22, 27 - ADD
430 add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
431 add_op.add_input_tensor(nr_x)
432 add_op.add_input_tensor(to_add)
433 nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
434 nr_x.quantization = one_scale_quant
435 add_op.set_output_tensor(nr_x)
436
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200437 # PASS 28 - Multiply
438 mul28_op = Operation("MulAct", self.op.name + "_mul28")
439 mul28_op.add_input_tensor(nr_x)
440 mul28_op.add_input_tensor(
441 create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
442 )
443 scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200444 scale_factor.quantization = one_scale_quant
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200445 mul28_op.set_output_tensor(scale_factor)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200446
447 # PASS 29 - Multiply
448 mul_op = Operation("MulAct", self.op.name + "_mul29")
449 mul_op.add_input_tensor(ifm_exp)
450 mul_op.add_input_tensor(scale_factor)
451 scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
452 scaled_exp.quantization = one_scale_quant.clone()
453 scaled_exp.quantization.scale_f32 = 2.0
454 mul_op.set_output_tensor(scaled_exp)
455
456 # PASS 30 - SHR
457 shr30_op = Operation("SHR", self.op.name + "_shr30")
Tim Halld775e372020-08-28 18:33:38 +0100458 shr30_op.attrs["rounding_mode"] = b"NATURAL"
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200459 shr30_op.add_input_tensor(scaled_exp)
460 shr30_op.add_input_tensor(right_shift)
461 shr30_op.set_output_tensor(ofm)
462
463 return shr30_op
464
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200465 def get_graph_int16(self, ifm, ofm):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100466 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
467 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200468 no_scale_quant = ifm.quantization.clone()
469 no_scale_quant.scale_f32 = None
470
471 # PASS 0 - Depthwise Maxpool
472 maxpool_op = self.op.clone("_maxpool0")
473 maxpool_op.type = "MaxPool"
474 maxpool_h = ifm.shape[1] * ifm.shape[2]
475 maxpool_w = ifm.shape[3]
476 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
477 maxpool_op.attrs["padding"] = b"VALID"
478 maxpool_op.attrs["stride_w"] = 1
479 maxpool_op.attrs["stride_h"] = 1
480 maxpool_op.attrs["filter_width"] = maxpool_w
481 maxpool_op.attrs["filter_height"] = 1
482 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
483 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100484 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200485 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200486 maxpool_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100487 maxpool_op.set_output_tensor(maxpool_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200488
489 # PASS 1 - Sub
490 sub1_op = Operation("SubAct", self.op.name + "_sub1")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100491 sub1_op.add_input_tensor(ifm)
492 sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200493 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
494 sub1_ofm.quantization = ifm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100495 sub1_op.set_output_tensor(sub1_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200496
497 # PASS 2 - Mul
498 beta = self.op.attrs.get("beta", 1.0)
499 mul2_out_range = 10.0 / 65535.0
500 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
501 mul2_quant = ifm.quantization.clone()
502 mul2_quant.scale_f32 = beta
503 mul2_op = Operation("MulAct", self.op.name + "_mul2")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100504 mul2_op.add_input_tensor(sub1_ofm)
505 mul2_op.add_input_tensor(
506 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200507 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200508 ),
509 )
510 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
511 mul2_ofm.quantization = ofm.quantization.clone()
512 mul2_ofm.quantization.scale_f32 = mul2_out_range
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100513 mul2_op.set_output_tensor(mul2_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200514
515 # PASS 3 - Add+LUT(exp)
516 add_op = Operation("AddAct", self.op.name + "_add3")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100517 add_op.add_input_tensor(mul2_ofm)
518 add_op.add_input_tensor(
519 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200520 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200521 ),
522 )
523 add_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100524 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200525 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200526 )
527 )
528 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
529 exp_ofm.quantization = mul2_ofm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100530 add_op.set_output_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200531
532 # PASS 4 - Reduce sum
533 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
534 reduce_sum_op.attrs["padding"] = b"VALID"
535 reduce_sum_op.attrs["stride_w"] = 1
536 reduce_sum_op.attrs["stride_h"] = 1
537 reduce_sum_op.attrs["filter_width"] = 1
538 reduce_sum_op.attrs["filter_height"] = 1
539 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
540 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100541 reduce_sum_op.add_input_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200542
543 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
544 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
545 sum_of_exp.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100546 reduce_sum_op.set_output_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200547
548 # PASS 5 - CLZ
549 clz_op = Operation("CLZ", self.op.name + "_clz5")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100550 clz_op.add_input_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200551 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
552 headroom_plus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100553 clz_op.set_output_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200554
555 # PASS 6 - Sub
556 sub6_op = Operation("SubAct", self.op.name + "_sub6")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100557 sub6_op.add_input_tensor(
558 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200559 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200560 ),
561 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200562 sub6_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200563 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
564 reciprocal_right_shift.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100565 sub6_op.set_output_tensor(reciprocal_right_shift)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200566
567 # PASS 7 - SHL
568 shl7_op = Operation("SHL", self.op.name + "_shl7")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100569 shl7_op.add_input_tensor(
570 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200571 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200572 ),
573 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200574 shl7_op.add_input_tensor(reciprocal_right_shift)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200575 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200576 constant_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100577 shl7_op.set_output_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200578
579 # PASS 8 - Sub
580 sub8_op = Operation("SubAct", self.op.name + "_sub8")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100581 sub8_op.add_input_tensor(sum_of_exp)
582 sub8_op.add_input_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200583 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
584 sum_of_exps_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100585 sub8_op.set_output_tensor(sum_of_exps_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200586
587 # PASS 9 - SHL
588 shl9_op = Operation("SHL", self.op.name + "_shl9")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100589 shl9_op.add_input_tensor(sum_of_exps_minus_one)
590 shl9_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200591 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
592 shifted_sum_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100593 shl9_op.set_output_tensor(shifted_sum_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200594
595 # PASS 10 - SHR
596 shr10_op = Operation("SHR", self.op.name + "_shr10")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100597 shr10_op.add_input_tensor(shifted_sum_minus_one)
598 shr10_op.add_input_tensor(
599 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200600 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200601 ),
602 )
603 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
604 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100605 shr10_op.set_output_tensor(shifted_sum_minus_one_16)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200606
607 # PASS 11 - Sub+LUT(one over one plus x)
608 sub11_op = Operation("SubAct", self.op.name + "_sub11")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100609 sub11_op.add_input_tensor(shifted_sum_minus_one_16)
610 sub11_op.add_input_tensor(
611 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200612 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200613 ),
614 )
615 sub11_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100616 create_const_tensor(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200617 sub11_op.name + "_lut",
618 [1, 1, 1, 512],
619 DataType.int32,
620 self.ONE_OVER_ONE_PLUS_X_LUT,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200621 np.int32,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200622 TensorPurpose.LUT,
623 )
624 )
625 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
626 reciprocal_scale.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100627 sub11_op.set_output_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200628
629 # PASS 12 - Multiply
630 mul_op = Operation("MulAct", self.op.name + "_mul12")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100631 mul_op.add_input_tensor(exp_ofm)
632 mul_op.add_input_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200633 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
634 mul_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100635 mul_op.set_output_tensor(mul_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200636
637 # PASS 13 - SHR
638 shr13_op = Operation("SHR", self.op.name + "_shr13")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100639 shr13_op.add_input_tensor(mul_ofm)
640 shr13_op.add_input_tensor(reciprocal_right_shift)
641 shr13_op.set_output_tensor(ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200642
643 return shr13_op