blob: c67cc376c62b33b7f28a903dfadd7648d16131dc [file] [log] [blame]
Fredrik Svedberga0c36242020-06-03 15:43:31 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Contains SoftMax
18import numpy as np
19
20from . import scaling
21from .data_type import DataType
22from .operation import Operation
Michael McGeagh5778ffd2020-08-06 17:31:02 +010023from .tensor import create_const_tensor
24from .tensor import create_reshape_tensor
Fredrik Svedberga0c36242020-06-03 15:43:31 +020025from .tensor import Tensor
26from .tensor import TensorPurpose
27
28
Fredrik Svedberga0c36242020-06-03 15:43:31 +020029class SoftMax:
30 # Turn off black formatting for the LUT tables to keep them compact
31 # fmt: off
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +020032
33 EXP_LUT_U8 = [
34 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
35 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
36 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
37 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
38 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
39 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
40 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
41 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
42 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
43 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
44 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
45 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
46 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
47 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
48 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
49 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
50 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
51 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
52 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
53 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
54 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
55 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
56 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
57 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
58 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
59 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
60 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
61 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
62 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
63 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
64 0x00000291, 0x000006fa, 0x000012f6, 0x0000338b, 0x00008c1b, 0x00017cd8, 0x00040b3d, 0x000afe11,
65 0x001de16c, 0x00513949, 0x00dcca03, 0x02582ac2, 0x065f6c52, 0x1152aaf6, 0x2f16ad4c, 0x7fffffff
66 ]
67
68 EXP_LUT_I8 = [
69 0x000011c9, 0x000012b8, 0x000013b4, 0x000014bd, 0x000015d4, 0x000016fa, 0x0000182f, 0x00001975,
70 0x00001acb, 0x00001c34, 0x00001daf, 0x00001f3f, 0x000020e3, 0x0000229e, 0x00002470, 0x0000265a,
71 0x0000285e, 0x00002a7d, 0x00002cb9, 0x00002f13, 0x0000318c, 0x00003427, 0x000036e5, 0x000039c8,
72 0x00003cd1, 0x00004004, 0x00004361, 0x000046ec, 0x00004aa6, 0x00004e93, 0x000052b4, 0x0000570d,
73 0x00005ba1, 0x00006072, 0x00006583, 0x00006ada, 0x00007077, 0x00007661, 0x00007c9a, 0x00008327,
74 0x00008a0c, 0x0000914d, 0x000098f1, 0x0000a0fb, 0x0000a971, 0x0000b259, 0x0000bbb9, 0x0000c597,
75 0x0000cffa, 0x0000dae9, 0x0000e66b, 0x0000f288, 0x0000ff48, 0x00010cb3, 0x00011ad3, 0x000129b1,
76 0x00013957, 0x000149d0, 0x00015b26, 0x00016d65, 0x0001809b, 0x000194d2, 0x0001aa1a, 0x0001c080,
77 0x0001d814, 0x0001f0e4, 0x00020b03, 0x00022681, 0x00024371, 0x000261e7, 0x000281f7, 0x0002a3b5,
78 0x0002c73b, 0x0002ec9e, 0x000313f8, 0x00033d64, 0x000368fd, 0x000396e0, 0x0003c72e, 0x0003fa05,
79 0x00042f89, 0x000467dd, 0x0004a326, 0x0004e18e, 0x0005233d, 0x00056860, 0x0005b126, 0x0005fdbf,
80 0x00064e5f, 0x0006a33b, 0x0006fc8e, 0x00075a93, 0x0007bd89, 0x000825b3, 0x00089356, 0x000906bd,
81 0x00098034, 0x000a000f, 0x000a86a2, 0x000b1447, 0x000ba95f, 0x000c464d, 0x000ceb7c, 0x000d9959,
82 0x000e505a, 0x000f10f9, 0x000fdbb8, 0x0010b120, 0x001191c0, 0x00127e2f, 0x0013770b, 0x00147cfc,
83 0x001590b2, 0x0016b2e6, 0x0017e45c, 0x001925e1, 0x001a784c, 0x001bdc81, 0x001d536f, 0x001ede14,
84 0x00207d76, 0x002232af, 0x0023fee3, 0x0025e348, 0x0027e125, 0x0029f9ce, 0x002c2ead, 0x002e813e,
85 0x0030f30f, 0x003385c7, 0x00363b1e, 0x003914e9, 0x003c150f, 0x003f3d97, 0x004290a0, 0x00461065,
86 0x0049bf40, 0x004d9fac, 0x0051b444, 0x0055ffc2, 0x005a850e, 0x005f472f, 0x00644959, 0x00698eea,
87 0x006f1b6b, 0x0074f298, 0x007b185e, 0x008190dd, 0x00886073, 0x008f8bad, 0x00971761, 0x009f08a0,
88 0x00a764c0, 0x00b03163, 0x00b9746c, 0x00c3341a, 0x00cd76f8, 0x00d843eb, 0x00e3a23a, 0x00ef9981,
89 0x00fc31d0, 0x0109739d, 0x011767cf, 0x012617cd, 0x01358d6e, 0x0145d319, 0x0156f3be, 0x0168fadc,
90 0x017bf49d, 0x018fedb3, 0x01a4f391, 0x01bb1457, 0x01d25ede, 0x01eae2e1, 0x0204b0c5, 0x021fd9e9,
91 0x023c708e, 0x025a87f5, 0x027a343a, 0x029b8ac1, 0x02bea1ea, 0x02e39148, 0x030a71be, 0x03335d49,
92 0x035e6f88, 0x038bc564, 0x03bb7d53, 0x03edb776, 0x0422956d, 0x045a3add, 0x0494cd23, 0x04d27398,
93 0x051357c1, 0x0557a511, 0x059f8990, 0x05eb3585, 0x063adbc4, 0x068eb1f7, 0x06e6f042, 0x0743d212,
94 0x07a595d0, 0x080c7d1f, 0x0878cd5d, 0x08eacf11, 0x0962cefe, 0x09e11dc0, 0x0a661028, 0x0af1ffdf,
95 0x0b854a8e, 0x0c205363, 0x0cc38284, 0x0d6f4577, 0x0e241032, 0x0ee25ba2, 0x0faaa7e6, 0x107d7b92,
96 0x115b64b1, 0x1244f774, 0x133ad1b8, 0x143d9876, 0x154df988, 0x166cac69, 0x179a70c9, 0x18d81250,
97 0x1a266643, 0x1b864d38, 0x1cf8b430, 0x1e7e9307, 0x2018f0a9, 0x21c8e098, 0x238f850c, 0x256e1033,
98 0x2765c273, 0x2977ef40, 0x2ba5faa9, 0x2df15b73, 0x305b9d6b, 0x32e65e8a, 0x3593552c, 0x38644d67,
99 0x3b5b2b66, 0x3e79ee87, 0x41c2adcb, 0x45379f4e, 0x48db158a, 0x4caf81e6, 0x50b7797f, 0x54f5af16,
100 0x596cfe2f, 0x5e2066d0, 0x631310c8, 0x684852d8, 0x6dc3a909, 0x7388c421, 0x799b84b7, 0x7fffffff,
101 ]
102
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200103 EXP_LUT = [
104 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
105 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
106 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
107 0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
108 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
109 0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
110 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
111 0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
112 0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
113 0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
114 0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
115 0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
116 0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
117 0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
118 0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
119 0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
120 0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
121 0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
122 0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
123 0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
124 0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
125 0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
126 0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
127 0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
128 0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
129 0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
130 0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
131 0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
132 0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
133 0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
134 0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
135 0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
136 0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
137 0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
138 0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
139 0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
140 0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
141 0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
142 0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
143 0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
144 0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
145 0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
146 0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
147 0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
148 0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
149 0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
150 0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
151 0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
152 0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
153 0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
154 0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
155 0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
156 0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
157 0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
158 0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
159 0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
160 0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
161 0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
162 0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
163 0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
164 0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
165 0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
166 0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
167 0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
168 ]
169
170 ONE_OVER_ONE_PLUS_X_LUT = [
171 0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
172 0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
173 0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
174 0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
175 0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
176 0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
177 0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
178 0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
179 0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
180 0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
181 0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
182 0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
183 0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
184 0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
185 0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
186 0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
187 0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
188 0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
189 0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
190 0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
191 0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
192 0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
193 0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
194 0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
195 0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
196 0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
197 0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
198 0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
199 0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
200 0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
201 0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
202 0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
203 0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
204 0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
205 0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
206 0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
207 0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
208 0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
209 0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
210 0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
211 0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
212 0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
213 0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
214 0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
215 0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
216 0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
217 0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
218 0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
219 0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
220 0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
221 0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
222 0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
223 0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
224 0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
225 0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
226 0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
227 0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
228 0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
229 0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
230 0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
231 0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
232 0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
233 0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
234 0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
235 ]
236 # fmt: on
237
238 def __init__(self, op):
239 self.op = op
240
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200241 def generate_exp_table(self, beta, input_scale):
242 # TODO: Generate the exp table using the same math as the reference
243 return self.EXP_LUT_U8 if input_scale == 1.0 else self.EXP_LUT_I8
244
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200245 def get_graph(self):
246 ifm = self.op.inputs[0]
247 ofm = self.op.outputs[0]
248
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200249 if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
250 return self.get_graph_8bit(ifm, ofm)
251 elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200252 return self.get_graph_int16(ifm, ofm)
253 else:
254 self.op.run_on_npu = False
255 return self.op
256
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200257 def get_graph_8bit(self, ifm, ofm):
258 exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
259 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
260 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
261 no_scale_quant = ifm.quantization.clone()
262 no_scale_quant.scale_f32 = None
263 no_scale_quant.zero_point = 0
264 one_scale_quant = ifm.quantization.clone()
265 one_scale_quant.scale_f32 = 1.0
266 one_scale_quant.zero_point = 0
267 ifm.quantization.zero_point = 0
268
269 # PASS 0 - Depthwise Maxpool
270 maxpool_op = self.op.clone("_maxpool0")
271 maxpool_op.type = "MaxPool"
272 maxpool_h = ifm.shape[1] * ifm.shape[2]
273 maxpool_w = ifm.shape[3]
274 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
275 maxpool_op.attrs["padding"] = b"VALID"
276 maxpool_op.attrs["stride_w"] = 1
277 maxpool_op.attrs["stride_h"] = 1
278 maxpool_op.attrs["filter_width"] = maxpool_w
279 maxpool_op.attrs["filter_height"] = 1
280 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
281 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
282 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
283 ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
284 ifm_max.quantization = no_scale_quant
285 maxpool_op.set_output_tensor(ifm_max)
286
287 # PASS 1 - Sub+LUT(exp)
288 sub_op = Operation("SubAct", self.op.name + "_sub1")
289 sub_op.add_input_tensor(ifm)
290 sub_op.add_input_tensor(ifm_max)
291 sub_op.set_activation_lut(
292 create_const_tensor(
293 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
294 )
295 )
296 ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
297 ifm_exp.quantization = one_scale_quant.clone()
298 ifm_exp.quantization.zero_point = 127
299 ifm_exp.quantization.quant_min = -128
300 ifm_exp.quantization.quant_max = 127
301 sub_op.set_output_tensor(ifm_exp)
302
303 # PASS 2 - SHR
304 shr2_op = Operation("SHR", self.op.name + "_shr2")
305 shr2_op.add_input_tensor(ifm_exp)
306 shr2_op.add_input_tensor(
307 create_const_tensor(
308 shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
309 ),
310 )
311 rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
312 rescaled_exp.quantization = no_scale_quant
313 shr2_op.set_output_tensor(rescaled_exp)
314
315 # PASS 3 - Reduce sum
316 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
317 reduce_sum_op.attrs["padding"] = b"VALID"
318 reduce_sum_op.attrs["stride_w"] = 1
319 reduce_sum_op.attrs["stride_h"] = 1
320 reduce_sum_op.attrs["filter_width"] = 1
321 reduce_sum_op.attrs["filter_height"] = 1
322 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
323 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
324 reduce_sum_op.add_input_tensor(rescaled_exp)
325
326 reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
327 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
328 sum_of_exp.quantization = no_scale_quant
329 reduce_sum_op.set_output_tensor(sum_of_exp)
330
331 # PASS 4 - CLZ
332 clz_op = Operation("CLZ", self.op.name + "_clz4")
333 clz_op.add_input_tensor(sum_of_exp)
334 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
335 headroom_plus_one.quantization = no_scale_quant
336 clz_op.set_output_tensor(headroom_plus_one)
337
338 # PASS 5 - Sub
339 sub5_op = Operation("SubAct", self.op.name + "_sub5")
340 sub5_op.add_input_tensor(
341 create_const_tensor(
342 "headroom_offset_const", [1, 1, 1, 1], DataType.int32, [12 + 31 - 8], np.int32, quantization=no_scale_quant
343 ),
344 )
345 sub5_op.add_input_tensor(headroom_plus_one)
346 right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
347 right_shift.quantization = no_scale_quant
348 sub5_op.set_output_tensor(right_shift)
349
350 # PASS 6 - Sub
351 one = create_const_tensor(
352 "one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
353 )
354 sub6_op = Operation("SubAct", self.op.name + "_sub6")
355 sub6_op.add_input_tensor(headroom_plus_one)
356 sub6_op.add_input_tensor(one)
357 headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
358 headroom.quantization = no_scale_quant
359 sub6_op.set_output_tensor(headroom)
360
361 # PASS 7 - SHL
362 shl7_op = Operation("SHL", self.op.name + "_shl7")
363 shl7_op.add_input_tensor(sum_of_exp)
364 shl7_op.add_input_tensor(headroom)
365 shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
366 shifted_sum.quantization = no_scale_quant
367 shl7_op.set_output_tensor(shifted_sum)
368
369 # PASS 8 - Sub
370 sub8_op = Operation("SubAct", self.op.name + "_sub8")
371 sub8_op.add_input_tensor(shifted_sum)
372 sub8_op.add_input_tensor(
373 create_const_tensor(
374 "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
375 ),
376 )
377 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
378 shifted_sum_minus_one.quantization = no_scale_quant
379 sub8_op.set_output_tensor(shifted_sum_minus_one)
380
381 # PASS 9 - SHL
382 shl9_op = Operation("SHL", self.op.name + "_shl9")
383 shl9_op.add_input_tensor(shifted_sum_minus_one)
384 shl9_op.add_input_tensor(one)
385 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
386 shifted_sum_minus_one.quantization = no_scale_quant
387 shl9_op.set_output_tensor(shifted_sum_minus_one)
388
389 # PASS 10 - Add
390 add10_op = Operation("AddAct", self.op.name + "_add10")
391 add10_op.add_input_tensor(
392 create_const_tensor(
393 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
394 ),
395 )
396 add10_op.add_input_tensor(shifted_sum_minus_one)
397 add10_op.attrs["rescale"] = [1, 1]
398 half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
399 half_denominator.quantization = one_scale_quant
400 add10_op.set_output_tensor(half_denominator)
401
402 # PASS 11 - Multiply
403 mul11_op = Operation("MulAct", self.op.name + "_mul11")
404 mul11_op.add_input_tensor(half_denominator)
405 mul11_op.add_input_tensor(
406 create_const_tensor(
407 "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], np.int32, quantization=one_scale_quant
408 ),
409 )
410 rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
411 rescaled.quantization = one_scale_quant.clone()
412 rescaled.quantization.scale_f32 = 2.0
413 mul11_op.set_output_tensor(rescaled)
414
415 # PASS 12 - Add
416 add12_op = Operation("AddAct", self.op.name + "_add12")
417 add12_op.add_input_tensor(rescaled)
418 add12_op.add_input_tensor(
419 create_const_tensor(
420 "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
421 ),
422 )
423 rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
424 rescale_w_offset.quantization = one_scale_quant
425 add12_op.set_output_tensor(rescale_w_offset)
426
427 nr_x = rescale_w_offset
428 F2_one = create_const_tensor(
429 "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
430 )
431 two = create_const_tensor(
432 "two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant
433 )
434 for i in range(3):
435 # PASS 13, 18, 23 - MUL
436 mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
437 mul_op.add_input_tensor(nr_x)
438 mul_op.add_input_tensor(half_denominator)
439 half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
440 half_denominator_times_x.quantization = one_scale_quant.clone()
441 half_denominator_times_x.quantization.scale_f32 = 2.0
442 mul_op.set_output_tensor(half_denominator_times_x)
443 # PASS 14, 19, 24 - SUB
444 sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
445 sub_op.add_input_tensor(F2_one)
446 sub_op.add_input_tensor(half_denominator_times_x)
447 one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
448 one_minus_half_denominator_times_x.quantization = one_scale_quant
449 sub_op.set_output_tensor(one_minus_half_denominator_times_x)
450 # PASS 15, 20, 25 - MUL
451 mul_op = Operation("MulAct", self.op.name + "_mul%d" %+ (15 + i * 5))
452 mul_op.add_input_tensor(nr_x)
453 mul_op.add_input_tensor(one_minus_half_denominator_times_x)
454 to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
455 to_rescale.quantization = one_scale_quant.clone()
456 to_rescale.quantization.scale_f32 = 2.0
457 mul_op.set_output_tensor(to_rescale)
458 # PASS 16, 21, 26 - SHL
459 shl_op = Operation("SHL", self.op.name + "_shl%d" % (16 + i * 5))
460 shl_op.add_input_tensor(to_rescale)
461 shl_op.add_input_tensor(two)
462 to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
463 to_add.quantization = no_scale_quant
464 shl_op.set_output_tensor(to_add)
465 # PASS 17, 22, 27 - ADD
466 add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
467 add_op.add_input_tensor(nr_x)
468 add_op.add_input_tensor(to_add)
469 nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
470 nr_x.quantization = one_scale_quant
471 add_op.set_output_tensor(nr_x)
472
473 # PASS 28 - SHL
474 shl28_op = Operation("SHL", self.op.name + "_shl28")
475 shl28_op.add_input_tensor(nr_x)
476 shl28_op.add_input_tensor(one)
477 scale_factor = Tensor(reduce_sum_shape, DataType.int32, shl28_op.name + "_0")
478 scale_factor.quantization = one_scale_quant
479 shl28_op.set_output_tensor(scale_factor)
480
481 # PASS 29 - Multiply
482 mul_op = Operation("MulAct", self.op.name + "_mul29")
483 mul_op.add_input_tensor(ifm_exp)
484 mul_op.add_input_tensor(scale_factor)
485 scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
486 scaled_exp.quantization = one_scale_quant.clone()
487 scaled_exp.quantization.scale_f32 = 2.0
488 mul_op.set_output_tensor(scaled_exp)
489
490 # PASS 30 - SHR
491 shr30_op = Operation("SHR", self.op.name + "_shr30")
492 shr30_op.add_input_tensor(scaled_exp)
493 shr30_op.add_input_tensor(right_shift)
494 shr30_op.set_output_tensor(ofm)
495
496 return shr30_op
497
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200498 def get_graph_int16(self, ifm, ofm):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100499 ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
500 ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200501 no_scale_quant = ifm.quantization.clone()
502 no_scale_quant.scale_f32 = None
503
504 # PASS 0 - Depthwise Maxpool
505 maxpool_op = self.op.clone("_maxpool0")
506 maxpool_op.type = "MaxPool"
507 maxpool_h = ifm.shape[1] * ifm.shape[2]
508 maxpool_w = ifm.shape[3]
509 maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
510 maxpool_op.attrs["padding"] = b"VALID"
511 maxpool_op.attrs["stride_w"] = 1
512 maxpool_op.attrs["stride_h"] = 1
513 maxpool_op.attrs["filter_width"] = maxpool_w
514 maxpool_op.attrs["filter_height"] = 1
515 maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
516 maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100517 maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200518 maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200519 maxpool_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100520 maxpool_op.set_output_tensor(maxpool_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200521
522 # PASS 1 - Sub
523 sub1_op = Operation("SubAct", self.op.name + "_sub1")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100524 sub1_op.add_input_tensor(ifm)
525 sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200526 sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
527 sub1_ofm.quantization = ifm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100528 sub1_op.set_output_tensor(sub1_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200529
530 # PASS 2 - Mul
531 beta = self.op.attrs.get("beta", 1.0)
532 mul2_out_range = 10.0 / 65535.0
533 mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
534 mul2_quant = ifm.quantization.clone()
535 mul2_quant.scale_f32 = beta
536 mul2_op = Operation("MulAct", self.op.name + "_mul2")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100537 mul2_op.add_input_tensor(sub1_ofm)
538 mul2_op.add_input_tensor(
539 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200540 mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200541 ),
542 )
543 mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
544 mul2_ofm.quantization = ofm.quantization.clone()
545 mul2_ofm.quantization.scale_f32 = mul2_out_range
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100546 mul2_op.set_output_tensor(mul2_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200547
548 # PASS 3 - Add+LUT(exp)
549 add_op = Operation("AddAct", self.op.name + "_add3")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100550 add_op.add_input_tensor(mul2_ofm)
551 add_op.add_input_tensor(
552 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200553 add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200554 ),
555 )
556 add_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100557 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200558 add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200559 )
560 )
561 exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
562 exp_ofm.quantization = mul2_ofm.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100563 add_op.set_output_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200564
565 # PASS 4 - Reduce sum
566 reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
567 reduce_sum_op.attrs["padding"] = b"VALID"
568 reduce_sum_op.attrs["stride_w"] = 1
569 reduce_sum_op.attrs["stride_h"] = 1
570 reduce_sum_op.attrs["filter_width"] = 1
571 reduce_sum_op.attrs["filter_height"] = 1
572 reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
573 reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100574 reduce_sum_op.add_input_tensor(exp_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200575
576 reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
577 sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
578 sum_of_exp.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100579 reduce_sum_op.set_output_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200580
581 # PASS 5 - CLZ
582 clz_op = Operation("CLZ", self.op.name + "_clz5")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100583 clz_op.add_input_tensor(sum_of_exp)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200584 headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
585 headroom_plus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100586 clz_op.set_output_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200587
588 # PASS 6 - Sub
589 sub6_op = Operation("SubAct", self.op.name + "_sub6")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100590 sub6_op.add_input_tensor(
591 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200592 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200593 ),
594 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200595 sub6_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200596 reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
597 reciprocal_right_shift.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100598 sub6_op.set_output_tensor(reciprocal_right_shift)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200599
600 # PASS 7 - SHL
601 shl7_op = Operation("SHL", self.op.name + "_shl7")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100602 shl7_op.add_input_tensor(
603 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200604 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200605 ),
606 )
Jacob Bohlinbe733cf2020-08-13 10:21:34 +0200607 shl7_op.add_input_tensor(reciprocal_right_shift)
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200608 constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200609 constant_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100610 shl7_op.set_output_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200611
612 # PASS 8 - Sub
613 sub8_op = Operation("SubAct", self.op.name + "_sub8")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100614 sub8_op.add_input_tensor(sum_of_exp)
615 sub8_op.add_input_tensor(constant_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200616 sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
617 sum_of_exps_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100618 sub8_op.set_output_tensor(sum_of_exps_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200619
620 # PASS 9 - SHL
621 shl9_op = Operation("SHL", self.op.name + "_shl9")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100622 shl9_op.add_input_tensor(sum_of_exps_minus_one)
623 shl9_op.add_input_tensor(headroom_plus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200624 shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
625 shifted_sum_minus_one.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100626 shl9_op.set_output_tensor(shifted_sum_minus_one)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200627
628 # PASS 10 - SHR
629 shr10_op = Operation("SHR", self.op.name + "_shr10")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100630 shr10_op.add_input_tensor(shifted_sum_minus_one)
631 shr10_op.add_input_tensor(
632 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200633 shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200634 ),
635 )
636 shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
637 shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100638 shr10_op.set_output_tensor(shifted_sum_minus_one_16)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200639
640 # PASS 11 - Sub+LUT(one over one plus x)
641 sub11_op = Operation("SubAct", self.op.name + "_sub11")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100642 sub11_op.add_input_tensor(shifted_sum_minus_one_16)
643 sub11_op.add_input_tensor(
644 create_const_tensor(
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200645 sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200646 ),
647 )
648 sub11_op.set_activation_lut(
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100649 create_const_tensor(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200650 sub11_op.name + "_lut",
651 [1, 1, 1, 512],
652 DataType.int32,
653 self.ONE_OVER_ONE_PLUS_X_LUT,
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200654 np.int32,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200655 TensorPurpose.LUT,
656 )
657 )
658 reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
659 reciprocal_scale.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100660 sub11_op.set_output_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200661
662 # PASS 12 - Multiply
663 mul_op = Operation("MulAct", self.op.name + "_mul12")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100664 mul_op.add_input_tensor(exp_ofm)
665 mul_op.add_input_tensor(reciprocal_scale)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200666 mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
667 mul_ofm.quantization = no_scale_quant
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100668 mul_op.set_output_tensor(mul_ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200669
670 # PASS 13 - SHR
671 shr13_op = Operation("SHR", self.op.name + "_shr13")
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100672 shr13_op.add_input_tensor(mul_ofm)
673 shr13_op.add_input_tensor(reciprocal_right_shift)
674 shr13_op.set_output_tensor(ofm)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200675
676 return shr13_op