MLECO-2079 Adding the python KWS example

Signed-off-by: Eanna O Cathain <eanna.ocathain@arm.com>
Change-Id: Ie1463aaeb5e3cade22df8f560ae99a8e1c4a9c17
diff --git a/python/pyarmnn/examples/common/audio_capture.py b/python/pyarmnn/examples/common/audio_capture.py
new file mode 100644
index 0000000..1bd53b4
--- /dev/null
+++ b/python/pyarmnn/examples/common/audio_capture.py
@@ -0,0 +1,149 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+"""Contains CaptureAudioStream class for capturing chunks of audio data from incoming
+  stream and generic capture_audio function for capturing from files."""
+import collections
+import time
+from queue import Queue
+from typing import Generator
+
+import numpy as np
+import sounddevice as sd
+import soundfile as sf
+
+AudioCaptureParams = collections.namedtuple('AudioCaptureParams',
+                                            ['dtype', 'overlap', 'min_samples', 'sampling_freq', 'mono'])
+
+
+def capture_audio(audio_file_path, params_tuple) -> Generator[np.ndarray, None, None]:
+    """Creates a generator that yields audio data from a file. Data is padded with
+    zeros if necessary to make up minimum number of samples.
+    Args:
+        audio_file_path: Path to audio file provided by user.
+        params_tuple: Sampling parameters for model used
+    Yields:
+        Blocks of audio data of minimum sample size.
+    """
+    with sf.SoundFile(audio_file_path) as audio_file:
+        for block in audio_file.blocks(
+                blocksize=params_tuple.min_samples,
+                dtype=params_tuple.dtype,
+                always_2d=True,
+                fill_value=0,
+                overlap=params_tuple.overlap
+        ):
+            if params_tuple.mono and block.shape[0] > 1:
+                block = np.mean(block, dtype=block.dtype, axis=1)
+            yield block
+
+
+class CaptureAudioStream:
+
+    def __init__(self, audio_capture_params):
+        self.audio_capture_params = audio_capture_params
+        self.collection = np.zeros(self.audio_capture_params.min_samples + self.audio_capture_params.overlap).astype(
+            dtype=self.audio_capture_params.dtype)
+        self.is_active = True
+        self.is_first_window = True
+        self.duration = False
+        self.block_count = 0
+        self.current_block = 0
+        self.queue = Queue(2)
+
+    def set_stream_defaults(self):
+        """Discovers input devices on the system and sets default stream parameters."""
+        print(sd.query_devices())
+        device = input("Select input device by index or name: ")
+
+        try:
+            sd.default.device = int(device)
+        except ValueError:
+            sd.default.device = str(device)
+
+        sd.default.samplerate = self.audio_capture_params.sampling_freq
+        sd.default.blocksize = self.audio_capture_params.min_samples
+        sd.default.dtype = self.audio_capture_params.dtype
+        sd.default.channels = 1 if self.audio_capture_params.mono else 2
+
+    def set_recording_duration(self, duration):
+        """Sets a time duration (in integer seconds) for recording audio. Total time duration is
+        adjusted to a minimum based on the parameters of the model used. Durations less than 1
+        result in endless recording.
+
+        Args:
+            duration (int): User-provided command line argument for time duration of recording.
+        """
+        if duration > 0:
+            min_duration = int(
+                np.ceil(self.audio_capture_params.min_samples / self.audio_capture_params.sampling_freq)
+            )
+            if duration < min_duration:
+                print(f"Minimum duration must be {min_duration} seconds of audio")
+                print(f"Setting minimum recording duration...")
+                duration = min_duration
+
+            print(f"Recording duration is {duration} seconds")
+            self.duration = self.audio_capture_params.sampling_freq * duration
+            self.block_count, remainder_samples = divmod(
+                self.duration, self.audio_capture_params.min_samples
+            )
+
+            if remainder_samples > 0.5 * self.audio_capture_params.sampling_freq:
+                self.block_count += 1
+        else:
+            self.duration = False  # Record forever
+
+    def countdown(self, delay=3):
+        """3 second countdown prior to recording audio."""
+        print("Beginning recording in...")
+        for i in range(delay, 0, -1):
+            print(f"{i}...")
+            time.sleep(1)
+
+    def update(self):
+        """If a duration has been set, increments a counter to update the number of blocks of audio
+        data left to be collected. The stream is deactivated upon reaching the maximum block count
+        determined by the duration.
+        """
+        if self.duration:
+            self.current_block += 1
+            if self.current_block == self.block_count:
+                self.is_active = False
+
+    def capture_data(self):
+        """Gets the next window of audio data by retrieving the newest data from a queue and
+        shifting the position of the data in the collection. Overlap values of less than `min_samples` are supported.
+        """
+        new_data = self.queue.get()
+
+        if self.is_first_window or self.audio_capture_params.overlap == 0:
+            self.collection[:self.audio_capture_params.min_samples] = new_data[:]
+
+        elif self.audio_capture_params.overlap < self.audio_capture_params.min_samples:
+            #
+            self.collection[0:self.audio_capture_params.overlap] = \
+                self.collection[(self.audio_capture_params.min_samples - self.audio_capture_params.overlap):
+                                self.audio_capture_params.min_samples]
+
+            self.collection[self.audio_capture_params.overlap:(
+                    self.audio_capture_params.overlap + self.audio_capture_params.min_samples)] = new_data[:]
+        else:
+            raise ValueError(
+                "Capture Error: Overlap must be less than {}".format(self.audio_capture_params.min_samples))
+        audio_data = self.collection[0:self.audio_capture_params.min_samples]
+        return np.asarray(audio_data).astype(self.audio_capture_params.dtype)
+
+    def callback(self, data, frames, time, status):
+        """Places audio data from active stream into a queue for processing.
+        Update counter if recording duration is finite.
+         """
+
+        if self.duration:
+            self.update()
+
+        if self.audio_capture_params.mono:
+            audio_data = data.copy().flatten()
+        else:
+            audio_data = data.copy()
+
+        self.queue.put(audio_data)
diff --git a/python/pyarmnn/examples/common/cv_utils.py b/python/pyarmnn/examples/common/cv_utils.py
index fd848b8..e12ff50 100644
--- a/python/pyarmnn/examples/common/cv_utils.py
+++ b/python/pyarmnn/examples/common/cv_utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
 # SPDX-License-Identifier: MIT
 
 """
@@ -14,7 +14,7 @@
 import pyarmnn as ann
 
 
-def preprocess(frame: np.ndarray, input_binding_info: tuple):
+def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool):
     """
     Takes a frame, resizes, swaps channels and converts data type to match
     model input layer. The converted frame is wrapped in a const tensor
@@ -23,6 +23,7 @@
     Args:
         frame: Captured frame from video.
         input_binding_info:  Contains shape and data type of model input layer.
+        is_normalised: if the input layer expects normalised data
 
     Returns:
         Input tensor.
@@ -34,7 +35,8 @@
     # Expand dimensions and convert data type to match model input
     if input_binding_info[1].GetDataType() == ann.DataType_Float32:
         data_type = np.float32
-        resized_frame = resized_frame.astype("float32")/255
+        if is_normalised:
+            resized_frame = resized_frame.astype("float32")/255
     else:
         data_type = np.uint8
 
diff --git a/python/pyarmnn/examples/common/mfcc.py b/python/pyarmnn/examples/common/mfcc.py
new file mode 100644
index 0000000..2bab669
--- /dev/null
+++ b/python/pyarmnn/examples/common/mfcc.py
@@ -0,0 +1,238 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame."""
+
+import numpy as np
+import collections
+
+MFCCParams = collections.namedtuple('MFCCParams', ['sampling_freq', 'num_fbank_bins', 'mel_lo_freq', 'mel_hi_freq',
+                                                   'num_mfcc_feats', 'frame_len', 'use_htk_method', 'n_fft'])
+
+
+class MFCC:
+
+    def __init__(self, mfcc_params):
+        self.mfcc_params = mfcc_params
+        self.FREQ_STEP = 200.0 / 3
+        self.MIN_LOG_HZ = 1000.0
+        self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP
+        self.LOG_STEP = 1.8562979903656 / 27.0
+        self._frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
+        self._filter_bank_initialised = False
+        self.__frame = np.zeros(self._frame_len_padded)
+        self.__buffer = np.zeros(self._frame_len_padded)
+        self._filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
+        self._filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
+        self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins)
+        self._dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
+        self.__mel_filter_bank = self.create_mel_filter_bank()
+        self._np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_fft / 2) + 1])
+
+        for i in range(self.mfcc_params.num_fbank_bins):
+            k = 0
+            for j in range(int(self._filter_bank_filter_first[i]), int(self._filter_bank_filter_last[i]) + 1):
+                self._np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
+                k += 1
+
+    def mel_scale(self, freq, use_htk_method):
+        """
+        Gets the mel scale for a particular sample frequency.
+
+        Args:
+            freq: The sampling frequency.
+            use_htk_method: Boolean to set whether to use HTK method or not.
+
+        Returns:
+            the mel scale
+        """
+        if use_htk_method:
+            return 1127.0 * np.log(1.0 + freq / 700.0)
+        else:
+            mel = freq / self.FREQ_STEP
+
+        if freq >= self.MIN_LOG_HZ:
+            mel = self.MIN_LOG_MEL + np.log(freq / self.MIN_LOG_HZ) / self.LOG_STEP
+        return mel
+
+    def inv_mel_scale(self, mel_freq, use_htk_method):
+        """
+        Gets the sample frequency for a particular mel.
+
+        Args:
+            mel_freq: The mel frequency.
+            use_htk_method: Boolean to set whether to use HTK method or not.
+
+        Returns:
+            the sample frequency
+        """
+        if use_htk_method:
+            return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
+        else:
+            freq = self.FREQ_STEP * mel_freq
+
+            if mel_freq >= self.MIN_LOG_MEL:
+                freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL))
+            return freq
+
+    def spectrum_calc(self, audio_data):
+        return np.abs(np.fft.rfft(np.hanning(self.mfcc_params.frame_len + 1)[0:self.mfcc_params.frame_len] * audio_data,
+                                  self.mfcc_params.n_fft))
+
+    def log_mel(self, mel_energy):
+        mel_energy += 1e-10  # Avoid division by zero
+        return np.log(mel_energy)
+
+    def mfcc_compute(self, audio_data):
+        """
+        Extracts the MFCC for a single frame.
+
+        Args:
+            audio_data: The audio data to process.
+
+        Returns:
+            the MFCC features
+        """
+        if len(audio_data) != self.mfcc_params.frame_len:
+            raise ValueError(
+                f"audio_data buffer size {len(audio_data)} does not match frame length {self.mfcc_params.frame_len}")
+
+        audio_data = np.array(audio_data)
+        spec = self.spectrum_calc(audio_data)
+        mel_energy = np.dot(self._np_mel_bank.astype(np.float32),
+                            np.transpose(spec).astype(np.float32))
+        log_mel_energy = self.log_mel(mel_energy)
+        mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
+        return mfcc_feats
+
+    def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
+        """
+        Creates the Discrete Cosine Transform matrix to be used in the compute function.
+
+        Args:
+            num_fbank_bins: The number of filter bank bins
+            num_mfcc_feats: the number of MFCC features
+
+        Returns:
+            the DCT matrix
+        """
+
+        dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
+        for k in range(num_mfcc_feats):
+            for n in range(num_fbank_bins):
+                dct_m[(k * num_fbank_bins) + n] = (np.sqrt(2 / num_fbank_bins)) * np.cos(
+                    (np.pi / num_fbank_bins) * (n + 0.5) * k)
+        dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
+        return dct_m
+
+    def mel_norm(self, weight, right_mel, left_mel):
+        """
+        Placeholder function over-ridden in child class
+        """
+        return weight
+
+    def create_mel_filter_bank(self):
+        """
+        Creates the Mel filter bank.
+
+        Returns:
+            the mel filter bank
+        """
+        # FFT calculations are greatly accelerated for frame lengths which are powers of 2
+        # Frames are padded and FFT bin width/length calculated accordingly
+        num_fft_bins = int(self._frame_len_padded / 2)
+        fft_bin_width = self.mfcc_params.sampling_freq / self._frame_len_padded
+
+        mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, self.mfcc_params.use_htk_method)
+        mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, self.mfcc_params.use_htk_method)
+        mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1)
+
+        this_bin = np.zeros(num_fft_bins)
+        mel_fbank = [0] * self.mfcc_params.num_fbank_bins
+        for bin_num in range(self.mfcc_params.num_fbank_bins):
+            left_mel = mel_low_freq + bin_num * mel_freq_delta
+            center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta
+            right_mel = mel_low_freq + (bin_num + 2) * mel_freq_delta
+            first_index = last_index = -1
+
+            for i in range(num_fft_bins):
+                freq = (fft_bin_width * i)
+                mel = self.mel_scale(freq, self.mfcc_params.use_htk_method)
+                this_bin[i] = 0.0
+
+                if (mel > left_mel) and (mel < right_mel):
+                    if mel <= center_mel:
+                        weight = (mel - left_mel) / (center_mel - left_mel)
+                    else:
+                        weight = (right_mel - mel) / (right_mel - center_mel)
+
+                    this_bin[i] = self.mel_norm(weight, right_mel, left_mel)
+
+                    if first_index == -1:
+                        first_index = i
+                    last_index = i
+
+            self._filter_bank_filter_first[bin_num] = first_index
+            self._filter_bank_filter_last[bin_num] = last_index
+            mel_fbank[bin_num] = np.zeros(last_index - first_index + 1)
+            j = 0
+
+            for i in range(first_index, last_index + 1):
+                mel_fbank[bin_num][j] = this_bin[i]
+                j += 1
+
+        return mel_fbank
+
+
+class AudioPreprocessor:
+
+    def __init__(self, mfcc, model_input_size, stride):
+        self.model_input_size = model_input_size
+        self.stride = stride
+        self._mfcc_calc = mfcc
+
+    def _normalize(self, values):
+        """
+        Normalize values to mean 0 and std 1
+        """
+        ret_val = (values - np.mean(values)) / np.std(values)
+        return ret_val
+
+    def _get_features(self, features, mfcc_instance, audio_data):
+        idx = 0
+        while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats:
+            current_frame_feats = mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)])
+            features.extend(current_frame_feats)
+            idx += self.stride
+
+    def mfcc_delta_calc(self, features):
+        """
+        Placeholder function over-ridden in child class
+        """
+        return features
+
+    def extract_features(self, audio_data):
+        """
+        Extracts the MFCC features. Also calculates each features first and second order derivatives
+        if the mfcc_delta_calc() function has been implemented by a child class.
+        The matrix returned should be sized appropriately for input to the model, based
+        on the model info specified in the MFCC instance.
+
+        Args:
+            audio_data: the audio data to be used for this calculation
+        Returns:
+            the derived MFCC feature vector, sized appropriately for inference
+        """
+
+        num_samples_per_inference = ((self.model_input_size - 1)
+                                     * self.stride) + self._mfcc_calc.mfcc_params.frame_len
+
+        if len(audio_data) < num_samples_per_inference:
+            raise ValueError("audio_data size for feature extraction is smaller than "
+                             "the expected number of samples needed for inference")
+
+        features = []
+        self._get_features(features, self._mfcc_calc, np.asarray(audio_data))
+        features = np.reshape(np.array(features), (self.model_input_size, self._mfcc_calc.mfcc_params.num_mfcc_feats))
+        features = self.mfcc_delta_calc(features)
+        return np.float32(features)
diff --git a/python/pyarmnn/examples/common/tests/conftest.py b/python/pyarmnn/examples/common/tests/conftest.py
deleted file mode 100644
index 5e027a0..0000000
--- a/python/pyarmnn/examples/common/tests/conftest.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-import ntpath
-
-import urllib.request
-import zipfile
-
-import pytest
-
-script_dir = os.path.dirname(__file__)
-@pytest.fixture(scope="session")
-def test_data_folder(request):
-    """
-        This fixture returns path to folder with shared test resources among all tests
-    """
-
-    data_dir = os.path.join(script_dir, "testdata")
-    if not os.path.exists(data_dir):
-        os.mkdir(data_dir)
-
-    files_to_download = ["https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/messi5.jpg",
-                         "https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/basketball1.png",
-                         "https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/Megamind.avi",
-                         "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"
-                         ]
-
-    for file in files_to_download:
-        path, filename = ntpath.split(file)
-        file_path = os.path.join(data_dir, filename)
-        if not os.path.exists(file_path):
-            print("\nDownloading test file: " + file_path + "\n")
-            urllib.request.urlretrieve(file, file_path)
-
-    # Any unzipping needed, and moving around of files
-    with zipfile.ZipFile(os.path.join(data_dir, "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"), 'r') as zip_ref:
-        zip_ref.extractall(data_dir)
-
-    return data_dir
diff --git a/python/pyarmnn/examples/common/tests/context.py b/python/pyarmnn/examples/common/tests/context.py
deleted file mode 100644
index 72246c0..0000000
--- a/python/pyarmnn/examples/common/tests/context.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import os
-import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
-
-import cv_utils
-import network_executor
-import utils
diff --git a/python/pyarmnn/examples/common/tests/test_network_executor.py b/python/pyarmnn/examples/common/tests/test_network_executor.py
deleted file mode 100644
index e27b382..0000000
--- a/python/pyarmnn/examples/common/tests/test_network_executor.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-
-import cv2
-
-from context import network_executor
-from context import cv_utils
-
-
-def test_execute_network(test_data_folder):
-    model_path = os.path.join(test_data_folder, "detect.tflite")
-    backends = ["CpuAcc", "CpuRef"]
-
-    executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
-    img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
-    input_tensors = cv_utils.preprocess(img, executor.input_binding_info)
-
-    output_result = executor.run(input_tensors)
-
-    # Ensure it detects a person
-    classes = output_result[1]
-    assert classes[0][0] == 0
diff --git a/python/pyarmnn/examples/common/tests/test_utils.py b/python/pyarmnn/examples/common/tests/test_utils.py
deleted file mode 100644
index 28d68ea..0000000
--- a/python/pyarmnn/examples/common/tests/test_utils.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-
-from context import cv_utils
-from context import utils
-
-
-def test_get_source_encoding(test_data_folder):
-    video_file = os.path.join(test_data_folder, "Megamind.avi")
-    video, video_writer, frame_count = cv_utils.init_video_file_capture(video_file, "/tmp")
-    assert cv_utils.get_source_encoding_int(video) == 1145656920
-
-
-def test_read_existing_labels_file(test_data_folder):
-    label_file = os.path.join(test_data_folder, "labelmap.txt")
-    labels_map = utils.dict_labels(label_file)
-    assert labels_map is not None
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py
index cf09fde..d4dadf8 100644
--- a/python/pyarmnn/examples/common/utils.py
+++ b/python/pyarmnn/examples/common/utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
 # SPDX-License-Identifier: MIT
 
 """Contains helper functions that can be used across the example apps."""
@@ -8,6 +8,7 @@
 from pathlib import Path
 
 import numpy as np
+import pyarmnn as ann
 
 
 def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
@@ -39,3 +40,69 @@
             else:
                 labels[idx] = line.strip("\n")
         return labels
+
+
+def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+    """
+    Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
+    input tensors.
+
+    Args:
+        audio_data: The audio data to process
+        mfcc_instance: the mfcc class instance
+        input_binding_info: the model input binding info
+        mfcc_preprocessor: the mfcc preprocessor instance
+    Returns:
+        input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+    """
+
+    data_type = input_binding_info[1].GetDataType()
+    input_tensor = mfcc_preprocessor.extract_features(audio_data)
+    if data_type != ann.DataType_Float32:
+        input_tensor = quantize_input(input_tensor, input_binding_info)
+    input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+    return input_tensors
+
+
+def quantize_input(data, input_binding_info):
+    """Quantize the float input to (u)int8 ready for inputting to model."""
+    if data.ndim != 2:
+        raise RuntimeError("Audio data must have 2 dimensions for quantization")
+
+    quant_scale = input_binding_info[1].GetQuantizationScale()
+    quant_offset = input_binding_info[1].GetQuantizationOffset()
+    data_type = input_binding_info[1].GetDataType()
+
+    if data_type == ann.DataType_QAsymmS8:
+        data_type = np.int8
+    elif data_type == ann.DataType_QAsymmU8:
+        data_type = np.uint8
+    else:
+        raise ValueError("Could not quantize data to required data type")
+
+    d_min = np.iinfo(data_type).min
+    d_max = np.iinfo(data_type).max
+
+    for row in range(data.shape[0]):
+        for col in range(data.shape[1]):
+            data[row, col] = (data[row, col] / quant_scale) + quant_offset
+            data[row, col] = np.clip(data[row, col], d_min, d_max)
+    data = data.astype(data_type)
+    return data
+
+
+def dequantize_output(data, output_binding_info):
+    """Dequantize the (u)int8 output to float"""
+
+    if output_binding_info[1].IsQuantized():
+        if data.ndim != 2:
+            raise RuntimeError("Data must have 2 dimensions for quantization")
+
+        quant_scale = output_binding_info[1].GetQuantizationScale()
+        quant_offset = output_binding_info[1].GetQuantizationOffset()
+
+        data = data.astype(float)
+        for row in range(data.shape[0]):
+            for col in range(data.shape[1]):
+                data[row, col] = (data[row, col] - quant_offset)*quant_scale
+    return data