Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (c) 2020 Arm Limited. All rights reserved. |
| 3 | * |
| 4 | * SPDX-License-Identifier: Apache-2.0 |
| 5 | * |
| 6 | * Licensed under the Apache License, Version 2.0 (the License); you may |
| 7 | * not use this file except in compliance with the License. |
| 8 | * You may obtain a copy of the License at |
| 9 | * |
| 10 | * www.apache.org/licenses/LICENSE-2.0 |
| 11 | * |
| 12 | * Unless required by applicable law or agreed to in writing, software |
| 13 | * distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 14 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | * See the License for the specific language governing permissions and |
| 16 | * limitations under the License. |
| 17 | */ |
| 18 | |
| 19 | #include <stdio.h> |
| 20 | #include <stdlib.h> |
| 21 | #include <stdint.h> |
| 22 | #include <stdbool.h> |
| 23 | #include <string.h> |
| 24 | #include <assert.h> |
| 25 | #include <math.h> |
| 26 | #include <stdarg.h> |
| 27 | #include <math.h> |
| 28 | #include "mlw_common.h" |
| 29 | #include "mlw_encode.h" |
| 30 | |
| 31 | #define DPRINTF(...) |
| 32 | //#define DPRINTF(...) printf(__VA_ARGS__) |
| 33 | |
| 34 | #define ZERO_RUN_THRES 4 |
| 35 | |
Fredrik Svedberg | 5b51388 | 2020-12-11 13:42:22 +0100 | [diff] [blame] | 36 | #ifndef min |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 37 | #define min(a,b) ((a)<(b)?(a):(b)) |
Fredrik Svedberg | 5b51388 | 2020-12-11 13:42:22 +0100 | [diff] [blame] | 38 | #endif |
| 39 | #ifndef max |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 40 | #define max(a,b) ((a)>(b)?(a):(b)) |
Fredrik Svedberg | 5b51388 | 2020-12-11 13:42:22 +0100 | [diff] [blame] | 41 | #endif |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 42 | |
| 43 | typedef struct palette { |
| 44 | int16_t lut[32]; |
| 45 | int16_t inv_lut[512]; |
| 46 | int palsize; // number of palette entries |
| 47 | int palbits; // bit width of palette entries |
| 48 | int use_zero_runs; // zeros are coded separately |
| 49 | int only_palette; // no values outside the palette |
| 50 | int direct_offset; // added to the decoded weight index before direct conversion to sign/mag |
| 51 | int only_zeros; // special case that the section is all zeros |
| 52 | } palette_t; |
| 53 | |
| 54 | static int is_power_of_two( int x ) { |
| 55 | return ((x-1) & x)==0; |
| 56 | } |
| 57 | |
| 58 | static int get_palette_index_bits( int size ) { |
| 59 | int i; |
| 60 | for(i=7; i>=0; i--) |
| 61 | if (size > (1<<i) ) |
| 62 | return i+1; |
| 63 | return 0; |
| 64 | } |
| 65 | |
| 66 | // Search the stream for suitable palette restart positions |
| 67 | // Return the number of restarts |
| 68 | static int search_palette_sections( int16_t *buf, int size, int **palette_restart_positions ) { |
| 69 | int i,j,got_palette,restart_i,palette_size=0, last_restart_idx, zero_cnt; |
| 70 | int prev_idx[512]; // For each value, keep track of the index of the previous occurence |
| 71 | int *restart_pos; |
Tim Hall | bb33039 | 2020-04-29 15:13:00 +0100 | [diff] [blame] | 72 | int max_palettes = (size+63)/64; |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 73 | |
| 74 | // Preliminary allocation of sufficient size |
| 75 | restart_pos = (int*)malloc( max_palettes*sizeof(int) ); |
| 76 | last_restart_idx=0; |
| 77 | got_palette=0; |
| 78 | restart_i=1; |
| 79 | restart_pos[0] = 0; |
| 80 | zero_cnt=0; |
| 81 | memset( prev_idx, -1, sizeof(prev_idx)); |
| 82 | for(i=0; i<size; i++) { |
| 83 | // Guess if zeros should be excluded from the palette |
| 84 | int exclude_zero = zero_cnt > (i-last_restart_idx)/4; |
| 85 | |
| 86 | if (got_palette) { |
| 87 | // Check if the next value is not covered by the current palette |
| 88 | if ( prev_idx[ buf[i]+256 ] < last_restart_idx ) { |
| 89 | // New value: increase the palette size |
| 90 | palette_size++; |
| 91 | DPRINTF("Note: at pos %d extend palette to size %d\n", i, palette_size); |
| 92 | if ( is_power_of_two(palette_size-1-exclude_zero) ) { |
| 93 | if ( (i - last_restart_idx - zero_cnt) > 512 || (palette_size-exclude_zero)>32 ) { |
| 94 | // create a new palette because we extend a long lasting palette to require one more index bit |
| 95 | DPRINTF("Note: at pos %d create new palette because previous has to increase one more index bit. last_restart_idx %d n %d zero_cnt %d\n", i, last_restart_idx, i - last_restart_idx, zero_cnt ); |
Tim Hall | bb33039 | 2020-04-29 15:13:00 +0100 | [diff] [blame] | 96 | if (restart_i == max_palettes) { |
| 97 | max_palettes = max_palettes*2; |
| 98 | restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); |
| 99 | } |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 100 | DPRINTF("restart %d pos %d\n", restart_i, i); |
| 101 | restart_pos[restart_i++] = i; |
| 102 | last_restart_idx = i; |
| 103 | got_palette=0; |
| 104 | zero_cnt=0; |
| 105 | } |
| 106 | } |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | prev_idx[ buf[i]+256 ] = i; |
| 111 | if (buf[i]==0) |
| 112 | zero_cnt++; |
| 113 | |
| 114 | static const int window_sizes[5][2] = {{32,1}, {64,1}, {128,1}, {256,1}, {512,1}}; |
| 115 | int k; |
| 116 | // loop over window sizes |
| 117 | for(k=0; k<5; k++) { |
| 118 | // Every Nth non-zero value, count what would be the size of a palette covering the last N NZ. |
| 119 | int N = window_sizes[k][0] * (got_palette?2:1); |
| 120 | if ( (i - last_restart_idx - zero_cnt) > 0 && ((i - last_restart_idx - zero_cnt) % N)==0 ) { |
| 121 | // Search backward to the position N nonzero values earlier |
| 122 | int nzcnt=0; |
| 123 | for( j=i; j>last_restart_idx; j--) { |
| 124 | if ( buf[j]!=0 ) { |
| 125 | if (nzcnt==N+1) |
| 126 | break; |
| 127 | nzcnt++; |
| 128 | } |
| 129 | } |
| 130 | int restart_idx = j; |
| 131 | |
| 132 | // Calculate the size of a new palette (starting at restart_idx) |
| 133 | int new_palette_size=0; |
| 134 | for(j=0; j<512; j++) { |
| 135 | if ( prev_idx[j] >= restart_idx ) { |
| 136 | new_palette_size++; |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | int create_new_palette=0; |
| 141 | if (got_palette) { |
| 142 | int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero ); |
| 143 | int old_size_bits = get_palette_index_bits( palette_size - exclude_zero ); |
| 144 | int savings = N*(old_size_bits*15-new_size_bits*15)/16 - new_palette_size*8 - 20; |
| 145 | if ( savings>0 ) { |
| 146 | // Create new palette because it can be smaller than the existing palette |
| 147 | create_new_palette=1; |
| 148 | DPRINTF("Note: at pos %d restart smaller palette\n", restart_idx); |
| 149 | } |
| 150 | } else { |
| 151 | if ( (new_palette_size-exclude_zero) <= 32) { |
| 152 | int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero ); |
| 153 | // estimate if we will make savings by using palette mode |
| 154 | int savings = N*(90-new_size_bits*15)/16 - new_palette_size*8 - 20; |
| 155 | create_new_palette = savings>0; |
| 156 | } |
| 157 | } |
| 158 | if (create_new_palette) { |
| 159 | palette_size=new_palette_size; |
| 160 | got_palette=1; |
| 161 | last_restart_idx = restart_idx; |
| 162 | DPRINTF("Note: at pos %d create palette of size %d\n", last_restart_idx, new_palette_size); |
| 163 | if ( restart_pos[restart_i-1] != last_restart_idx) { |
Tim Hall | bb33039 | 2020-04-29 15:13:00 +0100 | [diff] [blame] | 164 | if (restart_i == max_palettes) { |
| 165 | max_palettes = max_palettes*2; |
| 166 | restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); |
| 167 | } |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 168 | restart_pos[restart_i++] = last_restart_idx; |
| 169 | } |
| 170 | zero_cnt=0; |
| 171 | for( j=last_restart_idx; j<=i; j++) |
| 172 | if (buf[j]==0) |
| 173 | zero_cnt++; |
| 174 | } |
| 175 | } |
| 176 | } |
| 177 | } |
| 178 | // Reallocate to actual size |
| 179 | *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) ); |
| 180 | return restart_i; |
| 181 | } |
| 182 | |
| 183 | // Calculate frequency table |
| 184 | static void calc_freq( const int16_t *buf, int size, int freq[512] ) { |
| 185 | int i; |
| 186 | memset(freq, 0, 512*sizeof(int)); |
| 187 | for(i=0; i<size; i++) { |
| 188 | freq[buf[i]+256]++; |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | static int cmp_uint64(const void * a, const void * b) { |
| 193 | uint64_t aa = *(uint64_t*)a; |
| 194 | uint64_t bb = *(uint64_t*)b; |
| 195 | return aa>bb ? -1 : aa<bb ? 1 : 0; |
| 196 | } |
| 197 | |
| 198 | // Create palette from the given frequencies |
| 199 | // Freq index 0-511 correspond to weights -256..255 |
| 200 | static void create_palette( int freq[512], |
| 201 | int use_zero_runs, |
| 202 | palette_t *p ) { |
| 203 | uint64_t freq64[512]; |
| 204 | int i,all_cnt,all_max_val; |
| 205 | |
| 206 | // Pair the frequency with the value so that |
| 207 | // the array can be sorted on frequency while keeping |
| 208 | // track of the corresponding palette value |
| 209 | memset(freq64, 0, sizeof(freq64)); |
| 210 | all_cnt=0; |
| 211 | all_max_val=0; |
| 212 | for(i=-255; i<256; i++) { |
| 213 | if (i==0 && use_zero_runs) |
| 214 | continue; |
| 215 | int sign = i<0; |
| 216 | int mag = abs(i); |
| 217 | int palval = (mag<<1) | sign; |
| 218 | |
| 219 | // Store palette value in 16 LSB bits, which will not affect the sorting |
| 220 | freq64[palval] = (((uint64_t)freq[i+256])<<16) | palval; |
| 221 | all_cnt+=freq[i+256]; |
| 222 | |
| 223 | if (freq[i+256]>0) { |
| 224 | all_max_val = max(all_max_val, palval); |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | // Count number of non-used weight values around zero (0, -1, +1, -2, +2 etc) |
| 229 | for(i=0; i<31; i++) { |
| 230 | if ((freq64[i]>>16)!=0) |
| 231 | break; |
| 232 | } |
| 233 | p->direct_offset = i; |
| 234 | |
| 235 | // Sort in descending frequency order |
| 236 | qsort(freq64, 512, sizeof(uint64_t), cmp_uint64); |
| 237 | |
| 238 | // Identify special case that there are no weights to code |
| 239 | // in the weight index stream (i.e. all weights are zeros) |
| 240 | p->only_zeros = (freq64[0]>>16)==0; |
| 241 | if (p->only_zeros) { |
| 242 | p->direct_offset=0; |
| 243 | } |
| 244 | |
| 245 | // Check if all weights fit into the palette (and the palette is not empty) |
| 246 | p->only_palette = (freq64[0]>>16)>0 && (freq64[32]>>16)==0; |
| 247 | |
| 248 | int max_palette_size; |
| 249 | if (p->only_palette) { |
| 250 | max_palette_size = 32; |
| 251 | } else { |
| 252 | // For direct-lut we must make sure that the encoded weight |
| 253 | // index is not > 511. We do that by limiting the palette size |
| 254 | // such that the greatest value can be reached after subtracting |
| 255 | // the palette size. |
| 256 | max_palette_size = min(32, 511-all_max_val); |
| 257 | if (max_palette_size==1) { |
| 258 | max_palette_size=0; // because palette of size 1 is not supported |
| 259 | } |
| 260 | } |
| 261 | |
| 262 | // Setup the 32 entry palette |
| 263 | int palette_max_val = 0, val, cnt, pal_cnt=0; |
| 264 | for(i=0; i<max_palette_size; i++) { |
Fredrik Svedberg | 5b51388 | 2020-12-11 13:42:22 +0100 | [diff] [blame] | 265 | cnt = (int)(freq64[i]>>16); |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 266 | val = freq64[i]&0xffff; |
| 267 | if ( cnt==0 ) |
| 268 | break; |
| 269 | p->lut[i] = val; |
| 270 | palette_max_val = max(palette_max_val, val); |
| 271 | pal_cnt+=cnt; |
| 272 | } |
| 273 | if (i==1) |
| 274 | i++; // palette size of 1 is not supported, make it 2 |
| 275 | |
| 276 | // Heuristic for when to use the palette. If more than half of the |
| 277 | // weights are in the palette then we use it. This ensures we don't |
| 278 | // use palette for e.g. rectangular distributions. |
| 279 | int palbits_val; |
| 280 | if (pal_cnt > all_cnt/2) { |
| 281 | p->palsize = i; |
| 282 | palbits_val = palette_max_val; |
| 283 | } else { |
| 284 | // No palette |
| 285 | p->palsize = 0; |
| 286 | // If no palette, then palbits is used to specify the |
| 287 | // number of bits required for uncompressed mode, i.e. |
| 288 | // the number of bits for the greatest weight value |
| 289 | palbits_val = all_max_val; |
| 290 | } |
| 291 | |
| 292 | // the palette entry bit width |
| 293 | // minimum 2bits (because PALBITS is in range 2..9) |
| 294 | int palbits=2; |
| 295 | while( (1<<palbits) <= palbits_val ) |
| 296 | palbits++; |
| 297 | assert(palbits<=9); |
| 298 | p->palbits = palbits; |
| 299 | p->use_zero_runs = use_zero_runs; |
| 300 | } |
| 301 | |
| 302 | // Return 1 if zero runs should be used |
| 303 | // If palette_size is 512, then palette is not used (in that case the palette is setup |
| 304 | // with the standard alternating unsigned to signed mapping) |
| 305 | static int find_palette( const int16_t *inbuf, int inbuf_size, palette_t *p) { |
| 306 | int freq[512], i; |
| 307 | |
| 308 | // Calculate frequencies of the given weight stream |
| 309 | calc_freq( inbuf, inbuf_size, freq); |
| 310 | |
| 311 | // Find two most common values |
| 312 | int most_common_freq[2]={0}, most_common_val[2]={0}; |
| 313 | for(i=0; i<512; i++) { |
| 314 | if ( freq[i] > most_common_freq[0] ) { |
| 315 | most_common_freq[1] = most_common_freq[0]; |
| 316 | most_common_val[1] = most_common_val[0]; |
| 317 | most_common_freq[0] = freq[i]; |
| 318 | most_common_val[0] = i-256; |
| 319 | } else if ( freq[i] > most_common_freq[1] ) { |
| 320 | most_common_freq[1] = freq[i]; |
| 321 | most_common_val[1] = i-256; |
| 322 | } |
| 323 | } |
| 324 | |
| 325 | // Decide if zero-runs (alternating mode) should be used: |
| 326 | // * zero should be the most common symbol |
| 327 | // * zero should be sufficiently more common than the second most common symbol |
| 328 | int use_zero_runs = most_common_val[0]==0 && most_common_freq[0] > ZERO_RUN_THRES*most_common_freq[1]; |
| 329 | |
| 330 | // Create the palette |
| 331 | create_palette( freq, use_zero_runs, p); |
| 332 | |
| 333 | return use_zero_runs; |
| 334 | } |
| 335 | |
| 336 | static void create_inverse_palette( palette_t *p) { |
| 337 | int i; |
| 338 | memset( p->inv_lut, 0, sizeof(p->inv_lut)); |
| 339 | for(i=0; i<512; i++) { |
| 340 | int val = i; |
| 341 | int sign = val&1; |
| 342 | int mag = val>>1; |
| 343 | int weight = sign ? -mag : mag; |
| 344 | if (weight+256 < 512) |
| 345 | p->inv_lut[ weight+256 ] = i + p->palsize - p->direct_offset; |
| 346 | } |
| 347 | for(i=0; i<p->palsize; i++) { |
| 348 | int val = p->lut[i]; |
| 349 | int sign = val&1; |
| 350 | int mag = val>>1; |
| 351 | int weight = sign ? -mag : mag; |
| 352 | if (weight+256 < 512) |
| 353 | p->inv_lut[ weight+256 ] = i; |
| 354 | } |
| 355 | } |
| 356 | |
| 357 | #define NWCFG 13 |
| 358 | #define NZCFG 4 // restrict search to ZDIV=0..3 |
| 359 | #define MAX_ZWCFG (max(NWCFG,NZCFG)) |
| 360 | |
| 361 | // search state |
| 362 | typedef struct search_state { |
| 363 | int bitcnt; // number of bits to reach this state |
| 364 | uint8_t prev_cfg; // previous grc parameter config |
| 365 | } search_state_t; |
| 366 | |
| 367 | // (trunc<<4) | div, 0x20 means uncompressed |
| 368 | static const char w_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20 }; |
| 369 | static const char z_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04 }; |
| 370 | |
| 371 | |
| 372 | |
| 373 | // An algorithm similar to the Viterbi algorithm is used to search for a |
| 374 | // good GRC parameter sequence for the given input value sequence. |
| 375 | // The inval buffer can contain weights, weight indices or runs. |
| 376 | // The return value is the resulting number of bitstream sections. |
| 377 | static int search_grc_params( const int *inval_buf, |
| 378 | int n_inval, |
| 379 | int zrun_mode, |
| 380 | int uncompressed_bits, |
| 381 | uint8_t *grc_param_cfg, |
| 382 | int *grc_param_pos, |
| 383 | int max_grc_param_cfg, |
| 384 | int *existing_grc_param_pos, |
| 385 | int n_existing_grc_param_pos, |
| 386 | int *bitcnt ) |
| 387 | { |
| 388 | int n_cfg = zrun_mode ? NZCFG : NWCFG; |
| 389 | const char *grc_params = zrun_mode ? z_grc_params : w_grc_params; |
| 390 | int i,j; |
| 391 | |
| 392 | search_state_t *state[MAX_ZWCFG]; |
| 393 | for(i=0; i<n_cfg; i++) { |
| 394 | state[i] = malloc( sizeof(search_state_t) * (n_inval+1) ); |
| 395 | state[i][0].bitcnt=0; |
| 396 | state[i][0].prev_cfg=i; |
| 397 | } |
| 398 | |
| 399 | // Loop over inval_buf |
| 400 | int existing_idx=0; |
| 401 | for(i=0; i<n_inval; i++) { |
| 402 | int value = inval_buf[i]; |
| 403 | |
| 404 | // Best GRC parameter so far |
| 405 | int best_bitcnt=0x7fffffff, best_cfg=0; |
| 406 | for(j=0; j<n_cfg; j++) { |
| 407 | if (state[j][i].bitcnt < best_bitcnt) { |
| 408 | best_bitcnt = state[j][i].bitcnt; |
| 409 | best_cfg = j; |
| 410 | } |
| 411 | } |
| 412 | |
| 413 | int cmd_cost = 40; |
| 414 | if (existing_idx < n_existing_grc_param_pos && existing_grc_param_pos[existing_idx] == (i+1)) { |
| 415 | // free transition, because the weight stream already inserted a command at this position |
| 416 | cmd_cost = 0; |
| 417 | existing_idx++; |
| 418 | } |
| 419 | |
| 420 | // Loop over GRC parameters, calculate bits to code value, and then update the search state |
| 421 | for(j=0; j<n_cfg; j++) { |
| 422 | int div = grc_params[j]&15; |
| 423 | int trunc = grc_params[j]>>4; |
| 424 | int q = value>>div; |
| 425 | int bits = trunc ? min(q+1,2) + div : q+1+div; |
| 426 | if (!zrun_mode && ((trunc && q>2) || q>31)) |
| 427 | bits=10000; // it's not possible to code the current value; give it a high cost |
| 428 | if (trunc==2) |
| 429 | bits=uncompressed_bits; |
| 430 | |
| 431 | if ( best_bitcnt + cmd_cost < state[j][i].bitcnt ) { |
| 432 | // Change GRC parameters |
| 433 | state[j][i+1].prev_cfg = best_cfg; |
| 434 | state[j][i+1].bitcnt = best_bitcnt + cmd_cost + bits; |
| 435 | } else { |
| 436 | // Keep same GRC parameters |
| 437 | state[j][i+1].prev_cfg = j; |
| 438 | state[j][i+1].bitcnt = state[j][i].bitcnt + bits; |
| 439 | } |
| 440 | } |
| 441 | } |
| 442 | |
| 443 | |
| 444 | // Best GRC parameter |
| 445 | int best_bitcnt=0x7fffffff, best_cfg=0; |
| 446 | for(j=0; j<n_cfg; j++) { |
| 447 | if (state[j][n_inval].bitcnt < best_bitcnt) { |
| 448 | best_bitcnt = state[j][n_inval].bitcnt; |
| 449 | best_cfg = j; |
| 450 | } |
| 451 | } |
| 452 | |
| 453 | int cfg = best_cfg; |
| 454 | int n_cmds=0; |
| 455 | for(i=n_inval; i>=0; i--) { |
| 456 | if (state[cfg][i].prev_cfg != cfg || i==0) { |
| 457 | n_cmds++; |
| 458 | cfg = state[cfg][i].prev_cfg; |
| 459 | } |
| 460 | } |
| 461 | |
| 462 | (void)(max_grc_param_cfg); |
| 463 | assert(n_cmds<=max_grc_param_cfg); |
| 464 | |
| 465 | cfg = best_cfg; |
| 466 | j=n_cmds-1; |
| 467 | int endpos=n_inval; |
| 468 | for(i=n_inval; i>=0; i--) { |
| 469 | if (state[cfg][i].prev_cfg != cfg || i==0) { |
| 470 | grc_param_cfg[j] = cfg; |
| 471 | grc_param_pos[j] = endpos; |
| 472 | j--; |
| 473 | cfg = state[cfg][i].prev_cfg; |
| 474 | endpos = i-1; |
| 475 | } |
| 476 | } |
| 477 | assert(j==-1); |
| 478 | |
| 479 | for(i=0; i<n_cfg; i++) { |
| 480 | free(state[i]); |
| 481 | } |
| 482 | |
| 483 | *bitcnt = best_bitcnt; |
| 484 | |
| 485 | return n_cmds; |
| 486 | } |
| 487 | |
| 488 | |
| 489 | /////////////////////////////// Write to bitstream |
| 490 | |
| 491 | typedef struct bitbuf { |
| 492 | uint8_t *buf; |
| 493 | int buf_size; // in bytes |
| 494 | int pos; // bit pos of next bit |
| 495 | int log_symbols; |
| 496 | } bitbuf_t; |
| 497 | |
| 498 | // size in byte |
| 499 | static void bitbuf_init( bitbuf_t *bb, uint8_t *buf, int size, int log_symbols ) { |
| 500 | bb->buf = buf; |
| 501 | bb->pos = 0; |
| 502 | bb->buf_size = size; |
| 503 | bb->log_symbols = log_symbols; |
| 504 | } |
| 505 | |
| 506 | static void bitbuf_putbit( bitbuf_t *bb, int bit) { |
| 507 | int byte_pos = bb->pos>>3; |
| 508 | int bit_pos = bb->pos&7; |
| 509 | assert( byte_pos >= 0 ); |
| 510 | assert( byte_pos < bb->buf_size ); |
| 511 | bb->buf[ byte_pos ] = (bb->buf[ byte_pos ] & ~(1<<bit_pos)) | (bit<<bit_pos); |
| 512 | bb->pos += 1; |
| 513 | } |
| 514 | |
| 515 | static void bitbuf_put( bitbuf_t *bb, const char *name, int len, int data) { |
| 516 | int i; |
| 517 | if (len>0) { |
| 518 | if (bb->log_symbols) |
| 519 | printf("bitbuf: pos %3d %7s len %d data %x\n", bb->pos, name, len, data); |
| 520 | for(i=0; i<len; i++) { |
| 521 | bitbuf_putbit(bb, (data>>i)&1); |
| 522 | } |
| 523 | } |
| 524 | } |
| 525 | |
| 526 | // Return new bitpos |
| 527 | static int encode_slice( const int *w_value, |
| 528 | const int *z_value, |
| 529 | int nvalues, |
| 530 | palette_t *p, |
| 531 | int new_palette, |
| 532 | int uncompressed_bits, |
| 533 | int w_cfg, |
| 534 | int z_cfg, |
| 535 | uint8_t *bitbuf, |
| 536 | int bitbuf_size, |
| 537 | int bitpos, |
| 538 | int verbose ) |
| 539 | { |
| 540 | int i,j; |
| 541 | bitbuf_t bitbuf_s, *bb=&bitbuf_s; |
| 542 | bitbuf_init( bb, bitbuf, bitbuf_size, verbose&2?1:0 ); |
| 543 | bb->pos = bitpos; |
| 544 | |
| 545 | assert(nvalues<32768); |
| 546 | // GRC parameters for this slice |
| 547 | int w_grc_div = w_grc_params[w_cfg] & 15; |
| 548 | int w_grc_trunc = (w_grc_params[w_cfg] >> 4)==1; |
| 549 | int w_uncompressed = (w_grc_params[w_cfg] >> 4)==2; |
| 550 | int z_grc_div = z_grc_params[z_cfg] & 15; |
| 551 | |
| 552 | if (w_uncompressed) { |
| 553 | w_grc_div = uncompressed_bits; |
| 554 | } |
| 555 | |
| 556 | int zdiv = p->use_zero_runs ? z_grc_div : ZDIV_DISABLE; |
| 557 | int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED; |
| 558 | |
| 559 | if (verbose&1) { |
| 560 | printf("slice: bitoffset %7d slicelen %5d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %2d\n", |
| 561 | bb->pos, nvalues, zdiv, wdiv, w_grc_trunc, new_palette, p->palbits, p->palsize); |
| 562 | } |
| 563 | |
| 564 | // Write slice header |
| 565 | bitbuf_put( bb, "ZDIV", 3, zdiv); |
| 566 | bitbuf_put( bb, "SLICELEN", 15, nvalues-1 ); |
| 567 | bitbuf_put( bb, "WDIV", 3, wdiv); |
| 568 | bitbuf_put( bb, "WTRUNC", 1, w_grc_trunc ); |
| 569 | bitbuf_put( bb, "NEWPAL", 1, new_palette ); |
| 570 | if (new_palette) { |
| 571 | bitbuf_put( bb, "DIROFS", 5, p->direct_offset ); |
| 572 | bitbuf_put( bb, "PALSIZE", 5, max(0, p->palsize-1)); |
| 573 | bitbuf_put( bb, "PALBITS", 3, p->palbits-2 ); |
| 574 | for(i=0; i<p->palsize; i++) { |
| 575 | bitbuf_put( bb, "PALETTE", p->palbits, p->lut[i] ); |
| 576 | } |
| 577 | } |
| 578 | |
| 579 | int z_nvalues = nvalues + (new_palette?1:0); |
| 580 | int w_pos=0, z_pos=0; |
| 581 | int w_unary0=0, w_unary1=0, w_unary1_len=0, w_q=-1, w_r=0; |
| 582 | int z_unary=0, z_q=-1, z_r=0; |
| 583 | int w_nsymbols=0, w_remain[12]={0}; |
| 584 | int w_prev_enable=0, w_prev_nsymbols=0, w_prev_remain[12]={0}; |
| 585 | int z_nsymbols=0, z_remain[12]={0}; |
| 586 | int z_prev_enable=0, z_prev_nsymbols=0, z_prev_remain[12]={0}; |
| 587 | int z_unary_len = z_grc_div<3 ? 12 : 8; |
| 588 | do { |
| 589 | int balance = p->use_zero_runs ? w_pos - z_pos : 0; |
| 590 | int w_enable = balance<8 && w_pos<nvalues; |
| 591 | int z_enable = balance>=0 && p->use_zero_runs && z_pos<z_nvalues; |
| 592 | if (w_enable) { |
| 593 | // Encode chunk (weights) |
| 594 | j=0; |
| 595 | w_nsymbols=0; |
| 596 | w_unary0=0; |
| 597 | w_unary1=0; |
| 598 | w_unary1_len=0; |
| 599 | int max_symbols = w_uncompressed && w_grc_div>5 ? 8 : 12; |
| 600 | while(j<max_symbols) { |
| 601 | if (w_q<0) { |
| 602 | if (w_pos<nvalues) { |
| 603 | int value = w_value[w_pos]; |
| 604 | assert(value<512); |
| 605 | w_q = value>>w_grc_div; |
| 606 | w_r = value&((1<<w_grc_div)-1); |
| 607 | assert( w_q<=31 && (!w_grc_trunc || w_q<=2)); |
| 608 | } else { |
| 609 | w_q = 0; |
| 610 | w_r = -1; // don't send remainder |
| 611 | } |
| 612 | } |
| 613 | while( w_q>=0 && j<max_symbols) { |
| 614 | w_unary0 |= w_q>0 ? (1<<j) : 0; |
| 615 | if (w_q>0) { |
| 616 | w_unary1 |= w_q>1 ? (1<<w_unary1_len) : 0; |
| 617 | w_unary1_len++; |
| 618 | } |
| 619 | j++; |
| 620 | w_q-=2; |
| 621 | if (w_grc_trunc) |
| 622 | w_q--; |
| 623 | } |
| 624 | if (w_q<0 && w_r>=0) { |
| 625 | w_remain[w_nsymbols] = w_r; |
| 626 | w_nsymbols++; |
| 627 | w_pos++; |
| 628 | } |
| 629 | } |
| 630 | } |
| 631 | |
| 632 | if (z_enable) { |
| 633 | // Encode chunk (zrun) |
| 634 | j=0; |
| 635 | z_nsymbols=0; |
| 636 | z_unary=0; |
| 637 | while(j<z_unary_len) { |
| 638 | if (z_q<0) { |
| 639 | if (z_pos<z_nvalues) { |
| 640 | int value = z_value[z_pos]; |
| 641 | z_q = value>>z_grc_div; |
| 642 | z_r = value&((1<<z_grc_div)-1); |
| 643 | } else { |
| 644 | z_q = 0; |
| 645 | z_r = -1; |
| 646 | } |
| 647 | } |
| 648 | while( z_q>=0 && j<z_unary_len) { |
| 649 | z_unary |= z_q>0 ? (1<<j) : 0; |
| 650 | j++; |
| 651 | z_q--; |
| 652 | } |
| 653 | if (z_q<0 && z_r>=0) { |
| 654 | z_remain[z_nsymbols] = z_r; |
| 655 | z_nsymbols++; |
| 656 | z_pos++; |
| 657 | } |
| 658 | } |
| 659 | } |
| 660 | |
| 661 | // Write chunk to bitstream |
| 662 | if (w_enable && !w_uncompressed) { |
| 663 | bitbuf_put( bb, "WUNARY0", 12, w_unary0); |
| 664 | } |
| 665 | if (z_enable) { |
| 666 | bitbuf_put( bb, "ZUNARY", z_unary_len, z_unary); |
| 667 | } |
| 668 | if (w_enable && !w_uncompressed) { |
| 669 | bitbuf_put( bb, "WUNARY1", w_unary1_len, w_unary1); |
| 670 | } |
| 671 | if (w_prev_enable) { |
| 672 | for(i=0; i<w_prev_nsymbols; i++) { |
| 673 | bitbuf_put( bb, "WREMAIN", w_grc_div, w_prev_remain[i]); |
| 674 | } |
| 675 | } |
| 676 | if (z_prev_enable) { |
| 677 | for(i=0; i<z_prev_nsymbols; i++) { |
| 678 | bitbuf_put( bb, "ZREMAIN", z_grc_div, z_prev_remain[i]); |
| 679 | } |
| 680 | } |
| 681 | w_prev_enable = w_enable; |
| 682 | w_prev_nsymbols = w_nsymbols; |
| 683 | memcpy( w_prev_remain, w_remain, sizeof(w_prev_remain)); |
| 684 | z_prev_enable = z_enable; |
| 685 | z_prev_nsymbols = z_nsymbols; |
| 686 | memcpy( z_prev_remain, z_remain, sizeof(z_prev_remain)); |
| 687 | } while( w_prev_enable || z_prev_enable ); |
| 688 | |
| 689 | return bb->pos; |
| 690 | } |
| 691 | |
| 692 | |
| 693 | // return new bitpos |
| 694 | static int encode_section( const int16_t *inbuf, |
| 695 | int size, |
| 696 | palette_t *p, |
| 697 | uint8_t *bitbuf, |
| 698 | int bitbuf_size, |
| 699 | int bitpos, |
| 700 | int verbose ) |
| 701 | { |
| 702 | int uncompressed_bits; |
| 703 | |
| 704 | // Uncompressed mode can only be used if either all weights |
| 705 | // are in the palette OR if the palette is not used. |
| 706 | if (p->only_palette) { |
| 707 | // Uncompressed bits derived from palette size |
| 708 | uncompressed_bits=0; |
| 709 | while( (1<<uncompressed_bits) < p->palsize ) |
| 710 | uncompressed_bits++; |
| 711 | } else if (p->palsize==0) { |
| 712 | // Uncompressed bits is palbits (which is the bitdepth of the greatest weight) |
| 713 | uncompressed_bits = p->palbits; |
| 714 | } else { |
| 715 | // Don't use uncompressed |
| 716 | uncompressed_bits = 100; |
| 717 | } |
| 718 | |
| 719 | int *weight_values = malloc( size*sizeof(int) ); |
| 720 | int *zrun_values = malloc( size*sizeof(int) ); |
| 721 | |
| 722 | // Get weights (or weight indicies) AND zero-runs from the input weight stream. |
| 723 | int i=0, n_weights = 0, zcnt; |
| 724 | while(1) { |
| 725 | if (p->use_zero_runs) { |
| 726 | zcnt=0; |
| 727 | // Count zero run |
| 728 | // Special case: if all weights in the section are zero, we must |
| 729 | // still ensure we have one coded weight so the the slice length |
| 730 | // doesn't become 0. Therefore we skip the first zero run and code |
| 731 | // the zero explicitly as a weight value instead |
| 732 | if (!p->only_zeros || i>0) { |
| 733 | while( i<size && inbuf[i]==0) { |
| 734 | zcnt++; |
| 735 | i++; |
| 736 | } |
| 737 | } |
| 738 | zrun_values[n_weights] = zcnt; |
| 739 | } |
| 740 | if (i==size) |
| 741 | break; |
| 742 | int value = p->inv_lut[inbuf[i]+256]; |
| 743 | weight_values[n_weights] = value; |
| 744 | n_weights++; |
| 745 | i++; |
| 746 | } |
| 747 | |
| 748 | // Search for good GRC parameters for the weight stream |
| 749 | int n_w_slice, w_bitcnt; |
| 750 | uint8_t *w_slice_cfg; |
| 751 | int *w_slice_pos; |
| 752 | w_slice_cfg = malloc( size ); |
| 753 | w_slice_pos = malloc( size*sizeof(int) ); |
| 754 | n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt); |
| 755 | if (n_weights==0) |
| 756 | n_w_slice = 0; |
| 757 | |
| 758 | // Search for good GRC parameters for the zrun stream |
| 759 | int n_z_slice=0, z_bitcnt=0; |
| 760 | uint8_t *z_slice_cfg=0; |
| 761 | int *z_slice_pos=0; |
| 762 | if (p->use_zero_runs) { |
| 763 | z_slice_cfg = malloc( size ); |
| 764 | z_slice_pos = malloc( size*sizeof(int) ); |
| 765 | n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt); |
| 766 | } |
| 767 | |
| 768 | // Encode bitstream slice |
| 769 | int pos=0, i_w_slice=0, i_z_slice=0, new_palette=1; |
| 770 | while(pos<n_weights || new_palette) { |
| 771 | int endpos=pos+32767; // max slice length |
| 772 | |
| 773 | if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]<endpos) { |
| 774 | endpos = w_slice_pos[i_w_slice]; |
| 775 | } |
| 776 | |
| 777 | if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]<endpos) { |
| 778 | endpos = z_slice_pos[i_z_slice]; |
| 779 | } |
| 780 | |
| 781 | if (n_weights < endpos) { |
| 782 | endpos = n_weights; |
| 783 | } |
| 784 | |
| 785 | // The first slice (when new_palette is 1) encodes zero runs both at the |
| 786 | // beginning and end (i.e. number of zero runs are len+1). |
| 787 | // The following slices only encode zero runs at the end (there cannot be |
| 788 | // any zeros in the beginning since they are encoded by the previous slice) |
| 789 | int len = endpos - pos; |
| 790 | int *zrun_buf = p->use_zero_runs ? zrun_values+pos+(!new_palette) : 0; |
| 791 | bitpos = encode_slice( weight_values+pos, zrun_buf, len, |
| 792 | p, new_palette, uncompressed_bits, |
| 793 | w_slice_cfg[i_w_slice], p->use_zero_runs ? z_slice_cfg[i_z_slice] : 0, |
| 794 | bitbuf, bitbuf_size, bitpos, verbose ); |
| 795 | new_palette = 0; |
| 796 | |
| 797 | if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]==endpos) { |
| 798 | i_w_slice++; |
| 799 | } |
| 800 | if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]==endpos) { |
| 801 | i_z_slice++; |
| 802 | } |
| 803 | pos = endpos; |
| 804 | } |
| 805 | |
| 806 | // Free temporary buffers |
| 807 | free(w_slice_cfg); |
| 808 | free(w_slice_pos); |
| 809 | if (p->use_zero_runs) { |
| 810 | free(z_slice_cfg); |
| 811 | free(z_slice_pos); |
| 812 | } |
| 813 | free(weight_values); |
| 814 | free(zrun_values); |
| 815 | |
| 816 | return bitpos; |
| 817 | } |
| 818 | |
| 819 | // Encode the given weight stream |
| 820 | // inbuf uncompressed 9bit signed weights |
| 821 | // inbuf_size number of weights |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 822 | // outbuf compressed bitstream, buffer is malloced within this function |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 823 | // verbose if non-zero, printf log |
| 824 | // Return value is the size in bytes of the compressed output |
| 825 | // Return -1 if error |
| 826 | int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { |
| 827 | int i; |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 828 | #ifndef NDEBUG |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 829 | // Range check |
| 830 | for(i=0; i<inbuf_size; i++) { |
| 831 | if (inbuf[i]<-255 || inbuf[i]>255) { |
| 832 | printf("ERROR: weight out of range at index %d, weight value is %d (valid range is -255..255)\n", i, inbuf[i]); |
| 833 | return -1; |
| 834 | } |
| 835 | } |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 836 | #endif |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 837 | |
| 838 | int bitbuf_size = inbuf_size*2+1024; |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 839 | assert(*outbuf == NULL); |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 840 | *outbuf = malloc( bitbuf_size ); |
| 841 | |
| 842 | // Analyse input data to find palette re-programming points |
| 843 | int n_restarts; |
| 844 | int *palette_restart_pos; |
| 845 | n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos); |
| 846 | |
| 847 | // Compress each section (using a single palette) separately |
| 848 | int bitpos=0; |
| 849 | for(i=0; i<n_restarts; i++) { |
| 850 | palette_t palette; |
| 851 | int pos, size; |
| 852 | pos = palette_restart_pos[i]; |
| 853 | size = (i<n_restarts-1 ? palette_restart_pos[i+1] : inbuf_size) - pos; |
| 854 | find_palette( inbuf+pos, size, &palette); |
| 855 | create_inverse_palette( &palette); |
| 856 | bitpos = encode_section( inbuf+pos, size, &palette, |
| 857 | *outbuf, bitbuf_size, bitpos, verbose ); |
| 858 | } |
| 859 | |
| 860 | |
| 861 | // Add end of stream marker and align to 128bit |
| 862 | { |
| 863 | bitbuf_t bitbuf_s, *bb=&bitbuf_s; |
| 864 | bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 ); |
| 865 | bb->pos = bitpos; |
| 866 | bitbuf_put( bb, "ZDIV", 3, ZDIV_EOS); |
| 867 | bitbuf_put( bb, "BYTEALIGN", (8-(bb->pos&7))&7, 0xff ); |
| 868 | |
| 869 | // Pad with 0xff until 64bit aligned |
| 870 | while( bb->pos & 127 ) { |
| 871 | bitbuf_put( bb, "PAD", 8, 0xff ); |
| 872 | } |
| 873 | bitpos = bb->pos; |
| 874 | } |
| 875 | assert((bitpos&127)==0); |
| 876 | int outbuf_size = bitpos/8; |
| 877 | *outbuf = realloc( *outbuf, outbuf_size); |
| 878 | |
| 879 | free(palette_restart_pos); |
| 880 | |
| 881 | return outbuf_size; |
| 882 | } |
| 883 | |
| 884 | void mlw_free_outbuf( uint8_t *outbuf ) { |
| 885 | if (outbuf) |
| 886 | free(outbuf); |
| 887 | } |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 888 | |
| 889 | static int round_up_divide(int num, int den) |
| 890 | { |
| 891 | return (num + den - 1) / den; |
| 892 | } |
| 893 | |
| 894 | static int round_up(int num, int den) |
| 895 | { |
| 896 | return round_up_divide(num, den) * den; |
| 897 | } |
| 898 | |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 899 | struct brick_buf_s |
| 900 | { |
| 901 | uint8_t* buf; |
| 902 | int* strides; |
| 903 | }; |
| 904 | typedef struct brick_buf_s brick_buf_t; |
| 905 | |
| 906 | static int16_t get_brick_weight(brick_buf_t* buf, int ofm_z, int wy, int wx, int ifm_z) |
| 907 | { |
| 908 | uint8_t* p = buf->buf; |
| 909 | |
| 910 | p += ofm_z * buf->strides[0]; |
| 911 | p += wy * buf->strides[1]; |
| 912 | p += wx * buf->strides[2]; |
| 913 | p += ifm_z * buf->strides[3]; |
| 914 | |
| 915 | return *(int16_t*)p; |
| 916 | } |
| 917 | |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 918 | static void reorder_free(int16_t* buf) |
| 919 | { |
| 920 | if (buf) |
| 921 | { |
| 922 | free(buf); |
| 923 | } |
| 924 | } |
| 925 | |
| 926 | static int16_t* reorder( |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 927 | int ifm_ublock_depth, |
| 928 | int ofm_ublock_depth, |
| 929 | int ofm_depth, |
| 930 | int kernel_height, |
| 931 | int kernel_width, |
| 932 | int ifm_depth, |
| 933 | int* strides, |
| 934 | void* inbuf, |
| 935 | int ofm_block_depth, |
| 936 | int is_depthwise, |
| 937 | int is_partkernel, |
| 938 | int ifm_bitdepth, |
| 939 | int decomp_h, |
| 940 | int decomp_w, |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 941 | int64_t* padded_length) |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 942 | { |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 943 | /* Size unknown. Start with one page at least */ |
| 944 | *padded_length = round_up(max(1, sizeof(int16_t)* |
| 945 | ofm_depth* |
| 946 | kernel_height* |
| 947 | kernel_width* |
| 948 | ifm_depth), |
| 949 | 4*1024) / sizeof(int16_t); |
| 950 | int16_t* weights = (int16_t*)malloc(*padded_length * sizeof(int16_t)); |
| 951 | |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 952 | brick_buf_t brick_buf; |
| 953 | brick_buf.buf = inbuf; |
| 954 | brick_buf.strides = strides; |
| 955 | |
| 956 | int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32; |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 957 | int64_t weight_cnt = 0; |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 958 | for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth) |
| 959 | { |
| 960 | int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z); |
| 961 | // IFM blocks required for the brick |
| 962 | for (int ifm_block_z = 0; ifm_block_z < (is_depthwise ? 1 : ifm_depth); ifm_block_z += ifm_block_depth) |
| 963 | { |
| 964 | int clipped_ifm_block_depth; |
| 965 | if (is_depthwise) |
| 966 | { |
| 967 | clipped_ifm_block_depth = ifm_ublock_depth; |
| 968 | } |
| 969 | else |
| 970 | { |
| 971 | clipped_ifm_block_depth = is_partkernel ? |
| 972 | min(ifm_block_depth, ifm_depth - ifm_block_z) : ifm_block_depth; |
| 973 | } |
| 974 | // Weight decomposition |
| 975 | // Subkernel Splitting (H) |
| 976 | for (int subkernel_y = 0; subkernel_y < kernel_height; subkernel_y += decomp_h) |
| 977 | { |
| 978 | int sub_height = min(kernel_height - subkernel_y, decomp_h); |
| 979 | // Subkernel splitting (W) |
| 980 | for (int subkernel_x = 0; subkernel_x < kernel_width; subkernel_x += decomp_w) |
| 981 | { |
| 982 | int sub_width = min(kernel_width - subkernel_x, decomp_w); |
| 983 | int subkernel_elements = sub_width * sub_height; |
| 984 | // Part kernel first works across the kernel H/W and needs padding |
| 985 | if (is_partkernel) |
| 986 | { |
| 987 | if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0) |
| 988 | { |
| 989 | subkernel_elements = round_up(subkernel_elements, 2); |
| 990 | } |
| 991 | else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0) |
| 992 | { |
| 993 | subkernel_elements = round_up(subkernel_elements, 4); |
| 994 | } |
| 995 | } |
| 996 | else if (is_depthwise) |
| 997 | { |
| 998 | subkernel_elements = round_up(subkernel_elements, 4); |
| 999 | } |
| 1000 | int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1; |
| 1001 | int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth; |
| 1002 | for (int ifm_ublk_outer = 0; ifm_ublk_outer < ifm_block_depth_outer; ifm_ublk_outer += ifm_ublock_depth) |
| 1003 | { |
| 1004 | // OFM Ublocks in OFM-block over depth |
| 1005 | for (int ofm_ublk = 0; ofm_ublk < clipped_ofm_block_depth; ofm_ublk += ofm_ublock_depth) |
| 1006 | { |
| 1007 | // HW Kernel element traversal - cannot be a H/W loop due to element |
| 1008 | // padding requirement on depthwise/part-kernel configurations |
| 1009 | for (int element = 0; element < subkernel_elements; element++) |
| 1010 | { |
| 1011 | int kx = element % sub_width; |
| 1012 | int ky = element / sub_width; |
| 1013 | // IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise) |
| 1014 | // In case of part-kernel-first IFM Ublock traversal have already been handled |
| 1015 | // and this loop is ignored. |
| 1016 | for (int ifm_ublk_inner = 0; ifm_ublk_inner < ifm_block_depth_inner; ifm_ublk_inner += ifm_ublock_depth) |
| 1017 | { |
| 1018 | // Feed OFM ublock elements |
| 1019 | for (int ofm_ublock_z = 0; ofm_ublock_z < ofm_ublock_depth; ofm_ublock_z++) |
| 1020 | { |
| 1021 | // Source IFM ublock elements (only 1 element deep if depthwise) |
| 1022 | for (int ifm_ublock_z = 0; ifm_ublock_z < (is_depthwise ? 1 : ifm_ublock_depth); ifm_ublock_z++) |
| 1023 | { |
| 1024 | // Source position within the current subkernel |
| 1025 | int wx = subkernel_x + kx; |
| 1026 | int wy = subkernel_y + ky; |
| 1027 | // Source IFM/OFM slices |
| 1028 | int ifm_ublk = ifm_ublk_inner + ifm_ublk_outer; |
| 1029 | int ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z; |
| 1030 | int ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z; |
| 1031 | if ((ifm_z < ifm_depth) && (ofm_z < ofm_depth) && (ky < sub_height)) |
| 1032 | { |
| 1033 | weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z); |
| 1034 | } |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1035 | else |
| 1036 | { |
| 1037 | weights[weight_cnt] = 0; |
| 1038 | } |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1039 | weight_cnt++; |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1040 | if (weight_cnt == *padded_length) |
| 1041 | { |
| 1042 | // Reallocate by doubling the buffer size as needed |
| 1043 | *padded_length *= 2; |
| 1044 | weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t)); |
| 1045 | } |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1046 | } |
| 1047 | } |
| 1048 | } |
| 1049 | } |
| 1050 | } |
| 1051 | } |
| 1052 | } |
| 1053 | } |
| 1054 | } |
| 1055 | } |
| 1056 | |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1057 | *padded_length = weight_cnt; |
| 1058 | weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t)); |
| 1059 | return weights; |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1060 | } |
| 1061 | |
| 1062 | // Reorder and encode the given weight stream |
| 1063 | // Return value is the size in bytes of the compressed output |
| 1064 | // Return -1 if error |
| 1065 | int mlw_reorder_encode( |
| 1066 | int ifm_ublock_depth, |
| 1067 | int ofm_ublock_depth, |
| 1068 | int ofm_depth, |
| 1069 | int kernel_height, |
| 1070 | int kernel_width, |
| 1071 | int ifm_depth, |
| 1072 | int* brick_strides, |
| 1073 | void* inbuf, |
| 1074 | int ofm_block_depth, |
| 1075 | int is_depthwise, |
| 1076 | int is_partkernel, |
| 1077 | int ifm_bitdepth, |
| 1078 | int decomp_h, |
| 1079 | int decomp_w, |
| 1080 | uint8_t **outbuf, // *outbuf must be freed by caller |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1081 | int64_t* padded_length, |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1082 | int verbose) |
| 1083 | { |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1084 | /* Reorder weights */ |
| 1085 | int16_t* weights = reorder( |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1086 | ifm_ublock_depth, |
| 1087 | ofm_ublock_depth, |
| 1088 | ofm_depth, |
| 1089 | kernel_height, |
| 1090 | kernel_width, |
| 1091 | ifm_depth, |
| 1092 | brick_strides, |
| 1093 | inbuf, |
| 1094 | ofm_block_depth, |
| 1095 | is_depthwise, |
| 1096 | is_partkernel, |
| 1097 | ifm_bitdepth, |
| 1098 | decomp_h, |
| 1099 | decomp_w, |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1100 | padded_length); |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1101 | |
Fredrik Svedberg | 93d5c35 | 2021-05-11 13:51:47 +0200 | [diff] [blame^] | 1102 | /* Then encode */ |
| 1103 | int output_length = 0; |
| 1104 | if (*padded_length > 0) |
| 1105 | { |
| 1106 | output_length = mlw_encode(weights, *padded_length, outbuf, verbose); |
| 1107 | } |
| 1108 | reorder_free(weights); |
Mauricio Briceno | 67e11f7 | 2021-05-05 12:47:28 +0200 | [diff] [blame] | 1109 | |
| 1110 | return output_length; |
| 1111 | } |