blob: 7377d303645c486ef836f8bb03c6213e402b949d [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AsrClassifier.hpp"
18
19#include "hal.h"
20#include "TensorFlowLiteMicro.hpp"
21#include "Wav2LetterModel.hpp"
22
23template<typename T>
24bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor,
25 std::vector<ClassificationResult>& vecResults,
26 const std::vector <std::string>& labels, double scale, double zeroPoint)
27{
28 const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
29 const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
30
31 /* NOTE: tensor's size verification against labels should be
32 * checked by the calling/public function. */
33 if (nLetters < 1) {
34 return false;
35 }
36
37 /* Final results' container. */
38 vecResults = std::vector<ClassificationResult>(nElems);
39
40 T* tensorData = tflite::GetTensorData<T>(tensor);
41
42 /* Get the top 1 results. */
43 for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
44 std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
45
46 for (uint32_t j = 1; j < nLetters; ++j) {
47 if (top_1.first < tensorData[row + j]) {
48 top_1.first = tensorData[row + j];
49 top_1.second = j;
50 }
51 }
52
53 double score = static_cast<int> (top_1.first);
54 vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
55 vecResults[i].m_label = labels[top_1.second];
56 vecResults[i].m_labelIdx = top_1.second;
57 }
58
59 return true;
60}
61template bool arm::app::AsrClassifier::_GetTopResults<uint8_t>(TfLiteTensor* tensor,
62 std::vector<ClassificationResult>& vecResults,
63 const std::vector <std::string>& labels, double scale, double zeroPoint);
64template bool arm::app::AsrClassifier::_GetTopResults<int8_t>(TfLiteTensor* tensor,
65 std::vector<ClassificationResult>& vecResults,
66 const std::vector <std::string>& labels, double scale, double zeroPoint);
67
68bool arm::app::AsrClassifier::GetClassificationResults(
69 TfLiteTensor* outputTensor,
70 std::vector<ClassificationResult>& vecResults,
71 const std::vector <std::string>& labels, uint32_t topNCount)
72{
73 vecResults.clear();
74
75 constexpr int minTensorDims = static_cast<int>(
76 (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
77 arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
78
79 constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
80
81 /* Sanity checks. */
82 if (outputTensor == nullptr) {
83 printf_err("Output vector is null pointer.\n");
84 return false;
85 } else if (outputTensor->dims->size < minTensorDims) {
86 printf_err("Output tensor expected to be %dD\n", minTensorDims);
87 return false;
88 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
89 printf_err("Output vectors are smaller than %u\n", topNCount);
90 return false;
91 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
92 printf("Output size doesn't match the labels' size\n");
93 return false;
94 }
95
96 if (topNCount != 1) {
97 warn("TopNCount value ignored in this implementation\n");
98 }
99
100 /* To return the floating point values, we need quantization parameters. */
101 QuantParams quantParams = GetTensorQuantParams(outputTensor);
102
103 bool resultState;
104
105 switch (outputTensor->type) {
106 case kTfLiteUInt8:
107 resultState = this->_GetTopResults<uint8_t>(
108 outputTensor, vecResults,
109 labels, quantParams.scale,
110 quantParams.offset);
111 break;
112 case kTfLiteInt8:
113 resultState = this->_GetTopResults<int8_t>(
114 outputTensor, vecResults,
115 labels, quantParams.scale,
116 quantParams.offset);
117 break;
118 default:
119 printf_err("Tensor type %s not supported by classifier\n",
120 TfLiteTypeGetName(outputTensor->type));
121 return false;
122 }
123
124 if (!resultState) {
125 printf_err("Failed to get sorted set\n");
126 return false;
127 }
128
129 return true;
130}