MLBEDSW-8568 Fix mlw_codec memory handling

Added missing memory allocation checks to mlw_codec.

Change-Id: I20c04d5d9c934b9c715a2b2049705f853d90825a
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c
index 8c540d6..1f172ee 100644
--- a/ethosu/mlw_codec/mlw_codecmodule.c
+++ b/ethosu/mlw_codec/mlw_codecmodule.c
@@ -1,5 +1,5 @@
 /*
- * SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -84,6 +84,7 @@
         NPY_ARRAY_ALIGNED);
     if (input_ndarray_object == NULL)
     {
+        PyErr_SetString(PyExc_ValueError, "Invalid input array");
         return NULL;
     }
 
@@ -137,17 +138,23 @@
         &padded_length,
         verbose);
 
-    PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
-    PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
+    PyObject* ret = NULL;
+    if ( output_length < 0 ) {
+        ret = PyErr_NoMemory();
+    } else {
+        PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
+        PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
+        if ( output_byte_array && padded_length_obj ) {
+            ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
+        }
+        Py_XDECREF(output_byte_array);
+        Py_XDECREF(padded_length_obj);
+    }
 
     /* Discard the output buffer */
     mlw_free_outbuf(output_buffer);
 
-    PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
-
     Py_DECREF(input_ndarray_object);
-    Py_DECREF(output_byte_array);
-    Py_DECREF(padded_length_obj);
     return ret;
 }
 
@@ -216,7 +223,8 @@
 
   int output_length = mlw_encode(input_buffer, (int)input_length, &output_buffer, verbose);
 
-  PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length);
+  PyObject *output_byte_array = output_length < 0 ? PyErr_NoMemory() :
+    PyByteArray_FromStringAndSize ((char *) output_buffer, output_length);
 
   /* Discard the temporary input and output buffers.  */
   free (input_buffer);
diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c
index e8e1a8c..3ec2490 100644
--- a/ethosu/mlw_codec/mlw_encode.c
+++ b/ethosu/mlw_codec/mlw_encode.c
@@ -1,5 +1,5 @@
 /*
- * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -87,14 +87,14 @@
     // Preliminary allocation of sufficient size
     restart_pos = (int*)malloc( max_palettes*sizeof(int) );
     if (!restart_pos) {
-        return 0;
+        return -1;
     }
     last_restart_idx=0;
     got_palette=0;
     restart_i=1;
     restart_pos[0] = 0;
     zero_cnt=0;
-    memset( prev_idx, -1, sizeof(prev_idx));
+    memset(prev_idx, -1, sizeof(prev_idx));
     for(i=0; i<size; i++) {
         // Guess if zeros should be excluded from the palette
         int exclude_zero = zero_cnt > (i-last_restart_idx)/4;
@@ -113,7 +113,7 @@
                             max_palettes = max_palettes*2;
                             restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
                             if (!restart_pos) {
-                                return 0;
+                                return -1;
                             }
                         }
                         DPRINTF("restart %d pos %d\n", restart_i, i);
@@ -184,7 +184,7 @@
                             max_palettes = max_palettes*2;
                             restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
                             if (!restart_pos) {
-                                return 0;
+                                return -1;
                             }
                         }
                         restart_pos[restart_i++] = last_restart_idx;
@@ -199,7 +199,7 @@
     }
     // Reallocate to actual size
     *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) );
-    return *palette_restart_positions ? restart_i : 0;
+    return *palette_restart_positions ? restart_i : -1;
 }
 
 // Calculate frequency table
@@ -417,11 +417,18 @@
 
     search_state_t *state[MAX_ZWCFG];
     for(i=0; i<n_cfg; i++) {
-        state[i] = malloc( sizeof(search_state_t) * (n_inval+1) );
+        CHECKED_MALLOC(state[i], sizeof(search_state_t) * (n_inval + 1));
         state[i][0].bitcnt=0;
         state[i][0].prev_cfg=i;
     }
 
+    if ( i < n_cfg ) {  // Memory allocation failed - clean up and exit
+        while ( i ) {
+            free(state[--i]);
+        }
+        return -1;
+    }
+
     // Loop over inval_buf
     int existing_idx=0;
     for(i=0; i<n_inval; i++) {
@@ -784,6 +791,10 @@
         CHECKED_MALLOC( w_slice_cfg, size );
         CHECKED_MALLOC( w_slice_pos, size*sizeof(int) );
         n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt);
+        if ( n_w_slice < 0 ) {  // Memory allocation failed
+            bitpos = -1;
+            break;
+        }
         if (n_weights==0)
             n_w_slice = 0;
 
@@ -793,6 +804,10 @@
             CHECKED_MALLOC( z_slice_cfg, size );
             CHECKED_MALLOC( z_slice_pos, size*sizeof(int) );
             n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt);
+            if ( n_z_slice < 0 ) {  // Memory allocation failed
+                bitpos = -1;
+                break;
+            }
         }
 
         // Encode bitstream slice
@@ -875,13 +890,12 @@
     }
 
     // Analyse input data to find palette re-programming points
-    int n_restarts;
     int *palette_restart_pos = NULL;
-    n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos);
+    int n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos);
 
     // Compress each section (using a single palette) separately
-    int bitpos=0;
-    for(i=0; i<n_restarts; i++) {
+    int bitpos = 0;
+    for ( i = 0; i < n_restarts && bitpos >= 0; i++ ) {
         palette_t palette;
         int pos, size;
         pos = palette_restart_pos[i];
@@ -892,9 +906,9 @@
                                  *outbuf, bitbuf_size, bitpos, verbose );
     }
 
-
-    // Add end of stream marker and align to 128bit
-    {
+    int ret = -1;
+    if ( bitpos >= 0 && n_restarts >= 0 ) {  // If allocation fails bitpos or n_restarts < 0
+        // Add end of stream marker and align to 128bit
         bitbuf_t bitbuf_s, *bb=&bitbuf_s;
         bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 );
         bb->pos = bitpos;
@@ -906,14 +920,18 @@
           bitbuf_put( bb, "PAD", 8, 0xff );
         }
         bitpos = bb->pos;
+
+        assert((bitpos&127)==0);
+        int outbuf_size = bitpos/8;
+        *outbuf = realloc(*outbuf, outbuf_size);
+        if ( *outbuf ) {
+            ret = outbuf_size;
+        }
     }
-    assert((bitpos&127)==0);
-    int outbuf_size = bitpos/8;
-    *outbuf = realloc( *outbuf, outbuf_size);
 
     free(palette_restart_pos);
 
-    return *outbuf ? outbuf_size : -1;
+    return ret;
 }
 
 void mlw_free_outbuf( uint8_t *outbuf ) {
@@ -965,7 +983,7 @@
     int decomp_w,
     int64_t* padded_length)
 {
-    *padded_length = 0;
+    *padded_length = -1;
     /* Size unknown. Start with one page at least */
     int64_t length = round_up(max(1, sizeof(int16_t)*
         ofm_depth*
@@ -1090,7 +1108,9 @@
 
 
     weights = (int16_t*)realloc(weights, weight_cnt * sizeof(int16_t));
-    *padded_length = weights ? weight_cnt : 0;
+    if ( weights ) {
+        *padded_length = weight_cnt;
+    }
 
     return weights;
 }
@@ -1136,7 +1156,7 @@
         padded_length);
 
     /* Then encode */
-    int output_length = 0;
+    int output_length = -1;
     if (*padded_length > 0 && *padded_length <= INT32_MAX)
     {
         output_length = mlw_encode(weights, (int)*padded_length, outbuf, verbose);