Fix mlw_module

Fixedx size calculation in mlw_reorder_encode.
Fixed build warnings.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: Iac9408b9972a29b5a3403ba11f80dc4eaaa35453
diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c
index 62e8360..3a03091 100644
--- a/ethosu/mlw_codec/mlw_encode.c
+++ b/ethosu/mlw_codec/mlw_encode.c
@@ -896,56 +896,6 @@
     return round_up_divide(num, den) * den;
 }
 
-static int get_weight_cnt(
-    int ifm_ublock_depth,
-    int ofm_ublock_depth,
-    int ofm_depth,
-    int kernel_height,
-    int kernel_width,
-    int ifm_depth,
-    int ofm_block_depth,
-    int is_depthwise,
-    int is_partkernel,
-    int ifm_bitdepth,
-    int decomp_h,
-    int decomp_w)
-{
-    int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
-    int subkernel_elements = decomp_w * decomp_h;
-    if (is_partkernel)
-    {
-        if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
-        {
-            subkernel_elements = round_up(subkernel_elements, 2);
-        }
-        else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
-        {
-            subkernel_elements = round_up(subkernel_elements, 4);
-        }
-    }
-    else if (is_depthwise)
-    {
-        subkernel_elements = round_up(subkernel_elements, 4);
-    }
-    int clipped_ifm_block_depth = is_depthwise ? ifm_ublock_depth : ifm_block_depth;
-    int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
-    int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
-
-    int input_length = 1;
-    input_length *= is_depthwise ? 1 : ifm_ublock_depth;
-    input_length *= ofm_ublock_depth;
-    input_length *= round_up_divide(ifm_block_depth_inner, ifm_ublock_depth);
-    input_length *= subkernel_elements;
-    input_length *= round_up_divide(ofm_block_depth, ofm_ublock_depth);
-    input_length *= round_up_divide(ifm_block_depth_outer, ifm_ublock_depth);
-    input_length *= round_up_divide(kernel_width, decomp_w);
-    input_length *= round_up_divide(kernel_height, decomp_h);
-    input_length *= round_up_divide(is_depthwise ? 1 : ifm_depth, ifm_block_depth);
-    input_length *= round_up_divide(ofm_depth, ofm_block_depth);
-
-    return input_length;
-}
-
 struct brick_buf_s
 {
     uint8_t* buf;
@@ -965,7 +915,15 @@
     return *(int16_t*)p;
 }
 
-static int reorder(
+static void reorder_free(int16_t* buf)
+{
+    if (buf)
+    {
+        free(buf);
+    }
+}
+
+static int16_t* reorder(
     int ifm_ublock_depth,
     int ofm_ublock_depth,
     int ofm_depth,
@@ -980,14 +938,23 @@
     int ifm_bitdepth,
     int decomp_h,
     int decomp_w,
-    int16_t* weights)
+    int64_t* padded_length)
 {
+    /* Size unknown. Start with one page at least */
+    *padded_length = round_up(max(1, sizeof(int16_t)*
+        ofm_depth*
+        kernel_height*
+        kernel_width*
+        ifm_depth),
+    4*1024) / sizeof(int16_t);
+    int16_t* weights = (int16_t*)malloc(*padded_length * sizeof(int16_t));
+
     brick_buf_t brick_buf;
     brick_buf.buf = inbuf;
     brick_buf.strides = strides;
 
     int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
-    int weight_cnt = 0;
+    int64_t weight_cnt = 0;
     for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth)
     {
         int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z);
@@ -1065,7 +1032,17 @@
                                             {
                                                 weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z);
                                             }
+                                            else
+                                            {
+                                                weights[weight_cnt] = 0;
+                                            }
                                             weight_cnt++;
+                                            if (weight_cnt == *padded_length)
+                                            {
+                                                // Reallocate by doubling the buffer size as needed
+                                                *padded_length *= 2;
+                                                weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t));
+                                            }
                                         }
                                     }
                                 }
@@ -1077,7 +1054,9 @@
         }
     }
 
-    return weight_cnt;
+    *padded_length = weight_cnt;
+    weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t));
+    return weights;
 }
 
 // Reorder and encode the given weight stream
@@ -1099,32 +1078,11 @@
     int decomp_h,
     int decomp_w,
     uint8_t **outbuf, // *outbuf must be freed by caller
-    int* padded_length,
+    int64_t* padded_length,
     int verbose)
 {
-    /* Get an upper bound of the weight count */
-    int input_length = get_weight_cnt(
-        ifm_ublock_depth,
-        ofm_ublock_depth,
-        ofm_depth,
-        kernel_height,
-        kernel_width,
-        ifm_depth,
-        ofm_block_depth,
-        is_depthwise,
-        is_partkernel,
-        ifm_bitdepth,
-        decomp_h,
-        decomp_w);
-
-    int16_t* weights = (int16_t*)calloc(input_length, sizeof(int16_t));
-    if (weights == NULL)
-    {
-        return 0;
-    }
-
-    /* Reorder weights and update input_length */
-    input_length = reorder(
+    /* Reorder weights */
+    int16_t* weights = reorder(
         ifm_ublock_depth,
         ofm_ublock_depth,
         ofm_depth,
@@ -1139,11 +1097,15 @@
         ifm_bitdepth,
         decomp_h,
         decomp_w,
-        weights);
+        padded_length);
 
-    int output_length = mlw_encode(weights, input_length, outbuf, verbose);
-    free(weights);
-    *padded_length = input_length;
+    /* Then encode */
+    int output_length = 0;
+    if (*padded_length > 0)
+    {
+        output_length = mlw_encode(weights, *padded_length, outbuf, verbose);
+    }
+    reorder_free(weights);
 
     return output_length;
 }