blob: efd91a3510e0e1db4488e03977b807e9cfe63b8e [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02003# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4#
Fredrik Svedberga0c36242020-06-03 15:43:31 +02005# SPDX-License-Identifier: Apache-2.0
6#
Fredrik Svedberg1575b942020-08-18 13:19:18 +02007# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
Fredrik Svedberga0c36242020-06-03 15:43:31 +02009# You may obtain a copy of the License at
10#
Fredrik Svedberg1575b942020-08-18 13:19:18 +020011# http://www.apache.org/licenses/LICENSE-2.0
Fredrik Svedberga0c36242020-06-03 15:43:31 +020012#
13# Unless required by applicable law or agreed to in writing, software
Fredrik Svedberg1575b942020-08-18 13:19:18 +020014# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Fredrik Svedberga0c36242020-06-03 15:43:31 +020016# See the License for the specific language governing permissions and
17# limitations under the License.
Fredrik Svedberg1575b942020-08-18 13:19:18 +020018#
Fredrik Svedberga0c36242020-06-03 15:43:31 +020019# Description:
20# Contains SoftMax
Fredrik Svedberg1575b942020-08-18 13:19:18 +020021import math
22
Fredrik Svedberga0c36242020-06-03 15:43:31 +020023import numpy as np
24
Fredrik Svedberg1575b942020-08-18 13:19:18 +020025from . import fp_math
Fredrik Svedberga0c36242020-06-03 15:43:31 +020026from . import scaling
27from .data_type import DataType
Tim Halle6ccd872020-11-09 16:46:37 +000028from .debug_database import DebugDatabase
Louis Verhaardaee5d752020-09-30 09:01:52 +020029from .operation import Op
Fredrik Svedberga0c36242020-06-03 15:43:31 +020030from .operation import Operation
Michael McGeagh5778ffd2020-08-06 17:31:02 +010031from .tensor import create_const_tensor
32from .tensor import create_reshape_tensor
Fredrik Svedberga0c36242020-06-03 15:43:31 +020033from .tensor import Tensor
34from .tensor import TensorPurpose
35
36
Fredrik Svedberga0c36242020-06-03 15:43:31 +020037class SoftMax:
38 # Turn off black formatting for the LUT tables to keep them compact
39 # fmt: off
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +020040
Fredrik Svedberga0c36242020-06-03 15:43:31 +020041 EXP_LUT = [
42 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
43 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
44 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
45 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
46 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
47 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
48 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
49 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
50 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
51 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
52 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
53 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
54 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
55 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
56 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
57 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
58 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
59 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
60 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
61 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
62 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
63 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
64 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
65 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
66 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
67 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
68 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
69 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
70 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
71 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
72 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
73 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
74 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
75 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
76 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
77 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
78 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
79 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
80 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
81 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
82 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
83 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
84 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
85 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
86 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
87 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
88 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
89 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
90 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
91 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
92 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
93 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
94 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
95 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
96 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
97 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
98 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
99 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
100 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
101 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
102 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
103 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
104 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
105 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
106 ]
107
108 ONE_OVER_ONE_PLUS_X_LUT = [
109 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
110 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
111 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
112 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
113 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
114 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
115 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
116 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
117 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
118 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
119 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
120 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
121 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
122 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
123 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
124 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
125 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
126 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
127 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
128 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
129 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
130 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
131 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
132 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
133 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
134 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
135 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
136 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
137 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
138 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
139 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
140 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
141 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
142 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
143 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
144 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
145 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
146 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
147 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
148 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
149 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
150 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
151 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
152 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
153 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
154 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
155 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
156 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
157 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
158 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
159 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
160 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
161 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
162 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
163 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
164 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
165 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
166 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
167 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
168 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
169 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
170 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
171 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
172 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
173 ]
174 # fmt: on
175
176 def __init__(self, op):
177 self.op = op
178
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200179 def generate_exp_table(self, beta, input_scale):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200180 integer_bits = 5
181 total_signed_bits = 31
182 # Calculate scaling
183 real_beta = min(
184 np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
185 )
186 scale, shift = scaling.quantise_scale(real_beta)
187 shift = 31 - shift
188 diff_min = -1.0 * math.floor(
189 1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
190 )
191 # Generate the exp LUT
192 lut = []
193 for x in range(256):
194 input_diff = x - 255
195 if input_diff >= diff_min:
196 rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
197 lut.append(fp_math.exp_on_negative_values(rescale))
198 else:
199 lut.append(0)
200 return lut
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200201
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200202 def get_graph(self):
203 ifm = self.op.inputs[0]
204 ofm = self.op.outputs[0]
205
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200206 # Reshape ifm/ofm (if needed)
207 full_shape = ifm.get_full_shape()
208 if full_shape[0] > 1:
209 full_shape[1] *= full_shape[0]
210 full_shape[0] = 1
211 ifm = create_reshape_tensor(ifm, full_shape)
212 ofm = create_reshape_tensor(ofm, full_shape, False)
213
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200214 if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
215 return self.get_graph_8bit(ifm, ofm)
216 elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200217 return self.get_graph_int16(ifm, ofm)
218 else:
219 self.op.run_on_npu = False
220 return self.op
221
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200222 def get_graph_8bit(self, ifm, ofm):
223 exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
Tim Halle6ccd872020-11-09 16:46:37 +0000224 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
225 DebugDatabase.add_optimised(self.op, ifm.ops[0])
226 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200227 no_scale_quant = ifm.quantization.clone()
228 no_scale_quant.scale_f32 = None
229 no_scale_quant.zero_point = 0
230 one_scale_quant = ifm.quantization.clone()
231 one_scale_quant.scale_f32 = 1.0
232 one_scale_quant.zero_point = 0
233 ifm.quantization.zero_point = 0
234
235 # PASS 0 - Depthwise Maxpool
236 maxpool_op = self.op.clone("_maxpool0")
Louis Verhaardaee5d752020-09-30 09:01:52 +0200237 maxpool_op.type = Op.MaxPool
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200238 maxpool_h = ifm.shape[1] * ifm.shape[2]
239 maxpool_w = ifm.shape[3]
240 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
241 maxpool_op.attrs["padding"] = b"VALID"
242 maxpool_op.attrs["stride_w"] = 1
243 maxpool_op.attrs["stride_h"] = 1
244 maxpool_op.attrs["filter_width"] = maxpool_w
245 maxpool_op.attrs["filter_height"] = 1
246 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
247 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
248 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
249 ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
250 ifm_max.quantization = no_scale_quant
251 maxpool_op.set_output_tensor(ifm_max)
Tim Halle6ccd872020-11-09 16:46:37 +0000252 DebugDatabase.add_optimised(self.op, maxpool_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200253
254 # PASS 1 - Sub+LUT(exp)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200255 sub_op = Operation(Op.Sub, self.op.name + "_sub1")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200256 sub_op.add_input_tensor(ifm)
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200257 sub_op.add_input_tensor(create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200258 sub_op.set_activation_lut(
259 create_const_tensor(
260 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
261 )
262 )
263 ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
264 ifm_exp.quantization = one_scale_quant.clone()
265 ifm_exp.quantization.zero_point = 127
266 ifm_exp.quantization.quant_min = -128
267 ifm_exp.quantization.quant_max = 127
268 sub_op.set_output_tensor(ifm_exp)
Tim Halle6ccd872020-11-09 16:46:37 +0000269 DebugDatabase.add_optimised(self.op, sub_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200270
271 # PASS 2 - SHR
Louis Verhaardaee5d752020-09-30 09:01:52 +0200272 shr2_op = Operation(Op.SHR, self.op.name + "_shr2")
Tim Halld775e372020-08-28 18:33:38 +0100273 shr2_op.attrs["rounding_mode"] = b"NATURAL"
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200274 shr2_op.add_input_tensor(ifm_exp)
275 shr2_op.add_input_tensor(
276 create_const_tensor(
277 shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
278 ),
279 )
280 rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
281 rescaled_exp.quantization = no_scale_quant
282 shr2_op.set_output_tensor(rescaled_exp)
Tim Halle6ccd872020-11-09 16:46:37 +0000283 DebugDatabase.add_optimised(self.op, shr2_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200284
285 # PASS 3 - Reduce sum
Louis Verhaardaee5d752020-09-30 09:01:52 +0200286 reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum3")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200287 reduce_sum_op.attrs["padding"] = b"VALID"
288 reduce_sum_op.attrs["stride_w"] = 1
289 reduce_sum_op.attrs["stride_h"] = 1
290 reduce_sum_op.attrs["filter_width"] = 1
291 reduce_sum_op.attrs["filter_height"] = 1
292 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
293 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
294 reduce_sum_op.add_input_tensor(rescaled_exp)
295
296 reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
297 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
298 sum_of_exp.quantization = no_scale_quant
299 reduce_sum_op.set_output_tensor(sum_of_exp)
Tim Halle6ccd872020-11-09 16:46:37 +0000300 DebugDatabase.add_optimised(self.op, reduce_sum_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200301
302 # PASS 4 - CLZ
Louis Verhaardaee5d752020-09-30 09:01:52 +0200303 clz_op = Operation(Op.CLZ, self.op.name + "_clz4")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200304 clz_op.add_input_tensor(sum_of_exp)
305 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
306 headroom_plus_one.quantization = no_scale_quant
307 clz_op.set_output_tensor(headroom_plus_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000308 DebugDatabase.add_optimised(self.op, clz_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200309
310 # PASS 5 - Sub
Louis Verhaardaee5d752020-09-30 09:01:52 +0200311 sub5_op = Operation(Op.Sub, self.op.name + "_sub5")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200312 sub5_op.add_input_tensor(
313 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200314 "headroom_offset_const",
315 [1, 1, 1, 1],
316 DataType.int32,
317 [12 + 31 - 8],
318 np.int32,
319 quantization=no_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200320 ),
321 )
322 sub5_op.add_input_tensor(headroom_plus_one)
323 right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
324 right_shift.quantization = no_scale_quant
325 sub5_op.set_output_tensor(right_shift)
Tim Halle6ccd872020-11-09 16:46:37 +0000326 DebugDatabase.add_optimised(self.op, sub5_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200327
328 # PASS 6 - Sub
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200329 one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200330 sub6_op = Operation(Op.Sub, self.op.name + "_sub6")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200331 sub6_op.add_input_tensor(headroom_plus_one)
332 sub6_op.add_input_tensor(one)
333 headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
334 headroom.quantization = no_scale_quant
335 sub6_op.set_output_tensor(headroom)
Tim Halle6ccd872020-11-09 16:46:37 +0000336 DebugDatabase.add_optimised(self.op, sub6_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200337
338 # PASS 7 - SHL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200339 shl7_op = Operation(Op.SHL, self.op.name + "_shl7")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200340 shl7_op.add_input_tensor(sum_of_exp)
341 shl7_op.add_input_tensor(headroom)
342 shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
343 shifted_sum.quantization = no_scale_quant
344 shl7_op.set_output_tensor(shifted_sum)
Tim Halle6ccd872020-11-09 16:46:37 +0000345 DebugDatabase.add_optimised(self.op, shl7_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200346
347 # PASS 8 - Sub
Louis Verhaardaee5d752020-09-30 09:01:52 +0200348 sub8_op = Operation(Op.Sub, self.op.name + "_sub8")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200349 sub8_op.add_input_tensor(shifted_sum)
350 sub8_op.add_input_tensor(
351 create_const_tensor(
352 "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
353 ),
354 )
355 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
356 shifted_sum_minus_one.quantization = no_scale_quant
357 sub8_op.set_output_tensor(shifted_sum_minus_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000358 DebugDatabase.add_optimised(self.op, sub8_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200359
360 # PASS 9 - SHL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200361 shl9_op = Operation(Op.SHL, self.op.name + "_shl9")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200362 shl9_op.add_input_tensor(shifted_sum_minus_one)
363 shl9_op.add_input_tensor(one)
364 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
365 shifted_sum_minus_one.quantization = no_scale_quant
366 shl9_op.set_output_tensor(shifted_sum_minus_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000367 DebugDatabase.add_optimised(self.op, shl9_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200368
369 # PASS 10 - Add
Louis Verhaardaee5d752020-09-30 09:01:52 +0200370 add10_op = Operation(Op.Add, self.op.name + "_add10")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200371 add10_op.add_input_tensor(
372 create_const_tensor(
373 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
374 ),
375 )
376 add10_op.add_input_tensor(shifted_sum_minus_one)
377 add10_op.attrs["rescale"] = [1, 1]
378 half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
379 half_denominator.quantization = one_scale_quant
380 add10_op.set_output_tensor(half_denominator)
Tim Halle6ccd872020-11-09 16:46:37 +0000381 DebugDatabase.add_optimised(self.op, add10_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200382
383 # PASS 11 - Multiply
Louis Verhaardaee5d752020-09-30 09:01:52 +0200384 mul11_op = Operation(Op.Mul, self.op.name + "_mul11")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200385 mul11_op.add_input_tensor(half_denominator)
386 mul11_op.add_input_tensor(
387 create_const_tensor(
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200388 "neg_32_over_17_const",
389 [1, 1, 1, 1],
390 DataType.int32,
391 [-1010580540],
392 np.int32,
393 quantization=one_scale_quant,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200394 ),
395 )
Fredrik Svedbergd9e38fe2020-09-21 10:34:48 +0200396 rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200397 rescaled.quantization = one_scale_quant.clone()
398 rescaled.quantization.scale_f32 = 2.0
399 mul11_op.set_output_tensor(rescaled)
Tim Halle6ccd872020-11-09 16:46:37 +0000400 DebugDatabase.add_optimised(self.op, mul11_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200401
402 # PASS 12 - Add
Louis Verhaardaee5d752020-09-30 09:01:52 +0200403 add12_op = Operation(Op.Add, self.op.name + "_add12")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200404 add12_op.add_input_tensor(rescaled)
405 add12_op.add_input_tensor(
406 create_const_tensor(
407 "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
408 ),
409 )
410 rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
411 rescale_w_offset.quantization = one_scale_quant
412 add12_op.set_output_tensor(rescale_w_offset)
Tim Halle6ccd872020-11-09 16:46:37 +0000413 DebugDatabase.add_optimised(self.op, add12_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200414
415 nr_x = rescale_w_offset
416 F2_one = create_const_tensor(
417 "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
418 )
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200419 four = create_const_tensor(
420 "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant
421 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200422 for i in range(3):
423 # PASS 13, 18, 23 - MUL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200424 mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (13 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200425 mul_op.add_input_tensor(nr_x)
426 mul_op.add_input_tensor(half_denominator)
Fredrik Svedbergd9e38fe2020-09-21 10:34:48 +0200427 half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200428 half_denominator_times_x.quantization = one_scale_quant.clone()
429 half_denominator_times_x.quantization.scale_f32 = 2.0
430 mul_op.set_output_tensor(half_denominator_times_x)
Tim Halle6ccd872020-11-09 16:46:37 +0000431 DebugDatabase.add_optimised(self.op, mul_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200432 # PASS 14, 19, 24 - SUB
Louis Verhaardaee5d752020-09-30 09:01:52 +0200433 sub_op = Operation(Op.Sub, self.op.name + "_sub%d" % (14 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200434 sub_op.add_input_tensor(F2_one)
435 sub_op.add_input_tensor(half_denominator_times_x)
436 one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
437 one_minus_half_denominator_times_x.quantization = one_scale_quant
438 sub_op.set_output_tensor(one_minus_half_denominator_times_x)
Tim Halle6ccd872020-11-09 16:46:37 +0000439 DebugDatabase.add_optimised(self.op, sub_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200440 # PASS 15, 20, 25 - MUL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200441 mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (15 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200442 mul_op.add_input_tensor(nr_x)
443 mul_op.add_input_tensor(one_minus_half_denominator_times_x)
Fredrik Svedbergd9e38fe2020-09-21 10:34:48 +0200444 to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200445 to_rescale.quantization = one_scale_quant.clone()
446 to_rescale.quantization.scale_f32 = 2.0
447 mul_op.set_output_tensor(to_rescale)
Tim Halle6ccd872020-11-09 16:46:37 +0000448 DebugDatabase.add_optimised(self.op, mul_op)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200449 # PASS 16, 21, 26 - MUL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200450 shl_op = Operation(Op.Mul, self.op.name + "_mul%d" % (16 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200451 shl_op.add_input_tensor(to_rescale)
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200452 shl_op.add_input_tensor(four)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200453 to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
454 to_add.quantization = no_scale_quant
455 shl_op.set_output_tensor(to_add)
Tim Halle6ccd872020-11-09 16:46:37 +0000456 DebugDatabase.add_optimised(self.op, shl_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200457 # PASS 17, 22, 27 - ADD
Louis Verhaardaee5d752020-09-30 09:01:52 +0200458 add_op = Operation(Op.Add, self.op.name + "_add%d" % (17 + i * 5))
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200459 add_op.add_input_tensor(nr_x)
460 add_op.add_input_tensor(to_add)
461 nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
462 nr_x.quantization = one_scale_quant
463 add_op.set_output_tensor(nr_x)
Tim Halle6ccd872020-11-09 16:46:37 +0000464 DebugDatabase.add_optimised(self.op, add_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200465
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200466 # PASS 28 - Multiply
Louis Verhaardaee5d752020-09-30 09:01:52 +0200467 mul28_op = Operation(Op.Mul, self.op.name + "_mul28")
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200468 mul28_op.add_input_tensor(nr_x)
469 mul28_op.add_input_tensor(
470 create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
471 )
472 scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200473 scale_factor.quantization = one_scale_quant
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200474 mul28_op.set_output_tensor(scale_factor)
Tim Halle6ccd872020-11-09 16:46:37 +0000475 DebugDatabase.add_optimised(self.op, mul28_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200476
477 # PASS 29 - Multiply
Louis Verhaardaee5d752020-09-30 09:01:52 +0200478 mul_op = Operation(Op.Mul, self.op.name + "_mul29")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200479 mul_op.add_input_tensor(ifm_exp)
480 mul_op.add_input_tensor(scale_factor)
481 scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
482 scaled_exp.quantization = one_scale_quant.clone()
483 scaled_exp.quantization.scale_f32 = 2.0
484 mul_op.set_output_tensor(scaled_exp)
Tim Halle6ccd872020-11-09 16:46:37 +0000485 DebugDatabase.add_optimised(self.op, mul_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200486
487 # PASS 30 - SHR
Louis Verhaardaee5d752020-09-30 09:01:52 +0200488 shr30_op = Operation(Op.SHR, self.op.name + "_shr30")
Tim Halld775e372020-08-28 18:33:38 +0100489 shr30_op.attrs["rounding_mode"] = b"NATURAL"
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200490 shr30_op.add_input_tensor(scaled_exp)
491 shr30_op.add_input_tensor(right_shift)
492 shr30_op.set_output_tensor(ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000493 DebugDatabase.add_optimised(self.op, shr30_op)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200494
495 return shr30_op
496
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200497 def get_graph_int16(self, ifm, ofm):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200498 no_scale_quant = ifm.quantization.clone()
499 no_scale_quant.scale_f32 = None
500
501 # PASS 0 - Depthwise Maxpool
502 maxpool_op = self.op.clone("_maxpool0")
Louis Verhaardaee5d752020-09-30 09:01:52 +0200503 maxpool_op.type = Op.MaxPool
Tim Halle6ccd872020-11-09 16:46:37 +0000504 DebugDatabase.add_optimised(self.op, maxpool_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200505 maxpool_h = ifm.shape[1] * ifm.shape[2]
506 maxpool_w = ifm.shape[3]
507 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
508 maxpool_op.attrs["padding"] = b"VALID"
509 maxpool_op.attrs["stride_w"] = 1
510 maxpool_op.attrs["stride_h"] = 1
511 maxpool_op.attrs["filter_width"] = maxpool_w
512 maxpool_op.attrs["filter_height"] = 1
513 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
514 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100515 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200516 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200517 maxpool_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100518 maxpool_op.set_output_tensor(maxpool_ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000519 DebugDatabase.add_optimised(self.op, maxpool_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200520
521 # PASS 1 - Sub
Louis Verhaardaee5d752020-09-30 09:01:52 +0200522 sub1_op = Operation(Op.Sub, self.op.name + "_sub1")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100523 sub1_op.add_input_tensor(ifm)
524 sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200525 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
526 sub1_ofm.quantization = ifm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100527 sub1_op.set_output_tensor(sub1_ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000528 DebugDatabase.add_optimised(self.op, sub1_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200529
530 # PASS 2 - Mul
531 beta = self.op.attrs.get("beta", 1.0)
532 mul2_out_range = 10.0 / 65535.0
533 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
534 mul2_quant = ifm.quantization.clone()
535 mul2_quant.scale_f32 = beta
Louis Verhaardaee5d752020-09-30 09:01:52 +0200536 mul2_op = Operation(Op.Mul, self.op.name + "_mul2")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100537 mul2_op.add_input_tensor(sub1_ofm)
538 mul2_op.add_input_tensor(
539 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200540 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200541 ),
542 )
543 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
544 mul2_ofm.quantization = ofm.quantization.clone()
545 mul2_ofm.quantization.scale_f32 = mul2_out_range
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100546 mul2_op.set_output_tensor(mul2_ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000547 DebugDatabase.add_optimised(self.op, mul2_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200548
549 # PASS 3 - Add+LUT(exp)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200550 add_op = Operation(Op.Add, self.op.name + "_add3")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100551 add_op.add_input_tensor(mul2_ofm)
552 add_op.add_input_tensor(
553 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200554 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200555 ),
556 )
557 add_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100558 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200559 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200560 )
561 )
562 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
563 exp_ofm.quantization = mul2_ofm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100564 add_op.set_output_tensor(exp_ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000565 DebugDatabase.add_optimised(self.op, add_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200566
567 # PASS 4 - Reduce sum
Louis Verhaardaee5d752020-09-30 09:01:52 +0200568 reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum4")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200569 reduce_sum_op.attrs["padding"] = b"VALID"
570 reduce_sum_op.attrs["stride_w"] = 1
571 reduce_sum_op.attrs["stride_h"] = 1
572 reduce_sum_op.attrs["filter_width"] = 1
573 reduce_sum_op.attrs["filter_height"] = 1
574 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
575 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100576 reduce_sum_op.add_input_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200577
578 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
579 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
580 sum_of_exp.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100581 reduce_sum_op.set_output_tensor(sum_of_exp)
Tim Halle6ccd872020-11-09 16:46:37 +0000582 DebugDatabase.add_optimised(self.op, reduce_sum_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200583
584 # PASS 5 - CLZ
Louis Verhaardaee5d752020-09-30 09:01:52 +0200585 clz_op = Operation(Op.CLZ, self.op.name + "_clz5")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100586 clz_op.add_input_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200587 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
588 headroom_plus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100589 clz_op.set_output_tensor(headroom_plus_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000590 DebugDatabase.add_optimised(self.op, clz_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200591
592 # PASS 6 - Sub
Louis Verhaardaee5d752020-09-30 09:01:52 +0200593 sub6_op = Operation(Op.Sub, self.op.name + "_sub6")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100594 sub6_op.add_input_tensor(
595 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200596 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200597 ),
598 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200599 sub6_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200600 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
601 reciprocal_right_shift.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100602 sub6_op.set_output_tensor(reciprocal_right_shift)
Tim Halle6ccd872020-11-09 16:46:37 +0000603 DebugDatabase.add_optimised(self.op, sub6_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200604
605 # PASS 7 - SHL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200606 shl7_op = Operation(Op.SHL, self.op.name + "_shl7")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100607 shl7_op.add_input_tensor(
608 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200609 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200610 ),
611 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200612 shl7_op.add_input_tensor(reciprocal_right_shift)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200613 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200614 constant_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100615 shl7_op.set_output_tensor(constant_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000616 DebugDatabase.add_optimised(self.op, shl7_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200617
618 # PASS 8 - Sub
Louis Verhaardaee5d752020-09-30 09:01:52 +0200619 sub8_op = Operation(Op.Sub, self.op.name + "_sub8")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100620 sub8_op.add_input_tensor(sum_of_exp)
621 sub8_op.add_input_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200622 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
623 sum_of_exps_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100624 sub8_op.set_output_tensor(sum_of_exps_minus_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000625 DebugDatabase.add_optimised(self.op, sub8_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200626
627 # PASS 9 - SHL
Louis Verhaardaee5d752020-09-30 09:01:52 +0200628 shl9_op = Operation(Op.SHL, self.op.name + "_shl9")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100629 shl9_op.add_input_tensor(sum_of_exps_minus_one)
630 shl9_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200631 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
632 shifted_sum_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100633 shl9_op.set_output_tensor(shifted_sum_minus_one)
Tim Halle6ccd872020-11-09 16:46:37 +0000634 DebugDatabase.add_optimised(self.op, shl9_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200635
636 # PASS 10 - SHR
Louis Verhaardaee5d752020-09-30 09:01:52 +0200637 shr10_op = Operation(Op.SHR, self.op.name + "_shr10")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100638 shr10_op.add_input_tensor(shifted_sum_minus_one)
639 shr10_op.add_input_tensor(
640 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200641 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200642 ),
643 )
644 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
645 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100646 shr10_op.set_output_tensor(shifted_sum_minus_one_16)
Tim Halle6ccd872020-11-09 16:46:37 +0000647 DebugDatabase.add_optimised(self.op, shr10_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200648
649 # PASS 11 - Sub+LUT(one over one plus x)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200650 sub11_op = Operation(Op.Sub, self.op.name + "_sub11")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100651 sub11_op.add_input_tensor(shifted_sum_minus_one_16)
652 sub11_op.add_input_tensor(
653 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200654 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200655 ),
656 )
657 sub11_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100658 create_const_tensor(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200659 sub11_op.name + "_lut",
660 [1, 1, 1, 512],
661 DataType.int32,
662 self.ONE_OVER_ONE_PLUS_X_LUT,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200663 np.int32,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200664 TensorPurpose.LUT,
665 )
666 )
667 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
668 reciprocal_scale.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100669 sub11_op.set_output_tensor(reciprocal_scale)
Tim Halle6ccd872020-11-09 16:46:37 +0000670 DebugDatabase.add_optimised(self.op, sub11_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200671
672 # PASS 12 - Multiply
Louis Verhaardaee5d752020-09-30 09:01:52 +0200673 mul_op = Operation(Op.Mul, self.op.name + "_mul12")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100674 mul_op.add_input_tensor(exp_ofm)
675 mul_op.add_input_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200676 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
677 mul_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100678 mul_op.set_output_tensor(mul_ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000679 DebugDatabase.add_optimised(self.op, mul_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200680
681 # PASS 13 - SHR
Louis Verhaardaee5d752020-09-30 09:01:52 +0200682 shr13_op = Operation(Op.SHR, self.op.name + "_shr13")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100683 shr13_op.add_input_tensor(mul_ofm)
684 shr13_op.add_input_tensor(reciprocal_right_shift)
685 shr13_op.set_output_tensor(ofm)
Tim Halle6ccd872020-11-09 16:46:37 +0000686 DebugDatabase.add_optimised(self.op, shr13_op)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200687
688 return shr13_op