blob: 4ba8c7b7cf1c322b827d12483c6b8619d14b98c0 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AsrClassifier.hpp"
18
alexander31ae9f02022-02-10 16:15:54 +000019#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000020#include "TensorFlowLiteMicro.hpp"
21#include "Wav2LetterModel.hpp"
22
Richard Burtonb40ecf82022-04-22 16:14:57 +010023namespace arm {
24namespace app {
alexander3c798932021-03-26 21:42:19 +000025
Richard Burtonb40ecf82022-04-22 16:14:57 +010026 template<typename T>
27 bool AsrClassifier::GetTopResults(TfLiteTensor* tensor,
28 std::vector<ClassificationResult>& vecResults,
29 const std::vector <std::string>& labels, double scale, double zeroPoint)
30 {
31 const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx];
32 const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx];
alexanderc350cdc2021-04-29 20:36:09 +010033
Richard Burtonb40ecf82022-04-22 16:14:57 +010034 if (nLetters != labels.size()) {
alexander3c798932021-03-26 21:42:19 +000035 printf("Output size doesn't match the labels' size\n");
36 return false;
37 }
38
Richard Burtonb40ecf82022-04-22 16:14:57 +010039 /* NOTE: tensor's size verification against labels should be
40 * checked by the calling/public function. */
41 if (nLetters < 1) {
alexander3c798932021-03-26 21:42:19 +000042 return false;
43 }
44
Richard Burtonb40ecf82022-04-22 16:14:57 +010045 /* Final results' container. */
46 vecResults = std::vector<ClassificationResult>(nElems);
47
48 T* tensorData = tflite::GetTensorData<T>(tensor);
49
50 /* Get the top 1 results. */
51 for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
52 std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
53
54 for (uint32_t j = 1; j < nLetters; ++j) {
55 if (top_1.first < tensorData[row + j]) {
56 top_1.first = tensorData[row + j];
57 top_1.second = j;
58 }
59 }
60
61 double score = static_cast<int> (top_1.first);
62 vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
63 vecResults[i].m_label = labels[top_1.second];
64 vecResults[i].m_labelIdx = top_1.second;
65 }
66
alexander3c798932021-03-26 21:42:19 +000067 return true;
Richard Burtonb40ecf82022-04-22 16:14:57 +010068 }
69 template bool AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
70 std::vector<ClassificationResult>& vecResults,
71 const std::vector <std::string>& labels,
72 double scale, double zeroPoint);
73 template bool AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
74 std::vector<ClassificationResult>& vecResults,
75 const std::vector <std::string>& labels,
76 double scale, double zeroPoint);
77
78 bool AsrClassifier::GetClassificationResults(
79 TfLiteTensor* outputTensor,
80 std::vector<ClassificationResult>& vecResults,
81 const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax)
82 {
83 UNUSED(use_softmax);
84 vecResults.clear();
85
86 constexpr int minTensorDims = static_cast<int>(
87 (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)?
88 Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx);
89
90 constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx;
91
92 /* Sanity checks. */
93 if (outputTensor == nullptr) {
94 printf_err("Output vector is null pointer.\n");
95 return false;
96 } else if (outputTensor->dims->size < minTensorDims) {
97 printf_err("Output tensor expected to be %dD\n", minTensorDims);
98 return false;
99 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
100 printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
101 return false;
102 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
103 printf("Output size doesn't match the labels' size\n");
104 return false;
105 }
106
107 if (topNCount != 1) {
108 warn("TopNCount value ignored in this implementation\n");
109 }
110
111 /* To return the floating point values, we need quantization parameters. */
112 QuantParams quantParams = GetTensorQuantParams(outputTensor);
113
114 bool resultState;
115
116 switch (outputTensor->type) {
117 case kTfLiteUInt8:
118 resultState = this->GetTopResults<uint8_t>(
119 outputTensor, vecResults,
120 labels, quantParams.scale,
121 quantParams.offset);
122 break;
123 case kTfLiteInt8:
124 resultState = this->GetTopResults<int8_t>(
125 outputTensor, vecResults,
126 labels, quantParams.scale,
127 quantParams.offset);
128 break;
129 default:
130 printf_err("Tensor type %s not supported by classifier\n",
131 TfLiteTypeGetName(outputTensor->type));
132 return false;
133 }
134
135 if (!resultState) {
136 printf_err("Failed to get sorted set\n");
137 return false;
138 }
139
140 return true;
141 }
142
143} /* namespace app */
144} /* namespace arm */