Fix mlw_module

Fixedx size calculation in mlw_reorder_encode.
Fixed build warnings.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: Iac9408b9972a29b5a3403ba11f80dc4eaaa35453
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c
index 1e13dd2..75ea8e9 100644
--- a/ethosu/mlw_codec/mlw_codecmodule.c
+++ b/ethosu/mlw_codec/mlw_codecmodule.c
@@ -40,7 +40,7 @@
  *  decomp_w,
  *  verbose=0)
  *
- * output: bytearray
+ * output: (bytearray, int)
  */
 
 static PyObject *
@@ -79,7 +79,7 @@
         &verbose))
         return NULL;
 
-    PyArrayObject* input_ndarray_object = PyArray_FROM_OTF(
+    PyArrayObject* input_ndarray_object = (PyArrayObject*)PyArray_FROM_OTF(
         input_object,
         NPY_INT64,
         NPY_ARRAY_ALIGNED);
@@ -111,7 +111,7 @@
         return NULL;
     }
     uint8_t* output_buffer = NULL;
-    int padded_length;
+    int64_t padded_length;
 
     int output_length = mlw_reorder_encode(
         ifm_ublock_depth,
@@ -132,11 +132,6 @@
         &padded_length,
         verbose);
 
-    if (output_buffer == NULL)
-    {
-        return PyErr_NoMemory();
-    }
-
     PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
     PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
 
@@ -144,6 +139,8 @@
     mlw_free_outbuf(output_buffer);
 
     PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
+
+    Py_DECREF(input_ndarray_object);
     Py_DECREF(output_byte_array);
     Py_DECREF(padded_length_obj);
     return ret;
diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c
index 62e8360..3a03091 100644
--- a/ethosu/mlw_codec/mlw_encode.c
+++ b/ethosu/mlw_codec/mlw_encode.c
@@ -896,56 +896,6 @@
     return round_up_divide(num, den) * den;
 }
 
-static int get_weight_cnt(
-    int ifm_ublock_depth,
-    int ofm_ublock_depth,
-    int ofm_depth,
-    int kernel_height,
-    int kernel_width,
-    int ifm_depth,
-    int ofm_block_depth,
-    int is_depthwise,
-    int is_partkernel,
-    int ifm_bitdepth,
-    int decomp_h,
-    int decomp_w)
-{
-    int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
-    int subkernel_elements = decomp_w * decomp_h;
-    if (is_partkernel)
-    {
-        if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
-        {
-            subkernel_elements = round_up(subkernel_elements, 2);
-        }
-        else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
-        {
-            subkernel_elements = round_up(subkernel_elements, 4);
-        }
-    }
-    else if (is_depthwise)
-    {
-        subkernel_elements = round_up(subkernel_elements, 4);
-    }
-    int clipped_ifm_block_depth = is_depthwise ? ifm_ublock_depth : ifm_block_depth;
-    int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
-    int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
-
-    int input_length = 1;
-    input_length *= is_depthwise ? 1 : ifm_ublock_depth;
-    input_length *= ofm_ublock_depth;
-    input_length *= round_up_divide(ifm_block_depth_inner, ifm_ublock_depth);
-    input_length *= subkernel_elements;
-    input_length *= round_up_divide(ofm_block_depth, ofm_ublock_depth);
-    input_length *= round_up_divide(ifm_block_depth_outer, ifm_ublock_depth);
-    input_length *= round_up_divide(kernel_width, decomp_w);
-    input_length *= round_up_divide(kernel_height, decomp_h);
-    input_length *= round_up_divide(is_depthwise ? 1 : ifm_depth, ifm_block_depth);
-    input_length *= round_up_divide(ofm_depth, ofm_block_depth);
-
-    return input_length;
-}
-
 struct brick_buf_s
 {
     uint8_t* buf;
@@ -965,7 +915,15 @@
     return *(int16_t*)p;
 }
 
-static int reorder(
+static void reorder_free(int16_t* buf)
+{
+    if (buf)
+    {
+        free(buf);
+    }
+}
+
+static int16_t* reorder(
     int ifm_ublock_depth,
     int ofm_ublock_depth,
     int ofm_depth,
@@ -980,14 +938,23 @@
     int ifm_bitdepth,
     int decomp_h,
     int decomp_w,
-    int16_t* weights)
+    int64_t* padded_length)
 {
+    /* Size unknown. Start with one page at least */
+    *padded_length = round_up(max(1, sizeof(int16_t)*
+        ofm_depth*
+        kernel_height*
+        kernel_width*
+        ifm_depth),
+    4*1024) / sizeof(int16_t);
+    int16_t* weights = (int16_t*)malloc(*padded_length * sizeof(int16_t));
+
     brick_buf_t brick_buf;
     brick_buf.buf = inbuf;
     brick_buf.strides = strides;
 
     int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
-    int weight_cnt = 0;
+    int64_t weight_cnt = 0;
     for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth)
     {
         int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z);
@@ -1065,7 +1032,17 @@
                                             {
                                                 weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z);
                                             }
+                                            else
+                                            {
+                                                weights[weight_cnt] = 0;
+                                            }
                                             weight_cnt++;
+                                            if (weight_cnt == *padded_length)
+                                            {
+                                                // Reallocate by doubling the buffer size as needed
+                                                *padded_length *= 2;
+                                                weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t));
+                                            }
                                         }
                                     }
                                 }
@@ -1077,7 +1054,9 @@
         }
     }
 
-    return weight_cnt;
+    *padded_length = weight_cnt;
+    weights = (int16_t*)realloc(weights, *padded_length * sizeof(int16_t));
+    return weights;
 }
 
 // Reorder and encode the given weight stream
@@ -1099,32 +1078,11 @@
     int decomp_h,
     int decomp_w,
     uint8_t **outbuf, // *outbuf must be freed by caller
-    int* padded_length,
+    int64_t* padded_length,
     int verbose)
 {
-    /* Get an upper bound of the weight count */
-    int input_length = get_weight_cnt(
-        ifm_ublock_depth,
-        ofm_ublock_depth,
-        ofm_depth,
-        kernel_height,
-        kernel_width,
-        ifm_depth,
-        ofm_block_depth,
-        is_depthwise,
-        is_partkernel,
-        ifm_bitdepth,
-        decomp_h,
-        decomp_w);
-
-    int16_t* weights = (int16_t*)calloc(input_length, sizeof(int16_t));
-    if (weights == NULL)
-    {
-        return 0;
-    }
-
-    /* Reorder weights and update input_length */
-    input_length = reorder(
+    /* Reorder weights */
+    int16_t* weights = reorder(
         ifm_ublock_depth,
         ofm_ublock_depth,
         ofm_depth,
@@ -1139,11 +1097,15 @@
         ifm_bitdepth,
         decomp_h,
         decomp_w,
-        weights);
+        padded_length);
 
-    int output_length = mlw_encode(weights, input_length, outbuf, verbose);
-    free(weights);
-    *padded_length = input_length;
+    /* Then encode */
+    int output_length = 0;
+    if (*padded_length > 0)
+    {
+        output_length = mlw_encode(weights, *padded_length, outbuf, verbose);
+    }
+    reorder_free(weights);
 
     return output_length;
 }
diff --git a/ethosu/mlw_codec/mlw_encode.h b/ethosu/mlw_codec/mlw_encode.h
index a995ac6..743603b 100644
--- a/ethosu/mlw_codec/mlw_encode.h
+++ b/ethosu/mlw_codec/mlw_encode.h
@@ -38,6 +38,26 @@
 EXPORTED
 void mlw_free_outbuf(uint8_t *outbuf);
 
+EXPORTED
+int mlw_reorder_encode(
+    int ifm_ublock_depth,
+    int ofm_ublock_depth,
+    int ofm_depth,
+    int kernel_height,
+    int kernel_width,
+    int ifm_depth,
+    int* brick_strides,
+    void* inbuf,
+    int ofm_block_depth,
+    int is_depthwise,
+    int is_partkernel,
+    int ifm_bitdepth,
+    int decomp_h,
+    int decomp_w,
+    uint8_t **outbuf,
+    int64_t* padded_length,
+    int verbose);
+
 #if __cplusplus
 }
 #endif
diff --git a/setup.py b/setup.py
index d213743..6900a67 100644
--- a/setup.py
+++ b/setup.py
@@ -44,6 +44,7 @@
     "ethosu.mlw_codec",
     ["ethosu/mlw_codec/mlw_encode.c", "ethosu/mlw_codec/mlw_decode.c", "ethosu/mlw_codec/mlw_codecmodule.c"],
     include_dirs=[np.get_include()],
+    define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_9_API_VERSION")],
 )
 
 setup(