blob: df26a7fa88d5337814ddec11addaf7cdd77e202a [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AsrClassifier.hpp"
18
19#include "hal.h"
20#include "TensorFlowLiteMicro.hpp"
21#include "Wav2LetterModel.hpp"
22
23template<typename T>
alexanderc350cdc2021-04-29 20:36:09 +010024bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
25 std::vector<ClassificationResult>& vecResults,
26 const std::vector <std::string>& labels, double scale, double zeroPoint)
alexander3c798932021-03-26 21:42:19 +000027{
28 const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
29 const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
30
alexanderc350cdc2021-04-29 20:36:09 +010031 if (nLetters != labels.size()) {
32 printf("Output size doesn't match the labels' size\n");
33 return false;
34 }
35
alexander3c798932021-03-26 21:42:19 +000036 /* NOTE: tensor's size verification against labels should be
37 * checked by the calling/public function. */
38 if (nLetters < 1) {
39 return false;
40 }
41
42 /* Final results' container. */
43 vecResults = std::vector<ClassificationResult>(nElems);
44
45 T* tensorData = tflite::GetTensorData<T>(tensor);
46
47 /* Get the top 1 results. */
48 for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
49 std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
50
51 for (uint32_t j = 1; j < nLetters; ++j) {
52 if (top_1.first < tensorData[row + j]) {
53 top_1.first = tensorData[row + j];
54 top_1.second = j;
55 }
56 }
57
58 double score = static_cast<int> (top_1.first);
59 vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
60 vecResults[i].m_label = labels[top_1.second];
61 vecResults[i].m_labelIdx = top_1.second;
62 }
63
64 return true;
65}
alexanderc350cdc2021-04-29 20:36:09 +010066template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
67 std::vector<ClassificationResult>& vecResults,
68 const std::vector <std::string>& labels, double scale, double zeroPoint);
69template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
70 std::vector<ClassificationResult>& vecResults,
71 const std::vector <std::string>& labels, double scale, double zeroPoint);
alexander3c798932021-03-26 21:42:19 +000072
73bool arm::app::AsrClassifier::GetClassificationResults(
74 TfLiteTensor* outputTensor,
75 std::vector<ClassificationResult>& vecResults,
76 const std::vector <std::string>& labels, uint32_t topNCount)
77{
78 vecResults.clear();
79
80 constexpr int minTensorDims = static_cast<int>(
81 (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
82 arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
83
84 constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
85
86 /* Sanity checks. */
87 if (outputTensor == nullptr) {
88 printf_err("Output vector is null pointer.\n");
89 return false;
90 } else if (outputTensor->dims->size < minTensorDims) {
91 printf_err("Output tensor expected to be %dD\n", minTensorDims);
92 return false;
93 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
94 printf_err("Output vectors are smaller than %u\n", topNCount);
95 return false;
96 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
97 printf("Output size doesn't match the labels' size\n");
98 return false;
99 }
100
101 if (topNCount != 1) {
102 warn("TopNCount value ignored in this implementation\n");
103 }
104
105 /* To return the floating point values, we need quantization parameters. */
106 QuantParams quantParams = GetTensorQuantParams(outputTensor);
107
108 bool resultState;
109
110 switch (outputTensor->type) {
111 case kTfLiteUInt8:
alexanderc350cdc2021-04-29 20:36:09 +0100112 resultState = this->GetTopResults<uint8_t>(
113 outputTensor, vecResults,
114 labels, quantParams.scale,
115 quantParams.offset);
alexander3c798932021-03-26 21:42:19 +0000116 break;
117 case kTfLiteInt8:
alexanderc350cdc2021-04-29 20:36:09 +0100118 resultState = this->GetTopResults<int8_t>(
119 outputTensor, vecResults,
120 labels, quantParams.scale,
121 quantParams.offset);
alexander3c798932021-03-26 21:42:19 +0000122 break;
123 default:
124 printf_err("Tensor type %s not supported by classifier\n",
125 TfLiteTypeGetName(outputTensor->type));
126 return false;
127 }
128
129 if (!resultState) {
130 printf_err("Failed to get sorted set\n");
131 return false;
132 }
133
134 return true;
135}