Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (c) 2020 Arm Limited. All rights reserved. |
| 3 | * |
| 4 | * SPDX-License-Identifier: Apache-2.0 |
| 5 | * |
| 6 | * Licensed under the Apache License, Version 2.0 (the License); you may |
| 7 | * not use this file except in compliance with the License. |
| 8 | * You may obtain a copy of the License at |
| 9 | * |
| 10 | * www.apache.org/licenses/LICENSE-2.0 |
| 11 | * |
| 12 | * Unless required by applicable law or agreed to in writing, software |
| 13 | * distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 14 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | * See the License for the specific language governing permissions and |
| 16 | * limitations under the License. |
| 17 | */ |
| 18 | |
| 19 | #include <stdio.h> |
| 20 | #include <stdlib.h> |
| 21 | #include <stdint.h> |
| 22 | #include <stdbool.h> |
| 23 | #include <string.h> |
| 24 | #include <assert.h> |
| 25 | #include <math.h> |
| 26 | #include <getopt.h> |
| 27 | #include <stdarg.h> |
| 28 | #include "mlw_encode.h" |
| 29 | #include "mlw_decode.h" |
| 30 | |
| 31 | static void fatal_error(const char *format, ...) { |
| 32 | va_list ap; |
| 33 | va_start (ap, format); |
| 34 | vfprintf(stderr, format, ap); |
| 35 | va_end(ap); |
| 36 | exit(1); |
| 37 | } |
| 38 | |
| 39 | static void print_usage(void) { |
| 40 | printf("Usage:\n"); |
| 41 | printf(" Encode: ./mlw_codec [<options>] [-o <outfile.mlw>] infiles.bin\n"); |
| 42 | printf(" Decode: ./mlw_codec [<options>] -d [-o <outfile.bin>] infiles.mlw\n"); |
| 43 | printf("\n"); |
| 44 | printf("Options:\n"); |
| 45 | printf(" -w The uncompressed weight file is an int16_t (word) stream.\n"); |
| 46 | printf(" This is to support 9bit signed weights. Little endian is assuemd.\n"); |
| 47 | printf(" The default format is int8_t (byte) stream (if -w is not specified)\n"); |
| 48 | printf("\n"); |
| 49 | } |
| 50 | |
| 51 | // Read file into allocated buffer. Return length in bytes. |
| 52 | static int read_file( FILE *f, uint8_t **buf) { |
| 53 | |
| 54 | fseek(f, 0, SEEK_END); |
| 55 | int size = ftell(f); |
| 56 | fseek(f, 0, SEEK_SET); |
| 57 | *buf = malloc(size); |
| 58 | assert(*buf); |
| 59 | int rsize = fread(*buf, 1, size, f); |
| 60 | assert(rsize==size); |
| 61 | fclose(f); |
| 62 | return size; |
| 63 | } |
| 64 | |
| 65 | |
| 66 | #define MAX_INFILES 1000 |
| 67 | |
| 68 | int main(int argc, char *argv[]) |
| 69 | { |
| 70 | int c, decode=0, inbuf_size, outbuf_size; |
| 71 | char *infile_name[MAX_INFILES], *outfile_name=0; |
| 72 | uint8_t *inbuf=0, *outbuf=0; |
| 73 | FILE *infile, *outfile=0; |
| 74 | int verbose=0, infile_idx=0; |
| 75 | int int16_format=0; |
| 76 | |
| 77 | if (argc==1) { |
| 78 | print_usage(); |
| 79 | exit(1); |
| 80 | } |
| 81 | |
| 82 | // Parse command line options |
| 83 | while( optind < argc) { |
| 84 | // Parse options |
| 85 | while ((c = getopt (argc, argv, "di:o:v:w?")) != -1) { |
| 86 | switch (c) { |
| 87 | case 'd': |
| 88 | decode=1; |
| 89 | break; |
| 90 | case 'i': |
| 91 | assert(infile_idx<MAX_INFILES); |
| 92 | infile_name[infile_idx++]=optarg; |
| 93 | break; |
| 94 | case 'o': |
| 95 | outfile_name=optarg; |
| 96 | break; |
| 97 | case 'v': |
| 98 | verbose=atoi(optarg); |
| 99 | break; |
| 100 | case 'w': |
| 101 | int16_format=1; |
| 102 | break; |
| 103 | case '?': |
| 104 | print_usage(); |
| 105 | exit(0); |
| 106 | } |
| 107 | } |
| 108 | |
| 109 | if (optind<argc) { |
| 110 | assert(infile_idx<MAX_INFILES); |
| 111 | infile_name[infile_idx++]=argv[optind]; |
| 112 | optind++; |
| 113 | |
| 114 | } |
| 115 | } |
| 116 | |
| 117 | if (outfile_name) { |
| 118 | outfile=fopen(outfile_name, "wb"); |
| 119 | if (!outfile) |
| 120 | fatal_error("ERROR: cannot open outfile %s\n", outfile_name); |
| 121 | } |
| 122 | |
| 123 | // Loop over input files |
| 124 | int nbr_of_infiles=infile_idx; |
| 125 | for(infile_idx=0; infile_idx<nbr_of_infiles; infile_idx++) { |
| 126 | infile=fopen(infile_name[infile_idx], "rb"); |
| 127 | if (!infile) |
| 128 | fatal_error("ERROR: cannot open infile %s\n", infile_name[infile_idx]); |
| 129 | |
| 130 | // Read infile to buffer |
| 131 | inbuf_size = read_file(infile, &inbuf); |
| 132 | |
| 133 | if (!decode) { |
| 134 | // Encode |
| 135 | int i, n = int16_format ? inbuf_size/sizeof(int16_t) : inbuf_size; |
| 136 | int16_t *weights = malloc( n * sizeof(int16_t) ); |
| 137 | for(i=0; i<n; i++) { |
| 138 | weights[i] = int16_format ? ((int16_t*)inbuf)[i] : ((int8_t*)inbuf)[i]; |
| 139 | } |
| 140 | outbuf_size = mlw_encode( weights, n, &outbuf, verbose); |
| 141 | free(weights); |
| 142 | printf("Input size %d output size %d bpw %4.2f\n", n, outbuf_size, outbuf_size*8.0/n); |
| 143 | } else { |
| 144 | // Decode |
| 145 | int i, n; |
| 146 | int16_t *weights; |
| 147 | n = mlw_decode( inbuf, inbuf_size, &weights, verbose); |
| 148 | outbuf_size = int16_format ? n*sizeof(int16_t) : n; |
| 149 | outbuf = malloc( outbuf_size ); |
| 150 | assert(outbuf); |
| 151 | for(i=0; i<n; i++) { |
| 152 | if (int16_format) |
| 153 | ((int16_t*)outbuf)[i] = weights[i]; |
| 154 | else |
| 155 | outbuf[i] = weights[i]; |
| 156 | } |
| 157 | free(weights); |
| 158 | printf("Input size %d output size %d bpw %4.2f\n", inbuf_size, n, inbuf_size*8.0/n); |
| 159 | |
| 160 | } |
| 161 | |
| 162 | if (outfile) { |
| 163 | fwrite(outbuf, 1, outbuf_size, outfile); |
| 164 | } |
| 165 | |
| 166 | if (inbuf) |
| 167 | free(inbuf); |
| 168 | if (outbuf) |
| 169 | free(outbuf); |
| 170 | } |
| 171 | |
| 172 | if (outfile) { |
| 173 | fclose(outfile); |
| 174 | } |
| 175 | |
| 176 | return 0; |
| 177 | } |