blob: 11cb8f99a2ed7b2e2531c496efd6bcb2cbcec22b [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001/*
Rickard Bolinbc6ee582022-11-04 08:24:29 +00002 * SPDX-FileCopyrightText: Copyright 2020 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01003 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <stdio.h>
20#include <stdlib.h>
21#include <stdint.h>
22#include <stdbool.h>
23#include <string.h>
24#include <assert.h>
25#include <math.h>
26#include <getopt.h>
27#include <stdarg.h>
28#include "mlw_encode.h"
29#include "mlw_decode.h"
30
31static void fatal_error(const char *format, ...) {
32 va_list ap;
33 va_start (ap, format);
34 vfprintf(stderr, format, ap);
35 va_end(ap);
36 exit(1);
37}
38
39static void print_usage(void) {
40 printf("Usage:\n");
41 printf(" Encode: ./mlw_codec [<options>] [-o <outfile.mlw>] infiles.bin\n");
42 printf(" Decode: ./mlw_codec [<options>] -d [-o <outfile.bin>] infiles.mlw\n");
43 printf("\n");
44 printf("Options:\n");
45 printf(" -w The uncompressed weight file is an int16_t (word) stream.\n");
46 printf(" This is to support 9bit signed weights. Little endian is assuemd.\n");
47 printf(" The default format is int8_t (byte) stream (if -w is not specified)\n");
48 printf("\n");
49}
50
51// Read file into allocated buffer. Return length in bytes.
52static int read_file( FILE *f, uint8_t **buf) {
53
54 fseek(f, 0, SEEK_END);
55 int size = ftell(f);
56 fseek(f, 0, SEEK_SET);
57 *buf = malloc(size);
58 assert(*buf);
59 int rsize = fread(*buf, 1, size, f);
60 assert(rsize==size);
61 fclose(f);
62 return size;
63}
64
65
66#define MAX_INFILES 1000
67
68int main(int argc, char *argv[])
69{
70 int c, decode=0, inbuf_size, outbuf_size;
71 char *infile_name[MAX_INFILES], *outfile_name=0;
72 uint8_t *inbuf=0, *outbuf=0;
73 FILE *infile, *outfile=0;
74 int verbose=0, infile_idx=0;
75 int int16_format=0;
76
77 if (argc==1) {
78 print_usage();
79 exit(1);
80 }
81
82 // Parse command line options
83 while( optind < argc) {
84 // Parse options
85 while ((c = getopt (argc, argv, "di:o:v:w?")) != -1) {
86 switch (c) {
87 case 'd':
88 decode=1;
89 break;
90 case 'i':
91 assert(infile_idx<MAX_INFILES);
92 infile_name[infile_idx++]=optarg;
93 break;
94 case 'o':
95 outfile_name=optarg;
96 break;
97 case 'v':
98 verbose=atoi(optarg);
99 break;
100 case 'w':
101 int16_format=1;
102 break;
103 case '?':
104 print_usage();
105 exit(0);
106 }
107 }
108
109 if (optind<argc) {
110 assert(infile_idx<MAX_INFILES);
111 infile_name[infile_idx++]=argv[optind];
112 optind++;
113
114 }
115 }
116
117 if (outfile_name) {
118 outfile=fopen(outfile_name, "wb");
119 if (!outfile)
120 fatal_error("ERROR: cannot open outfile %s\n", outfile_name);
121 }
122
123 // Loop over input files
124 int nbr_of_infiles=infile_idx;
125 for(infile_idx=0; infile_idx<nbr_of_infiles; infile_idx++) {
126 infile=fopen(infile_name[infile_idx], "rb");
127 if (!infile)
128 fatal_error("ERROR: cannot open infile %s\n", infile_name[infile_idx]);
129
130 // Read infile to buffer
131 inbuf_size = read_file(infile, &inbuf);
132
133 if (!decode) {
134 // Encode
Michael McGeaghd5cf7652020-12-03 13:53:56 +0000135 int i, n = int16_format ? inbuf_size/(int)sizeof(int16_t) : inbuf_size;
Tim Hall79d07d22020-04-27 18:20:16 +0100136 int16_t *weights = malloc( n * sizeof(int16_t) );
137 for(i=0; i<n; i++) {
138 weights[i] = int16_format ? ((int16_t*)inbuf)[i] : ((int8_t*)inbuf)[i];
139 }
140 outbuf_size = mlw_encode( weights, n, &outbuf, verbose);
141 free(weights);
142 printf("Input size %d output size %d bpw %4.2f\n", n, outbuf_size, outbuf_size*8.0/n);
143 } else {
144 // Decode
145 int i, n;
146 int16_t *weights;
147 n = mlw_decode( inbuf, inbuf_size, &weights, verbose);
Michael McGeaghd5cf7652020-12-03 13:53:56 +0000148 outbuf_size = int16_format ? n*(int)sizeof(int16_t) : n;
Tim Hall79d07d22020-04-27 18:20:16 +0100149 outbuf = malloc( outbuf_size );
150 assert(outbuf);
151 for(i=0; i<n; i++) {
152 if (int16_format)
153 ((int16_t*)outbuf)[i] = weights[i];
154 else
155 outbuf[i] = weights[i];
156 }
157 free(weights);
158 printf("Input size %d output size %d bpw %4.2f\n", inbuf_size, n, inbuf_size*8.0/n);
159
160 }
161
162 if (outfile) {
163 fwrite(outbuf, 1, outbuf_size, outfile);
164 }
165
166 if (inbuf)
167 free(inbuf);
168 if (outbuf)
169 free(outbuf);
170 }
171
172 if (outfile) {
173 fclose(outfile);
174 }
175
176 return 0;
177}