blob: c461875ec37b1f79fe1c89339db312d11fe76f6e [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AdPostProcessing.hpp"
alexander31ae9f02022-02-10 16:15:54 +000018#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000019
20#include <numeric>
21#include <cmath>
22#include <string>
23
24namespace arm {
25namespace app {
26
27 template<typename T>
28 std::vector<float> Dequantize(TfLiteTensor* tensor) {
29
30 if (tensor == nullptr) {
31 printf_err("Tensor is null pointer can not dequantize.\n");
32 return std::vector<float>();
33 }
34 T* tensorData = tflite::GetTensorData<T>(tensor);
35
36 uint32_t totalOutputSize = 1;
37 for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
38 totalOutputSize *= tensor->dims->data[inputDim];
39 }
40
41 /* For getting the floating point values, we need quantization parameters */
42 QuantParams quantParams = GetTensorQuantParams(tensor);
43
44 std::vector<float> dequantizedOutput(totalOutputSize);
45
46 for (size_t i = 0; i < totalOutputSize; ++i) {
47 dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
48 }
49
50 return dequantizedOutput;
51 }
52
53 void Softmax(std::vector<float>& inputVector) {
54 auto start = inputVector.begin();
55 auto end = inputVector.end();
56
57 /* Fix for numerical stability and apply exp. */
58 float maxValue = *std::max_element(start, end);
59 for (auto it = start; it!=end; ++it) {
60 *it = std::exp((*it) - maxValue);
61 }
62
63 float sumExp = std::accumulate(start, end, 0.0f);
64
65 for (auto it = start; it!=end; ++it) {
66 *it = (*it)/sumExp;
67 }
68 }
69
70 int8_t OutputIndexFromFileName(std::string wavFileName) {
71 /* Filename is assumed in the form machine_id_00.wav */
72 std::string delimiter = "_"; /* First character used to split the file name up. */
73 size_t delimiterStart;
74 std::string subString;
75 size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */
76
77 for (size_t i = 0; i < machineIdxInString; ++i) {
78 delimiterStart = wavFileName.find(delimiter);
79 subString = wavFileName.substr(0, delimiterStart);
80 wavFileName.erase(0, delimiterStart + delimiter.length());
81 }
82
83 /* At this point substring should be 00.wav */
84 delimiter = "."; /* Second character used to split the file name up. */
85 delimiterStart = subString.find(delimiter);
86 subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
87
88 auto is_number = [](const std::string& str) -> bool
89 {
90 std::string::const_iterator it = str.begin();
91 while (it != str.end() && std::isdigit(*it)) ++it;
92 return !str.empty() && it == str.end();
93 };
94
95 const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
96
97 /* Return corresponding index in the output vector. */
98 if (machineIdx == 0) {
99 return 0;
100 } else if (machineIdx == 2) {
101 return 1;
102 } else if (machineIdx == 4) {
103 return 2;
104 } else if (machineIdx == 6) {
105 return 3;
106 } else {
107 printf_err("%d is an invalid machine index \n", machineIdx);
108 return -1;
109 }
110 }
111
112 template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor);
113 template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor);
114} /* namespace app */
115} /* namespace arm */