blob: 7043746d8081047ec68478dcce0d7680d04cd5bb [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001/*
Rickard Bolinbc6ee582022-11-04 08:24:29 +00002 * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01003 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <stdio.h>
20#include <stdlib.h>
21#include <stdint.h>
22#include <stdbool.h>
23#include <string.h>
24#include <assert.h>
25#include <math.h>
26#include <stdarg.h>
27#include <math.h>
28#include "mlw_common.h"
29#include "mlw_encode.h"
30
31#define DPRINTF(...)
32//#define DPRINTF(...) printf(__VA_ARGS__)
33
34#define ZERO_RUN_THRES 4
35
Fredrik Svedberg5b513882020-12-11 13:42:22 +010036#ifndef min
Tim Hall79d07d22020-04-27 18:20:16 +010037#define min(a,b) ((a)<(b)?(a):(b))
Fredrik Svedberg5b513882020-12-11 13:42:22 +010038#endif
39#ifndef max
Tim Hall79d07d22020-04-27 18:20:16 +010040#define max(a,b) ((a)>(b)?(a):(b))
Fredrik Svedberg5b513882020-12-11 13:42:22 +010041#endif
Tim Hall79d07d22020-04-27 18:20:16 +010042
43typedef struct palette {
44 int16_t lut[32];
45 int16_t inv_lut[512];
46 int palsize; // number of palette entries
47 int palbits; // bit width of palette entries
48 int use_zero_runs; // zeros are coded separately
49 int only_palette; // no values outside the palette
50 int direct_offset; // added to the decoded weight index before direct conversion to sign/mag
51 int only_zeros; // special case that the section is all zeros
52} palette_t;
53
54static int is_power_of_two( int x ) {
55 return ((x-1) & x)==0;
56}
57
58static int get_palette_index_bits( int size ) {
59 int i;
60 for(i=7; i>=0; i--)
61 if (size > (1<<i) )
62 return i+1;
63 return 0;
64}
65
66// Search the stream for suitable palette restart positions
67// Return the number of restarts
68static int search_palette_sections( int16_t *buf, int size, int **palette_restart_positions ) {
69 int i,j,got_palette,restart_i,palette_size=0, last_restart_idx, zero_cnt;
70 int prev_idx[512]; // For each value, keep track of the index of the previous occurence
71 int *restart_pos;
Tim Hallbb330392020-04-29 15:13:00 +010072 int max_palettes = (size+63)/64;
Tim Hall79d07d22020-04-27 18:20:16 +010073
74 // Preliminary allocation of sufficient size
75 restart_pos = (int*)malloc( max_palettes*sizeof(int) );
76 last_restart_idx=0;
77 got_palette=0;
78 restart_i=1;
79 restart_pos[0] = 0;
80 zero_cnt=0;
81 memset( prev_idx, -1, sizeof(prev_idx));
82 for(i=0; i<size; i++) {
83 // Guess if zeros should be excluded from the palette
84 int exclude_zero = zero_cnt > (i-last_restart_idx)/4;
85
86 if (got_palette) {
87 // Check if the next value is not covered by the current palette
88 if ( prev_idx[ buf[i]+256 ] < last_restart_idx ) {
89 // New value: increase the palette size
90 palette_size++;
91 DPRINTF("Note: at pos %d extend palette to size %d\n", i, palette_size);
92 if ( is_power_of_two(palette_size-1-exclude_zero) ) {
93 if ( (i - last_restart_idx - zero_cnt) > 512 || (palette_size-exclude_zero)>32 ) {
94 // create a new palette because we extend a long lasting palette to require one more index bit
95 DPRINTF("Note: at pos %d create new palette because previous has to increase one more index bit. last_restart_idx %d n %d zero_cnt %d\n", i, last_restart_idx, i - last_restart_idx, zero_cnt );
Tim Hallbb330392020-04-29 15:13:00 +010096 if (restart_i == max_palettes) {
97 max_palettes = max_palettes*2;
98 restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
99 }
Tim Hall79d07d22020-04-27 18:20:16 +0100100 DPRINTF("restart %d pos %d\n", restart_i, i);
101 restart_pos[restart_i++] = i;
102 last_restart_idx = i;
103 got_palette=0;
104 zero_cnt=0;
105 }
106 }
107 }
108 }
109
110 prev_idx[ buf[i]+256 ] = i;
111 if (buf[i]==0)
112 zero_cnt++;
113
114 static const int window_sizes[5][2] = {{32,1}, {64,1}, {128,1}, {256,1}, {512,1}};
115 int k;
116 // loop over window sizes
117 for(k=0; k<5; k++) {
118 // Every Nth non-zero value, count what would be the size of a palette covering the last N NZ.
119 int N = window_sizes[k][0] * (got_palette?2:1);
120 if ( (i - last_restart_idx - zero_cnt) > 0 && ((i - last_restart_idx - zero_cnt) % N)==0 ) {
121 // Search backward to the position N nonzero values earlier
122 int nzcnt=0;
123 for( j=i; j>last_restart_idx; j--) {
124 if ( buf[j]!=0 ) {
125 if (nzcnt==N+1)
126 break;
127 nzcnt++;
128 }
129 }
130 int restart_idx = j;
131
132 // Calculate the size of a new palette (starting at restart_idx)
133 int new_palette_size=0;
134 for(j=0; j<512; j++) {
135 if ( prev_idx[j] >= restart_idx ) {
136 new_palette_size++;
137 }
138 }
139
140 int create_new_palette=0;
141 if (got_palette) {
142 int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero );
143 int old_size_bits = get_palette_index_bits( palette_size - exclude_zero );
144 int savings = N*(old_size_bits*15-new_size_bits*15)/16 - new_palette_size*8 - 20;
145 if ( savings>0 ) {
146 // Create new palette because it can be smaller than the existing palette
147 create_new_palette=1;
148 DPRINTF("Note: at pos %d restart smaller palette\n", restart_idx);
149 }
150 } else {
151 if ( (new_palette_size-exclude_zero) <= 32) {
152 int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero );
153 // estimate if we will make savings by using palette mode
154 int savings = N*(90-new_size_bits*15)/16 - new_palette_size*8 - 20;
155 create_new_palette = savings>0;
156 }
157 }
158 if (create_new_palette) {
159 palette_size=new_palette_size;
160 got_palette=1;
161 last_restart_idx = restart_idx;
162 DPRINTF("Note: at pos %d create palette of size %d\n", last_restart_idx, new_palette_size);
163 if ( restart_pos[restart_i-1] != last_restart_idx) {
Tim Hallbb330392020-04-29 15:13:00 +0100164 if (restart_i == max_palettes) {
165 max_palettes = max_palettes*2;
166 restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
167 }
Tim Hall79d07d22020-04-27 18:20:16 +0100168 restart_pos[restart_i++] = last_restart_idx;
169 }
170 zero_cnt=0;
171 for( j=last_restart_idx; j<=i; j++)
172 if (buf[j]==0)
173 zero_cnt++;
174 }
175 }
176 }
177 }
178 // Reallocate to actual size
179 *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) );
180 return restart_i;
181}
182
183// Calculate frequency table
184static void calc_freq( const int16_t *buf, int size, int freq[512] ) {
185 int i;
186 memset(freq, 0, 512*sizeof(int));
187 for(i=0; i<size; i++) {
188 freq[buf[i]+256]++;
189 }
190}
191
192static int cmp_uint64(const void * a, const void * b) {
193 uint64_t aa = *(uint64_t*)a;
194 uint64_t bb = *(uint64_t*)b;
195 return aa>bb ? -1 : aa<bb ? 1 : 0;
196}
197
198// Create palette from the given frequencies
199// Freq index 0-511 correspond to weights -256..255
200static void create_palette( int freq[512],
201 int use_zero_runs,
202 palette_t *p ) {
203 uint64_t freq64[512];
204 int i,all_cnt,all_max_val;
205
206 // Pair the frequency with the value so that
207 // the array can be sorted on frequency while keeping
208 // track of the corresponding palette value
209 memset(freq64, 0, sizeof(freq64));
210 all_cnt=0;
211 all_max_val=0;
212 for(i=-255; i<256; i++) {
213 if (i==0 && use_zero_runs)
214 continue;
215 int sign = i<0;
216 int mag = abs(i);
217 int palval = (mag<<1) | sign;
218
219 // Store palette value in 16 LSB bits, which will not affect the sorting
220 freq64[palval] = (((uint64_t)freq[i+256])<<16) | palval;
221 all_cnt+=freq[i+256];
222
223 if (freq[i+256]>0) {
224 all_max_val = max(all_max_val, palval);
225 }
226 }
227
228 // Count number of non-used weight values around zero (0, -1, +1, -2, +2 etc)
229 for(i=0; i<31; i++) {
230 if ((freq64[i]>>16)!=0)
231 break;
232 }
233 p->direct_offset = i;
234
235 // Sort in descending frequency order
236 qsort(freq64, 512, sizeof(uint64_t), cmp_uint64);
237
238 // Identify special case that there are no weights to code
239 // in the weight index stream (i.e. all weights are zeros)
240 p->only_zeros = (freq64[0]>>16)==0;
241 if (p->only_zeros) {
242 p->direct_offset=0;
243 }
244
245 // Check if all weights fit into the palette (and the palette is not empty)
246 p->only_palette = (freq64[0]>>16)>0 && (freq64[32]>>16)==0;
247
248 int max_palette_size;
249 if (p->only_palette) {
250 max_palette_size = 32;
251 } else {
252 // For direct-lut we must make sure that the encoded weight
253 // index is not > 511. We do that by limiting the palette size
254 // such that the greatest value can be reached after subtracting
255 // the palette size.
256 max_palette_size = min(32, 511-all_max_val);
257 if (max_palette_size==1) {
258 max_palette_size=0; // because palette of size 1 is not supported
259 }
260 }
261
262 // Setup the 32 entry palette
263 int palette_max_val = 0, val, cnt, pal_cnt=0;
264 for(i=0; i<max_palette_size; i++) {
Fredrik Svedberg5b513882020-12-11 13:42:22 +0100265 cnt = (int)(freq64[i]>>16);
Tim Hall79d07d22020-04-27 18:20:16 +0100266 val = freq64[i]&0xffff;
267 if ( cnt==0 )
268 break;
269 p->lut[i] = val;
270 palette_max_val = max(palette_max_val, val);
271 pal_cnt+=cnt;
272 }
273 if (i==1)
Johan Alfvénb75e2e72022-10-04 08:18:44 +0200274 p->lut[i++] = 0; // palette size of 1 is not supported, make it 2
Tim Hall79d07d22020-04-27 18:20:16 +0100275
276 // Heuristic for when to use the palette. If more than half of the
277 // weights are in the palette then we use it. This ensures we don't
278 // use palette for e.g. rectangular distributions.
279 int palbits_val;
280 if (pal_cnt > all_cnt/2) {
281 p->palsize = i;
282 palbits_val = palette_max_val;
283 } else {
284 // No palette
285 p->palsize = 0;
286 // If no palette, then palbits is used to specify the
287 // number of bits required for uncompressed mode, i.e.
288 // the number of bits for the greatest weight value
289 palbits_val = all_max_val;
290 }
291
292 // the palette entry bit width
293 // minimum 2bits (because PALBITS is in range 2..9)
294 int palbits=2;
295 while( (1<<palbits) <= palbits_val )
296 palbits++;
297 assert(palbits<=9);
298 p->palbits = palbits;
299 p->use_zero_runs = use_zero_runs;
300}
301
302// Return 1 if zero runs should be used
303// If palette_size is 512, then palette is not used (in that case the palette is setup
304// with the standard alternating unsigned to signed mapping)
305static int find_palette( const int16_t *inbuf, int inbuf_size, palette_t *p) {
306 int freq[512], i;
307
308 // Calculate frequencies of the given weight stream
309 calc_freq( inbuf, inbuf_size, freq);
310
311 // Find two most common values
312 int most_common_freq[2]={0}, most_common_val[2]={0};
313 for(i=0; i<512; i++) {
314 if ( freq[i] > most_common_freq[0] ) {
315 most_common_freq[1] = most_common_freq[0];
316 most_common_val[1] = most_common_val[0];
317 most_common_freq[0] = freq[i];
318 most_common_val[0] = i-256;
319 } else if ( freq[i] > most_common_freq[1] ) {
320 most_common_freq[1] = freq[i];
321 most_common_val[1] = i-256;
322 }
323 }
324
325 // Decide if zero-runs (alternating mode) should be used:
326 // * zero should be the most common symbol
327 // * zero should be sufficiently more common than the second most common symbol
328 int use_zero_runs = most_common_val[0]==0 && most_common_freq[0] > ZERO_RUN_THRES*most_common_freq[1];
329
330 // Create the palette
331 create_palette( freq, use_zero_runs, p);
332
333 return use_zero_runs;
334}
335
336static void create_inverse_palette( palette_t *p) {
337 int i;
338 memset( p->inv_lut, 0, sizeof(p->inv_lut));
339 for(i=0; i<512; i++) {
340 int val = i;
341 int sign = val&1;
342 int mag = val>>1;
343 int weight = sign ? -mag : mag;
344 if (weight+256 < 512)
345 p->inv_lut[ weight+256 ] = i + p->palsize - p->direct_offset;
346 }
347 for(i=0; i<p->palsize; i++) {
348 int val = p->lut[i];
349 int sign = val&1;
350 int mag = val>>1;
351 int weight = sign ? -mag : mag;
Johan Alfvénb75e2e72022-10-04 08:18:44 +0200352 assert(weight+256 >= 0 && weight+256 < 512);
Tim Hall79d07d22020-04-27 18:20:16 +0100353 if (weight+256 < 512)
354 p->inv_lut[ weight+256 ] = i;
355 }
356}
357
358#define NWCFG 13
359#define NZCFG 4 // restrict search to ZDIV=0..3
360#define MAX_ZWCFG (max(NWCFG,NZCFG))
361
362// search state
363typedef struct search_state {
364 int bitcnt; // number of bits to reach this state
365 uint8_t prev_cfg; // previous grc parameter config
366} search_state_t;
367
368// (trunc<<4) | div, 0x20 means uncompressed
369static const char w_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20 };
370static const char z_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04 };
371
372
373
374// An algorithm similar to the Viterbi algorithm is used to search for a
375// good GRC parameter sequence for the given input value sequence.
376// The inval buffer can contain weights, weight indices or runs.
377// The return value is the resulting number of bitstream sections.
378static int search_grc_params( const int *inval_buf,
379 int n_inval,
380 int zrun_mode,
381 int uncompressed_bits,
382 uint8_t *grc_param_cfg,
383 int *grc_param_pos,
384 int max_grc_param_cfg,
385 int *existing_grc_param_pos,
386 int n_existing_grc_param_pos,
387 int *bitcnt )
388{
389 int n_cfg = zrun_mode ? NZCFG : NWCFG;
390 const char *grc_params = zrun_mode ? z_grc_params : w_grc_params;
391 int i,j;
392
393 search_state_t *state[MAX_ZWCFG];
394 for(i=0; i<n_cfg; i++) {
395 state[i] = malloc( sizeof(search_state_t) * (n_inval+1) );
396 state[i][0].bitcnt=0;
397 state[i][0].prev_cfg=i;
398 }
399
400 // Loop over inval_buf
401 int existing_idx=0;
402 for(i=0; i<n_inval; i++) {
403 int value = inval_buf[i];
404
405 // Best GRC parameter so far
406 int best_bitcnt=0x7fffffff, best_cfg=0;
407 for(j=0; j<n_cfg; j++) {
408 if (state[j][i].bitcnt < best_bitcnt) {
409 best_bitcnt = state[j][i].bitcnt;
410 best_cfg = j;
411 }
412 }
413
414 int cmd_cost = 40;
415 if (existing_idx < n_existing_grc_param_pos && existing_grc_param_pos[existing_idx] == (i+1)) {
416 // free transition, because the weight stream already inserted a command at this position
417 cmd_cost = 0;
418 existing_idx++;
419 }
420
421 // Loop over GRC parameters, calculate bits to code value, and then update the search state
422 for(j=0; j<n_cfg; j++) {
423 int div = grc_params[j]&15;
424 int trunc = grc_params[j]>>4;
425 int q = value>>div;
426 int bits = trunc ? min(q+1,2) + div : q+1+div;
427 if (!zrun_mode && ((trunc && q>2) || q>31))
428 bits=10000; // it's not possible to code the current value; give it a high cost
429 if (trunc==2)
430 bits=uncompressed_bits;
431
432 if ( best_bitcnt + cmd_cost < state[j][i].bitcnt ) {
433 // Change GRC parameters
434 state[j][i+1].prev_cfg = best_cfg;
435 state[j][i+1].bitcnt = best_bitcnt + cmd_cost + bits;
436 } else {
437 // Keep same GRC parameters
438 state[j][i+1].prev_cfg = j;
439 state[j][i+1].bitcnt = state[j][i].bitcnt + bits;
440 }
441 }
442 }
443
444
445 // Best GRC parameter
446 int best_bitcnt=0x7fffffff, best_cfg=0;
447 for(j=0; j<n_cfg; j++) {
448 if (state[j][n_inval].bitcnt < best_bitcnt) {
449 best_bitcnt = state[j][n_inval].bitcnt;
450 best_cfg = j;
451 }
452 }
453
454 int cfg = best_cfg;
455 int n_cmds=0;
456 for(i=n_inval; i>=0; i--) {
457 if (state[cfg][i].prev_cfg != cfg || i==0) {
458 n_cmds++;
459 cfg = state[cfg][i].prev_cfg;
460 }
461 }
462
463 (void)(max_grc_param_cfg);
464 assert(n_cmds<=max_grc_param_cfg);
465
466 cfg = best_cfg;
467 j=n_cmds-1;
468 int endpos=n_inval;
469 for(i=n_inval; i>=0; i--) {
470 if (state[cfg][i].prev_cfg != cfg || i==0) {
471 grc_param_cfg[j] = cfg;
472 grc_param_pos[j] = endpos;
473 j--;
474 cfg = state[cfg][i].prev_cfg;
475 endpos = i-1;
476 }
477 }
478 assert(j==-1);
479
480 for(i=0; i<n_cfg; i++) {
481 free(state[i]);
482 }
483
484 *bitcnt = best_bitcnt;
485
486 return n_cmds;
487}
488
489
490/////////////////////////////// Write to bitstream
491
492typedef struct bitbuf {
493 uint8_t *buf;
494 int buf_size; // in bytes
495 int pos; // bit pos of next bit
496 int log_symbols;
497} bitbuf_t;
498
499// size in byte
500static void bitbuf_init( bitbuf_t *bb, uint8_t *buf, int size, int log_symbols ) {
501 bb->buf = buf;
502 bb->pos = 0;
503 bb->buf_size = size;
504 bb->log_symbols = log_symbols;
505}
506
507static void bitbuf_putbit( bitbuf_t *bb, int bit) {
508 int byte_pos = bb->pos>>3;
509 int bit_pos = bb->pos&7;
510 assert( byte_pos >= 0 );
511 assert( byte_pos < bb->buf_size );
512 bb->buf[ byte_pos ] = (bb->buf[ byte_pos ] & ~(1<<bit_pos)) | (bit<<bit_pos);
513 bb->pos += 1;
514}
515
516static void bitbuf_put( bitbuf_t *bb, const char *name, int len, int data) {
517 int i;
518 if (len>0) {
519 if (bb->log_symbols)
520 printf("bitbuf: pos %3d %7s len %d data %x\n", bb->pos, name, len, data);
521 for(i=0; i<len; i++) {
522 bitbuf_putbit(bb, (data>>i)&1);
523 }
524 }
525}
526
527// Return new bitpos
528static int encode_slice( const int *w_value,
529 const int *z_value,
530 int nvalues,
531 palette_t *p,
532 int new_palette,
533 int uncompressed_bits,
534 int w_cfg,
535 int z_cfg,
536 uint8_t *bitbuf,
537 int bitbuf_size,
538 int bitpos,
539 int verbose )
540{
541 int i,j;
542 bitbuf_t bitbuf_s, *bb=&bitbuf_s;
543 bitbuf_init( bb, bitbuf, bitbuf_size, verbose&2?1:0 );
544 bb->pos = bitpos;
545
546 assert(nvalues<32768);
547 // GRC parameters for this slice
548 int w_grc_div = w_grc_params[w_cfg] & 15;
549 int w_grc_trunc = (w_grc_params[w_cfg] >> 4)==1;
550 int w_uncompressed = (w_grc_params[w_cfg] >> 4)==2;
551 int z_grc_div = z_grc_params[z_cfg] & 15;
552
553 if (w_uncompressed) {
554 w_grc_div = uncompressed_bits;
555 }
556
557 int zdiv = p->use_zero_runs ? z_grc_div : ZDIV_DISABLE;
558 int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED;
559
560 if (verbose&1) {
561 printf("slice: bitoffset %7d slicelen %5d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %2d\n",
562 bb->pos, nvalues, zdiv, wdiv, w_grc_trunc, new_palette, p->palbits, p->palsize);
563 }
564
565 // Write slice header
566 bitbuf_put( bb, "ZDIV", 3, zdiv);
567 bitbuf_put( bb, "SLICELEN", 15, nvalues-1 );
568 bitbuf_put( bb, "WDIV", 3, wdiv);
569 bitbuf_put( bb, "WTRUNC", 1, w_grc_trunc );
570 bitbuf_put( bb, "NEWPAL", 1, new_palette );
571 if (new_palette) {
572 bitbuf_put( bb, "DIROFS", 5, p->direct_offset );
573 bitbuf_put( bb, "PALSIZE", 5, max(0, p->palsize-1));
574 bitbuf_put( bb, "PALBITS", 3, p->palbits-2 );
575 for(i=0; i<p->palsize; i++) {
576 bitbuf_put( bb, "PALETTE", p->palbits, p->lut[i] );
577 }
578 }
579
580 int z_nvalues = nvalues + (new_palette?1:0);
581 int w_pos=0, z_pos=0;
582 int w_unary0=0, w_unary1=0, w_unary1_len=0, w_q=-1, w_r=0;
583 int z_unary=0, z_q=-1, z_r=0;
584 int w_nsymbols=0, w_remain[12]={0};
585 int w_prev_enable=0, w_prev_nsymbols=0, w_prev_remain[12]={0};
586 int z_nsymbols=0, z_remain[12]={0};
587 int z_prev_enable=0, z_prev_nsymbols=0, z_prev_remain[12]={0};
588 int z_unary_len = z_grc_div<3 ? 12 : 8;
589 do {
590 int balance = p->use_zero_runs ? w_pos - z_pos : 0;
591 int w_enable = balance<8 && w_pos<nvalues;
592 int z_enable = balance>=0 && p->use_zero_runs && z_pos<z_nvalues;
593 if (w_enable) {
594 // Encode chunk (weights)
595 j=0;
596 w_nsymbols=0;
597 w_unary0=0;
598 w_unary1=0;
599 w_unary1_len=0;
600 int max_symbols = w_uncompressed && w_grc_div>5 ? 8 : 12;
601 while(j<max_symbols) {
602 if (w_q<0) {
603 if (w_pos<nvalues) {
604 int value = w_value[w_pos];
605 assert(value<512);
606 w_q = value>>w_grc_div;
607 w_r = value&((1<<w_grc_div)-1);
608 assert( w_q<=31 && (!w_grc_trunc || w_q<=2));
609 } else {
610 w_q = 0;
611 w_r = -1; // don't send remainder
612 }
613 }
614 while( w_q>=0 && j<max_symbols) {
615 w_unary0 |= w_q>0 ? (1<<j) : 0;
616 if (w_q>0) {
617 w_unary1 |= w_q>1 ? (1<<w_unary1_len) : 0;
618 w_unary1_len++;
619 }
620 j++;
621 w_q-=2;
622 if (w_grc_trunc)
623 w_q--;
624 }
625 if (w_q<0 && w_r>=0) {
626 w_remain[w_nsymbols] = w_r;
627 w_nsymbols++;
628 w_pos++;
629 }
630 }
631 }
632
633 if (z_enable) {
634 // Encode chunk (zrun)
635 j=0;
636 z_nsymbols=0;
637 z_unary=0;
638 while(j<z_unary_len) {
639 if (z_q<0) {
640 if (z_pos<z_nvalues) {
641 int value = z_value[z_pos];
642 z_q = value>>z_grc_div;
643 z_r = value&((1<<z_grc_div)-1);
644 } else {
645 z_q = 0;
646 z_r = -1;
647 }
648 }
649 while( z_q>=0 && j<z_unary_len) {
650 z_unary |= z_q>0 ? (1<<j) : 0;
651 j++;
652 z_q--;
653 }
654 if (z_q<0 && z_r>=0) {
655 z_remain[z_nsymbols] = z_r;
656 z_nsymbols++;
657 z_pos++;
658 }
659 }
660 }
661
662 // Write chunk to bitstream
663 if (w_enable && !w_uncompressed) {
664 bitbuf_put( bb, "WUNARY0", 12, w_unary0);
665 }
666 if (z_enable) {
667 bitbuf_put( bb, "ZUNARY", z_unary_len, z_unary);
668 }
669 if (w_enable && !w_uncompressed) {
670 bitbuf_put( bb, "WUNARY1", w_unary1_len, w_unary1);
671 }
672 if (w_prev_enable) {
673 for(i=0; i<w_prev_nsymbols; i++) {
674 bitbuf_put( bb, "WREMAIN", w_grc_div, w_prev_remain[i]);
675 }
676 }
677 if (z_prev_enable) {
678 for(i=0; i<z_prev_nsymbols; i++) {
679 bitbuf_put( bb, "ZREMAIN", z_grc_div, z_prev_remain[i]);
680 }
681 }
682 w_prev_enable = w_enable;
683 w_prev_nsymbols = w_nsymbols;
684 memcpy( w_prev_remain, w_remain, sizeof(w_prev_remain));
685 z_prev_enable = z_enable;
686 z_prev_nsymbols = z_nsymbols;
687 memcpy( z_prev_remain, z_remain, sizeof(z_prev_remain));
688 } while( w_prev_enable || z_prev_enable );
689
690 return bb->pos;
691}
692
693
694// return new bitpos
695static int encode_section( const int16_t *inbuf,
696 int size,
697 palette_t *p,
698 uint8_t *bitbuf,
699 int bitbuf_size,
700 int bitpos,
701 int verbose )
702{
703 int uncompressed_bits;
704
705 // Uncompressed mode can only be used if either all weights
706 // are in the palette OR if the palette is not used.
707 if (p->only_palette) {
708 // Uncompressed bits derived from palette size
709 uncompressed_bits=0;
710 while( (1<<uncompressed_bits) < p->palsize )
711 uncompressed_bits++;
712 } else if (p->palsize==0) {
713 // Uncompressed bits is palbits (which is the bitdepth of the greatest weight)
714 uncompressed_bits = p->palbits;
715 } else {
716 // Don't use uncompressed
717 uncompressed_bits = 100;
718 }
719
720 int *weight_values = malloc( size*sizeof(int) );
721 int *zrun_values = malloc( size*sizeof(int) );
722
723 // Get weights (or weight indicies) AND zero-runs from the input weight stream.
724 int i=0, n_weights = 0, zcnt;
725 while(1) {
726 if (p->use_zero_runs) {
727 zcnt=0;
728 // Count zero run
729 // Special case: if all weights in the section are zero, we must
730 // still ensure we have one coded weight so the the slice length
731 // doesn't become 0. Therefore we skip the first zero run and code
732 // the zero explicitly as a weight value instead
733 if (!p->only_zeros || i>0) {
734 while( i<size && inbuf[i]==0) {
735 zcnt++;
736 i++;
737 }
738 }
739 zrun_values[n_weights] = zcnt;
740 }
741 if (i==size)
742 break;
743 int value = p->inv_lut[inbuf[i]+256];
744 weight_values[n_weights] = value;
745 n_weights++;
746 i++;
747 }
748
749 // Search for good GRC parameters for the weight stream
750 int n_w_slice, w_bitcnt;
751 uint8_t *w_slice_cfg;
752 int *w_slice_pos;
753 w_slice_cfg = malloc( size );
754 w_slice_pos = malloc( size*sizeof(int) );
755 n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt);
756 if (n_weights==0)
757 n_w_slice = 0;
758
759 // Search for good GRC parameters for the zrun stream
760 int n_z_slice=0, z_bitcnt=0;
761 uint8_t *z_slice_cfg=0;
762 int *z_slice_pos=0;
763 if (p->use_zero_runs) {
764 z_slice_cfg = malloc( size );
765 z_slice_pos = malloc( size*sizeof(int) );
766 n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt);
767 }
768
769 // Encode bitstream slice
770 int pos=0, i_w_slice=0, i_z_slice=0, new_palette=1;
771 while(pos<n_weights || new_palette) {
772 int endpos=pos+32767; // max slice length
773
774 if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]<endpos) {
775 endpos = w_slice_pos[i_w_slice];
776 }
777
778 if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]<endpos) {
779 endpos = z_slice_pos[i_z_slice];
780 }
781
782 if (n_weights < endpos) {
783 endpos = n_weights;
784 }
785
786 // The first slice (when new_palette is 1) encodes zero runs both at the
787 // beginning and end (i.e. number of zero runs are len+1).
788 // The following slices only encode zero runs at the end (there cannot be
789 // any zeros in the beginning since they are encoded by the previous slice)
790 int len = endpos - pos;
791 int *zrun_buf = p->use_zero_runs ? zrun_values+pos+(!new_palette) : 0;
792 bitpos = encode_slice( weight_values+pos, zrun_buf, len,
793 p, new_palette, uncompressed_bits,
794 w_slice_cfg[i_w_slice], p->use_zero_runs ? z_slice_cfg[i_z_slice] : 0,
795 bitbuf, bitbuf_size, bitpos, verbose );
796 new_palette = 0;
797
798 if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]==endpos) {
799 i_w_slice++;
800 }
801 if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]==endpos) {
802 i_z_slice++;
803 }
804 pos = endpos;
805 }
806
807 // Free temporary buffers
808 free(w_slice_cfg);
809 free(w_slice_pos);
810 if (p->use_zero_runs) {
811 free(z_slice_cfg);
812 free(z_slice_pos);
813 }
814 free(weight_values);
815 free(zrun_values);
816
817 return bitpos;
818}
819
820// Encode the given weight stream
821// inbuf uncompressed 9bit signed weights
822// inbuf_size number of weights
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200823// outbuf compressed bitstream, buffer is malloced within this function
Tim Hall79d07d22020-04-27 18:20:16 +0100824// verbose if non-zero, printf log
825// Return value is the size in bytes of the compressed output
826// Return -1 if error
827int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
828 int i;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200829#ifndef NDEBUG
Tim Hall79d07d22020-04-27 18:20:16 +0100830 // Range check
831 for(i=0; i<inbuf_size; i++) {
832 if (inbuf[i]<-255 || inbuf[i]>255) {
833 printf("ERROR: weight out of range at index %d, weight value is %d (valid range is -255..255)\n", i, inbuf[i]);
834 return -1;
835 }
836 }
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200837#endif
Tim Hall79d07d22020-04-27 18:20:16 +0100838
839 int bitbuf_size = inbuf_size*2+1024;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200840 assert(*outbuf == NULL);
Tim Hall79d07d22020-04-27 18:20:16 +0100841 *outbuf = malloc( bitbuf_size );
842
843 // Analyse input data to find palette re-programming points
844 int n_restarts;
845 int *palette_restart_pos;
846 n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos);
847
848 // Compress each section (using a single palette) separately
849 int bitpos=0;
850 for(i=0; i<n_restarts; i++) {
851 palette_t palette;
852 int pos, size;
853 pos = palette_restart_pos[i];
854 size = (i<n_restarts-1 ? palette_restart_pos[i+1] : inbuf_size) - pos;
855 find_palette( inbuf+pos, size, &palette);
856 create_inverse_palette( &palette);
857 bitpos = encode_section( inbuf+pos, size, &palette,
858 *outbuf, bitbuf_size, bitpos, verbose );
859 }
860
861
862 // Add end of stream marker and align to 128bit
863 {
864 bitbuf_t bitbuf_s, *bb=&bitbuf_s;
865 bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 );
866 bb->pos = bitpos;
867 bitbuf_put( bb, "ZDIV", 3, ZDIV_EOS);
868 bitbuf_put( bb, "BYTEALIGN", (8-(bb->pos&7))&7, 0xff );
869
870 // Pad with 0xff until 64bit aligned
871 while( bb->pos & 127 ) {
872 bitbuf_put( bb, "PAD", 8, 0xff );
873 }
874 bitpos = bb->pos;
875 }
876 assert((bitpos&127)==0);
877 int outbuf_size = bitpos/8;
878 *outbuf = realloc( *outbuf, outbuf_size);
879
880 free(palette_restart_pos);
881
882 return outbuf_size;
883}
884
885void mlw_free_outbuf( uint8_t *outbuf ) {
886 if (outbuf)
887 free(outbuf);
888}
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200889
890static int round_up_divide(int num, int den)
891{
892 return (num + den - 1) / den;
893}
894
895static int round_up(int num, int den)
896{
897 return round_up_divide(num, den) * den;
898}
899
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200900struct brick_buf_s
901{
Mauricio Briceno3e4168d2021-06-09 09:49:05 +0200902 int16_t* buf;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200903 int* strides;
904};
905typedef struct brick_buf_s brick_buf_t;
906
907static int16_t get_brick_weight(brick_buf_t* buf, int ofm_z, int wy, int wx, int ifm_z)
908{
Mauricio Briceno3e4168d2021-06-09 09:49:05 +0200909 int16_t* p = buf->buf;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200910
911 p += ofm_z * buf->strides[0];
912 p += wy * buf->strides[1];
913 p += wx * buf->strides[2];
914 p += ifm_z * buf->strides[3];
915
Mauricio Briceno3e4168d2021-06-09 09:49:05 +0200916 return *p;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200917}
918
Fredrik Svedberg93d5c352021-05-11 13:51:47 +0200919static void reorder_free(int16_t* buf)
920{
921 if (buf)
922 {
923 free(buf);
924 }
925}
926
927static int16_t* reorder(
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200928 int ifm_ublock_depth,
929 int ofm_ublock_depth,
930 int ofm_depth,
931 int kernel_height,
932 int kernel_width,
933 int ifm_depth,
934 int* strides,
Mauricio Briceno3e4168d2021-06-09 09:49:05 +0200935 int16_t* inbuf,
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200936 int ofm_block_depth,
937 int is_depthwise,
938 int is_partkernel,
939 int ifm_bitdepth,
940 int decomp_h,
941 int decomp_w,
Fredrik Svedberg93d5c352021-05-11 13:51:47 +0200942 int64_t* padded_length)
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200943{
Fredrik Svedberg93d5c352021-05-11 13:51:47 +0200944 /* Size unknown. Start with one page at least */
945 *padded_length = round_up(max(1, sizeof(int16_t)*
946 ofm_depth*
947 kernel_height*
948 kernel_width*
949 ifm_depth),
950 4*1024) / sizeof(int16_t);
951 int16_t* weights = (int16_t*)malloc(*padded_length * sizeof(int16_t));
952
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200953 brick_buf_t brick_buf;
954 brick_buf.buf = inbuf;
955 brick_buf.strides = strides;
956
957 int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
Fredrik Svedberg93d5c352021-05-11 13:51:47 +0200958 int64_t weight_cnt = 0;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200959 for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth)
960 {
961 int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z);
962 // IFM blocks required for the brick
963 for (int ifm_block_z = 0; ifm_block_z < (is_depthwise ? 1 : ifm_depth); ifm_block_z += ifm_block_depth)
964 {
965 int clipped_ifm_block_depth;
966 if (is_depthwise)
967 {
968 clipped_ifm_block_depth = ifm_ublock_depth;
969 }
970 else
971 {
972 clipped_ifm_block_depth = is_partkernel ?
973 min(ifm_block_depth, ifm_depth - ifm_block_z) : ifm_block_depth;
974 }
975 // Weight decomposition
976 // Subkernel Splitting (H)
977 for (int subkernel_y = 0; subkernel_y < kernel_height; subkernel_y += decomp_h)
978 {
979 int sub_height = min(kernel_height - subkernel_y, decomp_h);
980 // Subkernel splitting (W)
981 for (int subkernel_x = 0; subkernel_x < kernel_width; subkernel_x += decomp_w)
982 {
983 int sub_width = min(kernel_width - subkernel_x, decomp_w);
984 int subkernel_elements = sub_width * sub_height;
985 // Part kernel first works across the kernel H/W and needs padding
986 if (is_partkernel)
987 {
988 if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
989 {
990 subkernel_elements = round_up(subkernel_elements, 2);
991 }
992 else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
993 {
994 subkernel_elements = round_up(subkernel_elements, 4);
995 }
996 }
997 else if (is_depthwise)
998 {
999 subkernel_elements = round_up(subkernel_elements, 4);
1000 }
1001 int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
1002 int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
1003 for (int ifm_ublk_outer = 0; ifm_ublk_outer < ifm_block_depth_outer; ifm_ublk_outer += ifm_ublock_depth)
1004 {
1005 // OFM Ublocks in OFM-block over depth
1006 for (int ofm_ublk = 0; ofm_ublk < clipped_ofm_block_depth; ofm_ublk += ofm_ublock_depth)
1007 {
1008 // HW Kernel element traversal - cannot be a H/W loop due to element
1009 // padding requirement on depthwise/part-kernel configurations
1010 for (int element = 0; element < subkernel_elements; element++)
1011 {
1012 int kx = element % sub_width;
1013 int ky = element / sub_width;
1014 // IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise)
1015 // In case of part-kernel-first IFM Ublock traversal have already been handled
1016 // and this loop is ignored.
1017 for (int ifm_ublk_inner = 0; ifm_ublk_inner < ifm_block_depth_inner; ifm_ublk_inner += ifm_ublock_depth)
1018 {
1019 // Feed OFM ublock elements
1020 for (int ofm_ublock_z = 0; ofm_ublock_z < ofm_ublock_depth; ofm_ublock_z++)
1021 {
1022 // Source IFM ublock elements (only 1 element deep if depthwise)
1023 for (int ifm_ublock_z = 0; ifm_ublock_z < (is_depthwise ? 1 : ifm_ublock_depth); ifm_ublock_z++)
1024 {
1025 // Source position within the current subkernel
1026 int wx = subkernel_x + kx;
1027 int wy = subkernel_y + ky;
1028 // Source IFM/OFM slices
1029 int ifm_ublk = ifm_ublk_inner + ifm_ublk_outer;
1030 int ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z;
1031 int ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z;
1032 if ((ifm_z < ifm_depth) && (ofm_z < ofm_depth) && (ky < sub_height))
1033 {
1034 weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z);
1035 }
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001036 else
1037 {
1038 weights[weight_cnt] = 0;
1039 }
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001040 weight_cnt++;
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001041 if (weight_cnt == *padded_length)
1042 {
1043 // Reallocate by doubling the buffer size as needed
1044 *padded_length *= 2;
1045 weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t));
1046 }
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001047 }
1048 }
1049 }
1050 }
1051 }
1052 }
1053 }
1054 }
1055 }
1056 }
1057
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001058 *padded_length = weight_cnt;
1059 weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t));
1060 return weights;
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001061}
1062
1063// Reorder and encode the given weight stream
1064// Return value is the size in bytes of the compressed output
1065// Return -1 if error
1066int mlw_reorder_encode(
1067 int ifm_ublock_depth,
1068 int ofm_ublock_depth,
1069 int ofm_depth,
1070 int kernel_height,
1071 int kernel_width,
1072 int ifm_depth,
1073 int* brick_strides,
Mauricio Briceno3e4168d2021-06-09 09:49:05 +02001074 int16_t* inbuf,
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001075 int ofm_block_depth,
1076 int is_depthwise,
1077 int is_partkernel,
1078 int ifm_bitdepth,
1079 int decomp_h,
1080 int decomp_w,
1081 uint8_t **outbuf, // *outbuf must be freed by caller
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001082 int64_t* padded_length,
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001083 int verbose)
1084{
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001085 /* Reorder weights */
1086 int16_t* weights = reorder(
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001087 ifm_ublock_depth,
1088 ofm_ublock_depth,
1089 ofm_depth,
1090 kernel_height,
1091 kernel_width,
1092 ifm_depth,
1093 brick_strides,
1094 inbuf,
1095 ofm_block_depth,
1096 is_depthwise,
1097 is_partkernel,
1098 ifm_bitdepth,
1099 decomp_h,
1100 decomp_w,
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001101 padded_length);
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001102
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001103 /* Then encode */
1104 int output_length = 0;
Fredrik Svedberg0e938a32021-05-20 11:13:00 +02001105 if (*padded_length > 0 && *padded_length <= INT32_MAX)
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001106 {
Fredrik Svedberg0e938a32021-05-20 11:13:00 +02001107 output_length = mlw_encode(weights, (int)*padded_length, outbuf, verbose);
Fredrik Svedberg93d5c352021-05-11 13:51:47 +02001108 }
1109 reorder_free(weights);
Mauricio Briceno67e11f72021-05-05 12:47:28 +02001110
1111 return output_length;
1112}