blob: a7150684e5347afed3b02195a1585abed45c46c0 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AsrClassifier.hpp"
18
19#include "hal.h"
20#include "TensorFlowLiteMicro.hpp"
21#include "Wav2LetterModel.hpp"
22
23template<typename T>
alexanderc350cdc2021-04-29 20:36:09 +010024bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
25 std::vector<ClassificationResult>& vecResults,
26 const std::vector <std::string>& labels, double scale, double zeroPoint)
alexander3c798932021-03-26 21:42:19 +000027{
28 const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
29 const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
30
alexanderc350cdc2021-04-29 20:36:09 +010031 if (nLetters != labels.size()) {
32 printf("Output size doesn't match the labels' size\n");
33 return false;
34 }
35
alexander3c798932021-03-26 21:42:19 +000036 /* NOTE: tensor's size verification against labels should be
37 * checked by the calling/public function. */
38 if (nLetters < 1) {
39 return false;
40 }
41
42 /* Final results' container. */
43 vecResults = std::vector<ClassificationResult>(nElems);
44
45 T* tensorData = tflite::GetTensorData<T>(tensor);
46
47 /* Get the top 1 results. */
48 for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
49 std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
50
51 for (uint32_t j = 1; j < nLetters; ++j) {
52 if (top_1.first < tensorData[row + j]) {
53 top_1.first = tensorData[row + j];
54 top_1.second = j;
55 }
56 }
57
58 double score = static_cast<int> (top_1.first);
59 vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
60 vecResults[i].m_label = labels[top_1.second];
61 vecResults[i].m_labelIdx = top_1.second;
62 }
63
64 return true;
65}
alexanderc350cdc2021-04-29 20:36:09 +010066template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
67 std::vector<ClassificationResult>& vecResults,
68 const std::vector <std::string>& labels, double scale, double zeroPoint);
69template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
70 std::vector<ClassificationResult>& vecResults,
71 const std::vector <std::string>& labels, double scale, double zeroPoint);
alexander3c798932021-03-26 21:42:19 +000072
73bool arm::app::AsrClassifier::GetClassificationResults(
74 TfLiteTensor* outputTensor,
75 std::vector<ClassificationResult>& vecResults,
Kshitij Sisodia76a15802021-12-24 11:05:11 +000076 const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax)
alexander3c798932021-03-26 21:42:19 +000077{
Kshitij Sisodia76a15802021-12-24 11:05:11 +000078 UNUSED(use_softmax);
alexander3c798932021-03-26 21:42:19 +000079 vecResults.clear();
80
81 constexpr int minTensorDims = static_cast<int>(
82 (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
83 arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
84
85 constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
86
87 /* Sanity checks. */
88 if (outputTensor == nullptr) {
89 printf_err("Output vector is null pointer.\n");
90 return false;
91 } else if (outputTensor->dims->size < minTensorDims) {
92 printf_err("Output tensor expected to be %dD\n", minTensorDims);
93 return false;
94 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010095 printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
alexander3c798932021-03-26 21:42:19 +000096 return false;
97 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
98 printf("Output size doesn't match the labels' size\n");
99 return false;
100 }
101
102 if (topNCount != 1) {
103 warn("TopNCount value ignored in this implementation\n");
104 }
105
106 /* To return the floating point values, we need quantization parameters. */
107 QuantParams quantParams = GetTensorQuantParams(outputTensor);
108
109 bool resultState;
110
111 switch (outputTensor->type) {
112 case kTfLiteUInt8:
alexanderc350cdc2021-04-29 20:36:09 +0100113 resultState = this->GetTopResults<uint8_t>(
114 outputTensor, vecResults,
115 labels, quantParams.scale,
116 quantParams.offset);
alexander3c798932021-03-26 21:42:19 +0000117 break;
118 case kTfLiteInt8:
alexanderc350cdc2021-04-29 20:36:09 +0100119 resultState = this->GetTopResults<int8_t>(
120 outputTensor, vecResults,
121 labels, quantParams.scale,
122 quantParams.offset);
alexander3c798932021-03-26 21:42:19 +0000123 break;
124 default:
125 printf_err("Tensor type %s not supported by classifier\n",
126 TfLiteTypeGetName(outputTensor->type));
127 return false;
128 }
129
130 if (!resultState) {
131 printf_err("Failed to get sorted set\n");
132 return false;
133 }
134
135 return true;
136}