blob: 157784b9263b51af55e532615550279db21101d4 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AdPostProcessing.hpp"
18
19#include "hal.h"
20
21#include <numeric>
22#include <cmath>
23#include <string>
24
25namespace arm {
26namespace app {
27
28 template<typename T>
29 std::vector<float> Dequantize(TfLiteTensor* tensor) {
30
31 if (tensor == nullptr) {
32 printf_err("Tensor is null pointer can not dequantize.\n");
33 return std::vector<float>();
34 }
35 T* tensorData = tflite::GetTensorData<T>(tensor);
36
37 uint32_t totalOutputSize = 1;
38 for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
39 totalOutputSize *= tensor->dims->data[inputDim];
40 }
41
42 /* For getting the floating point values, we need quantization parameters */
43 QuantParams quantParams = GetTensorQuantParams(tensor);
44
45 std::vector<float> dequantizedOutput(totalOutputSize);
46
47 for (size_t i = 0; i < totalOutputSize; ++i) {
48 dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
49 }
50
51 return dequantizedOutput;
52 }
53
54 void Softmax(std::vector<float>& inputVector) {
55 auto start = inputVector.begin();
56 auto end = inputVector.end();
57
58 /* Fix for numerical stability and apply exp. */
59 float maxValue = *std::max_element(start, end);
60 for (auto it = start; it!=end; ++it) {
61 *it = std::exp((*it) - maxValue);
62 }
63
64 float sumExp = std::accumulate(start, end, 0.0f);
65
66 for (auto it = start; it!=end; ++it) {
67 *it = (*it)/sumExp;
68 }
69 }
70
71 int8_t OutputIndexFromFileName(std::string wavFileName) {
72 /* Filename is assumed in the form machine_id_00.wav */
73 std::string delimiter = "_"; /* First character used to split the file name up. */
74 size_t delimiterStart;
75 std::string subString;
76 size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */
77
78 for (size_t i = 0; i < machineIdxInString; ++i) {
79 delimiterStart = wavFileName.find(delimiter);
80 subString = wavFileName.substr(0, delimiterStart);
81 wavFileName.erase(0, delimiterStart + delimiter.length());
82 }
83
84 /* At this point substring should be 00.wav */
85 delimiter = "."; /* Second character used to split the file name up. */
86 delimiterStart = subString.find(delimiter);
87 subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
88
89 auto is_number = [](const std::string& str) -> bool
90 {
91 std::string::const_iterator it = str.begin();
92 while (it != str.end() && std::isdigit(*it)) ++it;
93 return !str.empty() && it == str.end();
94 };
95
96 const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
97
98 /* Return corresponding index in the output vector. */
99 if (machineIdx == 0) {
100 return 0;
101 } else if (machineIdx == 2) {
102 return 1;
103 } else if (machineIdx == 4) {
104 return 2;
105 } else if (machineIdx == 6) {
106 return 3;
107 } else {
108 printf_err("%d is an invalid machine index \n", machineIdx);
109 return -1;
110 }
111 }
112
113 template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor);
114 template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor);
115} /* namespace app */
116} /* namespace arm */