blob: 62e8360e7863154098ef19227f59d3d05a392f5b [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001/*
2 * Copyright (c) 2020 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <stdio.h>
20#include <stdlib.h>
21#include <stdint.h>
22#include <stdbool.h>
23#include <string.h>
24#include <assert.h>
25#include <math.h>
26#include <stdarg.h>
27#include <math.h>
28#include "mlw_common.h"
29#include "mlw_encode.h"
30
31#define DPRINTF(...)
32//#define DPRINTF(...) printf(__VA_ARGS__)
33
34#define ZERO_RUN_THRES 4
35
Fredrik Svedberg5b513882020-12-11 13:42:22 +010036#ifndef min
Tim Hall79d07d22020-04-27 18:20:16 +010037#define min(a,b) ((a)<(b)?(a):(b))
Fredrik Svedberg5b513882020-12-11 13:42:22 +010038#endif
39#ifndef max
Tim Hall79d07d22020-04-27 18:20:16 +010040#define max(a,b) ((a)>(b)?(a):(b))
Fredrik Svedberg5b513882020-12-11 13:42:22 +010041#endif
Tim Hall79d07d22020-04-27 18:20:16 +010042
43typedef struct palette {
44 int16_t lut[32];
45 int16_t inv_lut[512];
46 int palsize; // number of palette entries
47 int palbits; // bit width of palette entries
48 int use_zero_runs; // zeros are coded separately
49 int only_palette; // no values outside the palette
50 int direct_offset; // added to the decoded weight index before direct conversion to sign/mag
51 int only_zeros; // special case that the section is all zeros
52} palette_t;
53
54static int is_power_of_two( int x ) {
55 return ((x-1) & x)==0;
56}
57
58static int get_palette_index_bits( int size ) {
59 int i;
60 for(i=7; i>=0; i--)
61 if (size > (1<<i) )
62 return i+1;
63 return 0;
64}
65
66// Search the stream for suitable palette restart positions
67// Return the number of restarts
68static int search_palette_sections( int16_t *buf, int size, int **palette_restart_positions ) {
69 int i,j,got_palette,restart_i,palette_size=0, last_restart_idx, zero_cnt;
70 int prev_idx[512]; // For each value, keep track of the index of the previous occurence
71 int *restart_pos;
Tim Hallbb330392020-04-29 15:13:00 +010072 int max_palettes = (size+63)/64;
Tim Hall79d07d22020-04-27 18:20:16 +010073
74 // Preliminary allocation of sufficient size
75 restart_pos = (int*)malloc( max_palettes*sizeof(int) );
76 last_restart_idx=0;
77 got_palette=0;
78 restart_i=1;
79 restart_pos[0] = 0;
80 zero_cnt=0;
81 memset( prev_idx, -1, sizeof(prev_idx));
82 for(i=0; i<size; i++) {
83 // Guess if zeros should be excluded from the palette
84 int exclude_zero = zero_cnt > (i-last_restart_idx)/4;
85
86 if (got_palette) {
87 // Check if the next value is not covered by the current palette
88 if ( prev_idx[ buf[i]+256 ] < last_restart_idx ) {
89 // New value: increase the palette size
90 palette_size++;
91 DPRINTF("Note: at pos %d extend palette to size %d\n", i, palette_size);
92 if ( is_power_of_two(palette_size-1-exclude_zero) ) {
93 if ( (i - last_restart_idx - zero_cnt) > 512 || (palette_size-exclude_zero)>32 ) {
94 // create a new palette because we extend a long lasting palette to require one more index bit
95 DPRINTF("Note: at pos %d create new palette because previous has to increase one more index bit. last_restart_idx %d n %d zero_cnt %d\n", i, last_restart_idx, i - last_restart_idx, zero_cnt );
Tim Hallbb330392020-04-29 15:13:00 +010096 if (restart_i == max_palettes) {
97 max_palettes = max_palettes*2;
98 restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
99 }
Tim Hall79d07d22020-04-27 18:20:16 +0100100 DPRINTF("restart %d pos %d\n", restart_i, i);
101 restart_pos[restart_i++] = i;
102 last_restart_idx = i;
103 got_palette=0;
104 zero_cnt=0;
105 }
106 }
107 }
108 }
109
110 prev_idx[ buf[i]+256 ] = i;
111 if (buf[i]==0)
112 zero_cnt++;
113
114 static const int window_sizes[5][2] = {{32,1}, {64,1}, {128,1}, {256,1}, {512,1}};
115 int k;
116 // loop over window sizes
117 for(k=0; k<5; k++) {
118 // Every Nth non-zero value, count what would be the size of a palette covering the last N NZ.
119 int N = window_sizes[k][0] * (got_palette?2:1);
120 if ( (i - last_restart_idx - zero_cnt) > 0 && ((i - last_restart_idx - zero_cnt) % N)==0 ) {
121 // Search backward to the position N nonzero values earlier
122 int nzcnt=0;
123 for( j=i; j>last_restart_idx; j--) {
124 if ( buf[j]!=0 ) {
125 if (nzcnt==N+1)
126 break;
127 nzcnt++;
128 }
129 }
130 int restart_idx = j;
131
132 // Calculate the size of a new palette (starting at restart_idx)
133 int new_palette_size=0;
134 for(j=0; j<512; j++) {
135 if ( prev_idx[j] >= restart_idx ) {
136 new_palette_size++;
137 }
138 }
139
140 int create_new_palette=0;
141 if (got_palette) {
142 int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero );
143 int old_size_bits = get_palette_index_bits( palette_size - exclude_zero );
144 int savings = N*(old_size_bits*15-new_size_bits*15)/16 - new_palette_size*8 - 20;
145 if ( savings>0 ) {
146 // Create new palette because it can be smaller than the existing palette
147 create_new_palette=1;
148 DPRINTF("Note: at pos %d restart smaller palette\n", restart_idx);
149 }
150 } else {
151 if ( (new_palette_size-exclude_zero) <= 32) {
152 int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero );
153 // estimate if we will make savings by using palette mode
154 int savings = N*(90-new_size_bits*15)/16 - new_palette_size*8 - 20;
155 create_new_palette = savings>0;
156 }
157 }
158 if (create_new_palette) {
159 palette_size=new_palette_size;
160 got_palette=1;
161 last_restart_idx = restart_idx;
162 DPRINTF("Note: at pos %d create palette of size %d\n", last_restart_idx, new_palette_size);
163 if ( restart_pos[restart_i-1] != last_restart_idx) {
Tim Hallbb330392020-04-29 15:13:00 +0100164 if (restart_i == max_palettes) {
165 max_palettes = max_palettes*2;
166 restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
167 }
Tim Hall79d07d22020-04-27 18:20:16 +0100168 restart_pos[restart_i++] = last_restart_idx;
169 }
170 zero_cnt=0;
171 for( j=last_restart_idx; j<=i; j++)
172 if (buf[j]==0)
173 zero_cnt++;
174 }
175 }
176 }
177 }
178 // Reallocate to actual size
179 *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) );
180 return restart_i;
181}
182
183// Calculate frequency table
184static void calc_freq( const int16_t *buf, int size, int freq[512] ) {
185 int i;
186 memset(freq, 0, 512*sizeof(int));
187 for(i=0; i<size; i++) {
188 freq[buf[i]+256]++;
189 }
190}
191
192static int cmp_uint64(const void * a, const void * b) {
193 uint64_t aa = *(uint64_t*)a;
194 uint64_t bb = *(uint64_t*)b;
195 return aa>bb ? -1 : aa<bb ? 1 : 0;
196}
197
198// Create palette from the given frequencies
199// Freq index 0-511 correspond to weights -256..255
200static void create_palette( int freq[512],
201 int use_zero_runs,
202 palette_t *p ) {
203 uint64_t freq64[512];
204 int i,all_cnt,all_max_val;
205
206 // Pair the frequency with the value so that
207 // the array can be sorted on frequency while keeping
208 // track of the corresponding palette value
209 memset(freq64, 0, sizeof(freq64));
210 all_cnt=0;
211 all_max_val=0;
212 for(i=-255; i<256; i++) {
213 if (i==0 && use_zero_runs)
214 continue;
215 int sign = i<0;
216 int mag = abs(i);
217 int palval = (mag<<1) | sign;
218
219 // Store palette value in 16 LSB bits, which will not affect the sorting
220 freq64[palval] = (((uint64_t)freq[i+256])<<16) | palval;
221 all_cnt+=freq[i+256];
222
223 if (freq[i+256]>0) {
224 all_max_val = max(all_max_val, palval);
225 }
226 }
227
228 // Count number of non-used weight values around zero (0, -1, +1, -2, +2 etc)
229 for(i=0; i<31; i++) {
230 if ((freq64[i]>>16)!=0)
231 break;
232 }
233 p->direct_offset = i;
234
235 // Sort in descending frequency order
236 qsort(freq64, 512, sizeof(uint64_t), cmp_uint64);
237
238 // Identify special case that there are no weights to code
239 // in the weight index stream (i.e. all weights are zeros)
240 p->only_zeros = (freq64[0]>>16)==0;
241 if (p->only_zeros) {
242 p->direct_offset=0;
243 }
244
245 // Check if all weights fit into the palette (and the palette is not empty)
246 p->only_palette = (freq64[0]>>16)>0 && (freq64[32]>>16)==0;
247
248 int max_palette_size;
249 if (p->only_palette) {
250 max_palette_size = 32;
251 } else {
252 // For direct-lut we must make sure that the encoded weight
253 // index is not > 511. We do that by limiting the palette size
254 // such that the greatest value can be reached after subtracting
255 // the palette size.
256 max_palette_size = min(32, 511-all_max_val);
257 if (max_palette_size==1) {
258 max_palette_size=0; // because palette of size 1 is not supported
259 }
260 }
261
262 // Setup the 32 entry palette
263 int palette_max_val = 0, val, cnt, pal_cnt=0;
264 for(i=0; i<max_palette_size; i++) {
Fredrik Svedberg5b513882020-12-11 13:42:22 +0100265 cnt = (int)(freq64[i]>>16);
Tim Hall79d07d22020-04-27 18:20:16 +0100266 val = freq64[i]&0xffff;
267 if ( cnt==0 )
268 break;
269 p->lut[i] = val;
270 palette_max_val = max(palette_max_val, val);
271 pal_cnt+=cnt;
272 }
273 if (i==1)
274 i++; // palette size of 1 is not supported, make it 2
275
276 // Heuristic for when to use the palette. If more than half of the
277 // weights are in the palette then we use it. This ensures we don't
278 // use palette for e.g. rectangular distributions.
279 int palbits_val;
280 if (pal_cnt > all_cnt/2) {
281 p->palsize = i;
282 palbits_val = palette_max_val;
283 } else {
284 // No palette
285 p->palsize = 0;
286 // If no palette, then palbits is used to specify the
287 // number of bits required for uncompressed mode, i.e.
288 // the number of bits for the greatest weight value
289 palbits_val = all_max_val;
290 }
291
292 // the palette entry bit width
293 // minimum 2bits (because PALBITS is in range 2..9)
294 int palbits=2;
295 while( (1<<palbits) <= palbits_val )
296 palbits++;
297 assert(palbits<=9);
298 p->palbits = palbits;
299 p->use_zero_runs = use_zero_runs;
300}
301
302// Return 1 if zero runs should be used
303// If palette_size is 512, then palette is not used (in that case the palette is setup
304// with the standard alternating unsigned to signed mapping)
305static int find_palette( const int16_t *inbuf, int inbuf_size, palette_t *p) {
306 int freq[512], i;
307
308 // Calculate frequencies of the given weight stream
309 calc_freq( inbuf, inbuf_size, freq);
310
311 // Find two most common values
312 int most_common_freq[2]={0}, most_common_val[2]={0};
313 for(i=0; i<512; i++) {
314 if ( freq[i] > most_common_freq[0] ) {
315 most_common_freq[1] = most_common_freq[0];
316 most_common_val[1] = most_common_val[0];
317 most_common_freq[0] = freq[i];
318 most_common_val[0] = i-256;
319 } else if ( freq[i] > most_common_freq[1] ) {
320 most_common_freq[1] = freq[i];
321 most_common_val[1] = i-256;
322 }
323 }
324
325 // Decide if zero-runs (alternating mode) should be used:
326 // * zero should be the most common symbol
327 // * zero should be sufficiently more common than the second most common symbol
328 int use_zero_runs = most_common_val[0]==0 && most_common_freq[0] > ZERO_RUN_THRES*most_common_freq[1];
329
330 // Create the palette
331 create_palette( freq, use_zero_runs, p);
332
333 return use_zero_runs;
334}
335
336static void create_inverse_palette( palette_t *p) {
337 int i;
338 memset( p->inv_lut, 0, sizeof(p->inv_lut));
339 for(i=0; i<512; i++) {
340 int val = i;
341 int sign = val&1;
342 int mag = val>>1;
343 int weight = sign ? -mag : mag;
344 if (weight+256 < 512)
345 p->inv_lut[ weight+256 ] = i + p->palsize - p->direct_offset;
346 }
347 for(i=0; i<p->palsize; i++) {
348 int val = p->lut[i];
349 int sign = val&1;
350 int mag = val>>1;
351 int weight = sign ? -mag : mag;
352 if (weight+256 < 512)
353 p->inv_lut[ weight+256 ] = i;
354 }
355}
356
357#define NWCFG 13
358#define NZCFG 4 // restrict search to ZDIV=0..3
359#define MAX_ZWCFG (max(NWCFG,NZCFG))
360
361// search state
362typedef struct search_state {
363 int bitcnt; // number of bits to reach this state
364 uint8_t prev_cfg; // previous grc parameter config
365} search_state_t;
366
367// (trunc<<4) | div, 0x20 means uncompressed
368static const char w_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20 };
369static const char z_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04 };
370
371
372
373// An algorithm similar to the Viterbi algorithm is used to search for a
374// good GRC parameter sequence for the given input value sequence.
375// The inval buffer can contain weights, weight indices or runs.
376// The return value is the resulting number of bitstream sections.
377static int search_grc_params( const int *inval_buf,
378 int n_inval,
379 int zrun_mode,
380 int uncompressed_bits,
381 uint8_t *grc_param_cfg,
382 int *grc_param_pos,
383 int max_grc_param_cfg,
384 int *existing_grc_param_pos,
385 int n_existing_grc_param_pos,
386 int *bitcnt )
387{
388 int n_cfg = zrun_mode ? NZCFG : NWCFG;
389 const char *grc_params = zrun_mode ? z_grc_params : w_grc_params;
390 int i,j;
391
392 search_state_t *state[MAX_ZWCFG];
393 for(i=0; i<n_cfg; i++) {
394 state[i] = malloc( sizeof(search_state_t) * (n_inval+1) );
395 state[i][0].bitcnt=0;
396 state[i][0].prev_cfg=i;
397 }
398
399 // Loop over inval_buf
400 int existing_idx=0;
401 for(i=0; i<n_inval; i++) {
402 int value = inval_buf[i];
403
404 // Best GRC parameter so far
405 int best_bitcnt=0x7fffffff, best_cfg=0;
406 for(j=0; j<n_cfg; j++) {
407 if (state[j][i].bitcnt < best_bitcnt) {
408 best_bitcnt = state[j][i].bitcnt;
409 best_cfg = j;
410 }
411 }
412
413 int cmd_cost = 40;
414 if (existing_idx < n_existing_grc_param_pos && existing_grc_param_pos[existing_idx] == (i+1)) {
415 // free transition, because the weight stream already inserted a command at this position
416 cmd_cost = 0;
417 existing_idx++;
418 }
419
420 // Loop over GRC parameters, calculate bits to code value, and then update the search state
421 for(j=0; j<n_cfg; j++) {
422 int div = grc_params[j]&15;
423 int trunc = grc_params[j]>>4;
424 int q = value>>div;
425 int bits = trunc ? min(q+1,2) + div : q+1+div;
426 if (!zrun_mode && ((trunc && q>2) || q>31))
427 bits=10000; // it's not possible to code the current value; give it a high cost
428 if (trunc==2)
429 bits=uncompressed_bits;
430
431 if ( best_bitcnt + cmd_cost < state[j][i].bitcnt ) {
432 // Change GRC parameters
433 state[j][i+1].prev_cfg = best_cfg;
434 state[j][i+1].bitcnt = best_bitcnt + cmd_cost + bits;
435 } else {
436 // Keep same GRC parameters
437 state[j][i+1].prev_cfg = j;
438 state[j][i+1].bitcnt = state[j][i].bitcnt + bits;
439 }
440 }
441 }
442
443
444 // Best GRC parameter
445 int best_bitcnt=0x7fffffff, best_cfg=0;
446 for(j=0; j<n_cfg; j++) {
447 if (state[j][n_inval].bitcnt < best_bitcnt) {
448 best_bitcnt = state[j][n_inval].bitcnt;
449 best_cfg = j;
450 }
451 }
452
453 int cfg = best_cfg;
454 int n_cmds=0;
455 for(i=n_inval; i>=0; i--) {
456 if (state[cfg][i].prev_cfg != cfg || i==0) {
457 n_cmds++;
458 cfg = state[cfg][i].prev_cfg;
459 }
460 }
461
462 (void)(max_grc_param_cfg);
463 assert(n_cmds<=max_grc_param_cfg);
464
465 cfg = best_cfg;
466 j=n_cmds-1;
467 int endpos=n_inval;
468 for(i=n_inval; i>=0; i--) {
469 if (state[cfg][i].prev_cfg != cfg || i==0) {
470 grc_param_cfg[j] = cfg;
471 grc_param_pos[j] = endpos;
472 j--;
473 cfg = state[cfg][i].prev_cfg;
474 endpos = i-1;
475 }
476 }
477 assert(j==-1);
478
479 for(i=0; i<n_cfg; i++) {
480 free(state[i]);
481 }
482
483 *bitcnt = best_bitcnt;
484
485 return n_cmds;
486}
487
488
489/////////////////////////////// Write to bitstream
490
491typedef struct bitbuf {
492 uint8_t *buf;
493 int buf_size; // in bytes
494 int pos; // bit pos of next bit
495 int log_symbols;
496} bitbuf_t;
497
498// size in byte
499static void bitbuf_init( bitbuf_t *bb, uint8_t *buf, int size, int log_symbols ) {
500 bb->buf = buf;
501 bb->pos = 0;
502 bb->buf_size = size;
503 bb->log_symbols = log_symbols;
504}
505
506static void bitbuf_putbit( bitbuf_t *bb, int bit) {
507 int byte_pos = bb->pos>>3;
508 int bit_pos = bb->pos&7;
509 assert( byte_pos >= 0 );
510 assert( byte_pos < bb->buf_size );
511 bb->buf[ byte_pos ] = (bb->buf[ byte_pos ] & ~(1<<bit_pos)) | (bit<<bit_pos);
512 bb->pos += 1;
513}
514
515static void bitbuf_put( bitbuf_t *bb, const char *name, int len, int data) {
516 int i;
517 if (len>0) {
518 if (bb->log_symbols)
519 printf("bitbuf: pos %3d %7s len %d data %x\n", bb->pos, name, len, data);
520 for(i=0; i<len; i++) {
521 bitbuf_putbit(bb, (data>>i)&1);
522 }
523 }
524}
525
526// Return new bitpos
527static int encode_slice( const int *w_value,
528 const int *z_value,
529 int nvalues,
530 palette_t *p,
531 int new_palette,
532 int uncompressed_bits,
533 int w_cfg,
534 int z_cfg,
535 uint8_t *bitbuf,
536 int bitbuf_size,
537 int bitpos,
538 int verbose )
539{
540 int i,j;
541 bitbuf_t bitbuf_s, *bb=&bitbuf_s;
542 bitbuf_init( bb, bitbuf, bitbuf_size, verbose&2?1:0 );
543 bb->pos = bitpos;
544
545 assert(nvalues<32768);
546 // GRC parameters for this slice
547 int w_grc_div = w_grc_params[w_cfg] & 15;
548 int w_grc_trunc = (w_grc_params[w_cfg] >> 4)==1;
549 int w_uncompressed = (w_grc_params[w_cfg] >> 4)==2;
550 int z_grc_div = z_grc_params[z_cfg] & 15;
551
552 if (w_uncompressed) {
553 w_grc_div = uncompressed_bits;
554 }
555
556 int zdiv = p->use_zero_runs ? z_grc_div : ZDIV_DISABLE;
557 int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED;
558
559 if (verbose&1) {
560 printf("slice: bitoffset %7d slicelen %5d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %2d\n",
561 bb->pos, nvalues, zdiv, wdiv, w_grc_trunc, new_palette, p->palbits, p->palsize);
562 }
563
564 // Write slice header
565 bitbuf_put( bb, "ZDIV", 3, zdiv);
566 bitbuf_put( bb, "SLICELEN", 15, nvalues-1 );
567 bitbuf_put( bb, "WDIV", 3, wdiv);
568 bitbuf_put( bb, "WTRUNC", 1, w_grc_trunc );
569 bitbuf_put( bb, "NEWPAL", 1, new_palette );
570 if (new_palette) {
571 bitbuf_put( bb, "DIROFS", 5, p->direct_offset );
572 bitbuf_put( bb, "PALSIZE", 5, max(0, p->palsize-1));
573 bitbuf_put( bb, "PALBITS", 3, p->palbits-2 );
574 for(i=0; i<p->palsize; i++) {
575 bitbuf_put( bb, "PALETTE", p->palbits, p->lut[i] );
576 }
577 }
578
579 int z_nvalues = nvalues + (new_palette?1:0);
580 int w_pos=0, z_pos=0;
581 int w_unary0=0, w_unary1=0, w_unary1_len=0, w_q=-1, w_r=0;
582 int z_unary=0, z_q=-1, z_r=0;
583 int w_nsymbols=0, w_remain[12]={0};
584 int w_prev_enable=0, w_prev_nsymbols=0, w_prev_remain[12]={0};
585 int z_nsymbols=0, z_remain[12]={0};
586 int z_prev_enable=0, z_prev_nsymbols=0, z_prev_remain[12]={0};
587 int z_unary_len = z_grc_div<3 ? 12 : 8;
588 do {
589 int balance = p->use_zero_runs ? w_pos - z_pos : 0;
590 int w_enable = balance<8 && w_pos<nvalues;
591 int z_enable = balance>=0 && p->use_zero_runs && z_pos<z_nvalues;
592 if (w_enable) {
593 // Encode chunk (weights)
594 j=0;
595 w_nsymbols=0;
596 w_unary0=0;
597 w_unary1=0;
598 w_unary1_len=0;
599 int max_symbols = w_uncompressed && w_grc_div>5 ? 8 : 12;
600 while(j<max_symbols) {
601 if (w_q<0) {
602 if (w_pos<nvalues) {
603 int value = w_value[w_pos];
604 assert(value<512);
605 w_q = value>>w_grc_div;
606 w_r = value&((1<<w_grc_div)-1);
607 assert( w_q<=31 && (!w_grc_trunc || w_q<=2));
608 } else {
609 w_q = 0;
610 w_r = -1; // don't send remainder
611 }
612 }
613 while( w_q>=0 && j<max_symbols) {
614 w_unary0 |= w_q>0 ? (1<<j) : 0;
615 if (w_q>0) {
616 w_unary1 |= w_q>1 ? (1<<w_unary1_len) : 0;
617 w_unary1_len++;
618 }
619 j++;
620 w_q-=2;
621 if (w_grc_trunc)
622 w_q--;
623 }
624 if (w_q<0 && w_r>=0) {
625 w_remain[w_nsymbols] = w_r;
626 w_nsymbols++;
627 w_pos++;
628 }
629 }
630 }
631
632 if (z_enable) {
633 // Encode chunk (zrun)
634 j=0;
635 z_nsymbols=0;
636 z_unary=0;
637 while(j<z_unary_len) {
638 if (z_q<0) {
639 if (z_pos<z_nvalues) {
640 int value = z_value[z_pos];
641 z_q = value>>z_grc_div;
642 z_r = value&((1<<z_grc_div)-1);
643 } else {
644 z_q = 0;
645 z_r = -1;
646 }
647 }
648 while( z_q>=0 && j<z_unary_len) {
649 z_unary |= z_q>0 ? (1<<j) : 0;
650 j++;
651 z_q--;
652 }
653 if (z_q<0 && z_r>=0) {
654 z_remain[z_nsymbols] = z_r;
655 z_nsymbols++;
656 z_pos++;
657 }
658 }
659 }
660
661 // Write chunk to bitstream
662 if (w_enable && !w_uncompressed) {
663 bitbuf_put( bb, "WUNARY0", 12, w_unary0);
664 }
665 if (z_enable) {
666 bitbuf_put( bb, "ZUNARY", z_unary_len, z_unary);
667 }
668 if (w_enable && !w_uncompressed) {
669 bitbuf_put( bb, "WUNARY1", w_unary1_len, w_unary1);
670 }
671 if (w_prev_enable) {
672 for(i=0; i<w_prev_nsymbols; i++) {
673 bitbuf_put( bb, "WREMAIN", w_grc_div, w_prev_remain[i]);
674 }
675 }
676 if (z_prev_enable) {
677 for(i=0; i<z_prev_nsymbols; i++) {
678 bitbuf_put( bb, "ZREMAIN", z_grc_div, z_prev_remain[i]);
679 }
680 }
681 w_prev_enable = w_enable;
682 w_prev_nsymbols = w_nsymbols;
683 memcpy( w_prev_remain, w_remain, sizeof(w_prev_remain));
684 z_prev_enable = z_enable;
685 z_prev_nsymbols = z_nsymbols;
686 memcpy( z_prev_remain, z_remain, sizeof(z_prev_remain));
687 } while( w_prev_enable || z_prev_enable );
688
689 return bb->pos;
690}
691
692
693// return new bitpos
694static int encode_section( const int16_t *inbuf,
695 int size,
696 palette_t *p,
697 uint8_t *bitbuf,
698 int bitbuf_size,
699 int bitpos,
700 int verbose )
701{
702 int uncompressed_bits;
703
704 // Uncompressed mode can only be used if either all weights
705 // are in the palette OR if the palette is not used.
706 if (p->only_palette) {
707 // Uncompressed bits derived from palette size
708 uncompressed_bits=0;
709 while( (1<<uncompressed_bits) < p->palsize )
710 uncompressed_bits++;
711 } else if (p->palsize==0) {
712 // Uncompressed bits is palbits (which is the bitdepth of the greatest weight)
713 uncompressed_bits = p->palbits;
714 } else {
715 // Don't use uncompressed
716 uncompressed_bits = 100;
717 }
718
719 int *weight_values = malloc( size*sizeof(int) );
720 int *zrun_values = malloc( size*sizeof(int) );
721
722 // Get weights (or weight indicies) AND zero-runs from the input weight stream.
723 int i=0, n_weights = 0, zcnt;
724 while(1) {
725 if (p->use_zero_runs) {
726 zcnt=0;
727 // Count zero run
728 // Special case: if all weights in the section are zero, we must
729 // still ensure we have one coded weight so the the slice length
730 // doesn't become 0. Therefore we skip the first zero run and code
731 // the zero explicitly as a weight value instead
732 if (!p->only_zeros || i>0) {
733 while( i<size && inbuf[i]==0) {
734 zcnt++;
735 i++;
736 }
737 }
738 zrun_values[n_weights] = zcnt;
739 }
740 if (i==size)
741 break;
742 int value = p->inv_lut[inbuf[i]+256];
743 weight_values[n_weights] = value;
744 n_weights++;
745 i++;
746 }
747
748 // Search for good GRC parameters for the weight stream
749 int n_w_slice, w_bitcnt;
750 uint8_t *w_slice_cfg;
751 int *w_slice_pos;
752 w_slice_cfg = malloc( size );
753 w_slice_pos = malloc( size*sizeof(int) );
754 n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt);
755 if (n_weights==0)
756 n_w_slice = 0;
757
758 // Search for good GRC parameters for the zrun stream
759 int n_z_slice=0, z_bitcnt=0;
760 uint8_t *z_slice_cfg=0;
761 int *z_slice_pos=0;
762 if (p->use_zero_runs) {
763 z_slice_cfg = malloc( size );
764 z_slice_pos = malloc( size*sizeof(int) );
765 n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt);
766 }
767
768 // Encode bitstream slice
769 int pos=0, i_w_slice=0, i_z_slice=0, new_palette=1;
770 while(pos<n_weights || new_palette) {
771 int endpos=pos+32767; // max slice length
772
773 if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]<endpos) {
774 endpos = w_slice_pos[i_w_slice];
775 }
776
777 if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]<endpos) {
778 endpos = z_slice_pos[i_z_slice];
779 }
780
781 if (n_weights < endpos) {
782 endpos = n_weights;
783 }
784
785 // The first slice (when new_palette is 1) encodes zero runs both at the
786 // beginning and end (i.e. number of zero runs are len+1).
787 // The following slices only encode zero runs at the end (there cannot be
788 // any zeros in the beginning since they are encoded by the previous slice)
789 int len = endpos - pos;
790 int *zrun_buf = p->use_zero_runs ? zrun_values+pos+(!new_palette) : 0;
791 bitpos = encode_slice( weight_values+pos, zrun_buf, len,
792 p, new_palette, uncompressed_bits,
793 w_slice_cfg[i_w_slice], p->use_zero_runs ? z_slice_cfg[i_z_slice] : 0,
794 bitbuf, bitbuf_size, bitpos, verbose );
795 new_palette = 0;
796
797 if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]==endpos) {
798 i_w_slice++;
799 }
800 if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]==endpos) {
801 i_z_slice++;
802 }
803 pos = endpos;
804 }
805
806 // Free temporary buffers
807 free(w_slice_cfg);
808 free(w_slice_pos);
809 if (p->use_zero_runs) {
810 free(z_slice_cfg);
811 free(z_slice_pos);
812 }
813 free(weight_values);
814 free(zrun_values);
815
816 return bitpos;
817}
818
819// Encode the given weight stream
820// inbuf uncompressed 9bit signed weights
821// inbuf_size number of weights
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200822// outbuf compressed bitstream, buffer is malloced within this function
Tim Hall79d07d22020-04-27 18:20:16 +0100823// verbose if non-zero, printf log
824// Return value is the size in bytes of the compressed output
825// Return -1 if error
826int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
827 int i;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200828#ifndef NDEBUG
Tim Hall79d07d22020-04-27 18:20:16 +0100829 // Range check
830 for(i=0; i<inbuf_size; i++) {
831 if (inbuf[i]<-255 || inbuf[i]>255) {
832 printf("ERROR: weight out of range at index %d, weight value is %d (valid range is -255..255)\n", i, inbuf[i]);
833 return -1;
834 }
835 }
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200836#endif
Tim Hall79d07d22020-04-27 18:20:16 +0100837
838 int bitbuf_size = inbuf_size*2+1024;
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200839 assert(*outbuf == NULL);
Tim Hall79d07d22020-04-27 18:20:16 +0100840 *outbuf = malloc( bitbuf_size );
841
842 // Analyse input data to find palette re-programming points
843 int n_restarts;
844 int *palette_restart_pos;
845 n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos);
846
847 // Compress each section (using a single palette) separately
848 int bitpos=0;
849 for(i=0; i<n_restarts; i++) {
850 palette_t palette;
851 int pos, size;
852 pos = palette_restart_pos[i];
853 size = (i<n_restarts-1 ? palette_restart_pos[i+1] : inbuf_size) - pos;
854 find_palette( inbuf+pos, size, &palette);
855 create_inverse_palette( &palette);
856 bitpos = encode_section( inbuf+pos, size, &palette,
857 *outbuf, bitbuf_size, bitpos, verbose );
858 }
859
860
861 // Add end of stream marker and align to 128bit
862 {
863 bitbuf_t bitbuf_s, *bb=&bitbuf_s;
864 bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 );
865 bb->pos = bitpos;
866 bitbuf_put( bb, "ZDIV", 3, ZDIV_EOS);
867 bitbuf_put( bb, "BYTEALIGN", (8-(bb->pos&7))&7, 0xff );
868
869 // Pad with 0xff until 64bit aligned
870 while( bb->pos & 127 ) {
871 bitbuf_put( bb, "PAD", 8, 0xff );
872 }
873 bitpos = bb->pos;
874 }
875 assert((bitpos&127)==0);
876 int outbuf_size = bitpos/8;
877 *outbuf = realloc( *outbuf, outbuf_size);
878
879 free(palette_restart_pos);
880
881 return outbuf_size;
882}
883
884void mlw_free_outbuf( uint8_t *outbuf ) {
885 if (outbuf)
886 free(outbuf);
887}
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200888
889static int round_up_divide(int num, int den)
890{
891 return (num + den - 1) / den;
892}
893
894static int round_up(int num, int den)
895{
896 return round_up_divide(num, den) * den;
897}
898
899static int get_weight_cnt(
900 int ifm_ublock_depth,
901 int ofm_ublock_depth,
902 int ofm_depth,
903 int kernel_height,
904 int kernel_width,
905 int ifm_depth,
906 int ofm_block_depth,
907 int is_depthwise,
908 int is_partkernel,
909 int ifm_bitdepth,
910 int decomp_h,
911 int decomp_w)
912{
913 int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
914 int subkernel_elements = decomp_w * decomp_h;
915 if (is_partkernel)
916 {
917 if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
918 {
919 subkernel_elements = round_up(subkernel_elements, 2);
920 }
921 else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
922 {
923 subkernel_elements = round_up(subkernel_elements, 4);
924 }
925 }
926 else if (is_depthwise)
927 {
928 subkernel_elements = round_up(subkernel_elements, 4);
929 }
930 int clipped_ifm_block_depth = is_depthwise ? ifm_ublock_depth : ifm_block_depth;
931 int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
932 int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
933
934 int input_length = 1;
935 input_length *= is_depthwise ? 1 : ifm_ublock_depth;
936 input_length *= ofm_ublock_depth;
937 input_length *= round_up_divide(ifm_block_depth_inner, ifm_ublock_depth);
938 input_length *= subkernel_elements;
939 input_length *= round_up_divide(ofm_block_depth, ofm_ublock_depth);
940 input_length *= round_up_divide(ifm_block_depth_outer, ifm_ublock_depth);
941 input_length *= round_up_divide(kernel_width, decomp_w);
942 input_length *= round_up_divide(kernel_height, decomp_h);
943 input_length *= round_up_divide(is_depthwise ? 1 : ifm_depth, ifm_block_depth);
944 input_length *= round_up_divide(ofm_depth, ofm_block_depth);
945
946 return input_length;
947}
948
949struct brick_buf_s
950{
951 uint8_t* buf;
952 int* strides;
953};
954typedef struct brick_buf_s brick_buf_t;
955
956static int16_t get_brick_weight(brick_buf_t* buf, int ofm_z, int wy, int wx, int ifm_z)
957{
958 uint8_t* p = buf->buf;
959
960 p += ofm_z * buf->strides[0];
961 p += wy * buf->strides[1];
962 p += wx * buf->strides[2];
963 p += ifm_z * buf->strides[3];
964
965 return *(int16_t*)p;
966}
967
968static int reorder(
969 int ifm_ublock_depth,
970 int ofm_ublock_depth,
971 int ofm_depth,
972 int kernel_height,
973 int kernel_width,
974 int ifm_depth,
975 int* strides,
976 void* inbuf,
977 int ofm_block_depth,
978 int is_depthwise,
979 int is_partkernel,
980 int ifm_bitdepth,
981 int decomp_h,
982 int decomp_w,
983 int16_t* weights)
984{
985 brick_buf_t brick_buf;
986 brick_buf.buf = inbuf;
987 brick_buf.strides = strides;
988
989 int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
990 int weight_cnt = 0;
991 for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth)
992 {
993 int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z);
994 // IFM blocks required for the brick
995 for (int ifm_block_z = 0; ifm_block_z < (is_depthwise ? 1 : ifm_depth); ifm_block_z += ifm_block_depth)
996 {
997 int clipped_ifm_block_depth;
998 if (is_depthwise)
999 {
1000 clipped_ifm_block_depth = ifm_ublock_depth;
1001 }
1002 else
1003 {
1004 clipped_ifm_block_depth = is_partkernel ?
1005 min(ifm_block_depth, ifm_depth - ifm_block_z) : ifm_block_depth;
1006 }
1007 // Weight decomposition
1008 // Subkernel Splitting (H)
1009 for (int subkernel_y = 0; subkernel_y < kernel_height; subkernel_y += decomp_h)
1010 {
1011 int sub_height = min(kernel_height - subkernel_y, decomp_h);
1012 // Subkernel splitting (W)
1013 for (int subkernel_x = 0; subkernel_x < kernel_width; subkernel_x += decomp_w)
1014 {
1015 int sub_width = min(kernel_width - subkernel_x, decomp_w);
1016 int subkernel_elements = sub_width * sub_height;
1017 // Part kernel first works across the kernel H/W and needs padding
1018 if (is_partkernel)
1019 {
1020 if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
1021 {
1022 subkernel_elements = round_up(subkernel_elements, 2);
1023 }
1024 else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
1025 {
1026 subkernel_elements = round_up(subkernel_elements, 4);
1027 }
1028 }
1029 else if (is_depthwise)
1030 {
1031 subkernel_elements = round_up(subkernel_elements, 4);
1032 }
1033 int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
1034 int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
1035 for (int ifm_ublk_outer = 0; ifm_ublk_outer < ifm_block_depth_outer; ifm_ublk_outer += ifm_ublock_depth)
1036 {
1037 // OFM Ublocks in OFM-block over depth
1038 for (int ofm_ublk = 0; ofm_ublk < clipped_ofm_block_depth; ofm_ublk += ofm_ublock_depth)
1039 {
1040 // HW Kernel element traversal - cannot be a H/W loop due to element
1041 // padding requirement on depthwise/part-kernel configurations
1042 for (int element = 0; element < subkernel_elements; element++)
1043 {
1044 int kx = element % sub_width;
1045 int ky = element / sub_width;
1046 // IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise)
1047 // In case of part-kernel-first IFM Ublock traversal have already been handled
1048 // and this loop is ignored.
1049 for (int ifm_ublk_inner = 0; ifm_ublk_inner < ifm_block_depth_inner; ifm_ublk_inner += ifm_ublock_depth)
1050 {
1051 // Feed OFM ublock elements
1052 for (int ofm_ublock_z = 0; ofm_ublock_z < ofm_ublock_depth; ofm_ublock_z++)
1053 {
1054 // Source IFM ublock elements (only 1 element deep if depthwise)
1055 for (int ifm_ublock_z = 0; ifm_ublock_z < (is_depthwise ? 1 : ifm_ublock_depth); ifm_ublock_z++)
1056 {
1057 // Source position within the current subkernel
1058 int wx = subkernel_x + kx;
1059 int wy = subkernel_y + ky;
1060 // Source IFM/OFM slices
1061 int ifm_ublk = ifm_ublk_inner + ifm_ublk_outer;
1062 int ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z;
1063 int ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z;
1064 if ((ifm_z < ifm_depth) && (ofm_z < ofm_depth) && (ky < sub_height))
1065 {
1066 weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z);
1067 }
1068 weight_cnt++;
1069 }
1070 }
1071 }
1072 }
1073 }
1074 }
1075 }
1076 }
1077 }
1078 }
1079
1080 return weight_cnt;
1081}
1082
1083// Reorder and encode the given weight stream
1084// Return value is the size in bytes of the compressed output
1085// Return -1 if error
1086int mlw_reorder_encode(
1087 int ifm_ublock_depth,
1088 int ofm_ublock_depth,
1089 int ofm_depth,
1090 int kernel_height,
1091 int kernel_width,
1092 int ifm_depth,
1093 int* brick_strides,
1094 void* inbuf,
1095 int ofm_block_depth,
1096 int is_depthwise,
1097 int is_partkernel,
1098 int ifm_bitdepth,
1099 int decomp_h,
1100 int decomp_w,
1101 uint8_t **outbuf, // *outbuf must be freed by caller
1102 int* padded_length,
1103 int verbose)
1104{
1105 /* Get an upper bound of the weight count */
1106 int input_length = get_weight_cnt(
1107 ifm_ublock_depth,
1108 ofm_ublock_depth,
1109 ofm_depth,
1110 kernel_height,
1111 kernel_width,
1112 ifm_depth,
1113 ofm_block_depth,
1114 is_depthwise,
1115 is_partkernel,
1116 ifm_bitdepth,
1117 decomp_h,
1118 decomp_w);
1119
1120 int16_t* weights = (int16_t*)calloc(input_length, sizeof(int16_t));
1121 if (weights == NULL)
1122 {
1123 return 0;
1124 }
1125
1126 /* Reorder weights and update input_length */
1127 input_length = reorder(
1128 ifm_ublock_depth,
1129 ofm_ublock_depth,
1130 ofm_depth,
1131 kernel_height,
1132 kernel_width,
1133 ifm_depth,
1134 brick_strides,
1135 inbuf,
1136 ofm_block_depth,
1137 is_depthwise,
1138 is_partkernel,
1139 ifm_bitdepth,
1140 decomp_h,
1141 decomp_w,
1142 weights);
1143
1144 int output_length = mlw_encode(weights, input_length, outbuf, verbose);
1145 free(weights);
1146 *padded_length = input_length;
1147
1148 return output_length;
1149}