blob: bc86e09cf044f11c469ca983c2f895e482de1cab [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AsrClassifier.hpp"
18
19#include "hal.h"
20#include "TensorFlowLiteMicro.hpp"
21#include "Wav2LetterModel.hpp"
22
23template<typename T>
24bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor,
25 std::vector<ClassificationResult>& vecResults,
26 const std::vector <std::string>& labels, double scale, double zeroPoint)
27{
28 const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
29 const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
30
31
32 /* NOTE: tensor's size verification against labels should be
33 * checked by the calling/public function. */
34 if (nLetters < 1) {
35 return false;
36 }
37
38 /* Final results' container. */
39 vecResults = std::vector<ClassificationResult>(nElems);
40
41 T* tensorData = tflite::GetTensorData<T>(tensor);
42
43 /* Get the top 1 results. */
44 for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
45 std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
46
47 for (uint32_t j = 1; j < nLetters; ++j) {
48 if (top_1.first < tensorData[row + j]) {
49 top_1.first = tensorData[row + j];
50 top_1.second = j;
51 }
52 }
53
54 double score = static_cast<int> (top_1.first);
55 vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
56 vecResults[i].m_label = labels[top_1.second];
57 vecResults[i].m_labelIdx = top_1.second;
58 }
59
60 return true;
61}
62template bool arm::app::AsrClassifier::_GetTopResults<uint8_t>(TfLiteTensor* tensor,
63 std::vector<ClassificationResult>& vecResults,
64 const std::vector <std::string>& labels, double scale, double zeroPoint);
65template bool arm::app::AsrClassifier::_GetTopResults<int8_t>(TfLiteTensor* tensor,
66 std::vector<ClassificationResult>& vecResults,
67 const std::vector <std::string>& labels, double scale, double zeroPoint);
68
69bool arm::app::AsrClassifier::GetClassificationResults(
70 TfLiteTensor* outputTensor,
71 std::vector<ClassificationResult>& vecResults,
72 const std::vector <std::string>& labels, uint32_t topNCount)
73{
74 vecResults.clear();
75
76 constexpr int minTensorDims = static_cast<int>(
77 (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
78 arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
79
80 constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
81
82 /* Sanity checks. */
83 if (outputTensor == nullptr) {
84 printf_err("Output vector is null pointer.\n");
85 return false;
86 } else if (outputTensor->dims->size < minTensorDims) {
87 printf_err("Output tensor expected to be 3D (1, m, n)\n");
88 return false;
89 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
90 printf_err("Output vectors are smaller than %u\n", topNCount);
91 return false;
92 } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
93 printf("Output size doesn't match the labels' size\n");
94 return false;
95 }
96
97 if (topNCount != 1) {
98 warn("TopNCount value ignored in this implementation\n");
99 }
100
101 /* To return the floating point values, we need quantization parameters. */
102 QuantParams quantParams = GetTensorQuantParams(outputTensor);
103
104 bool resultState;
105
106 switch (outputTensor->type) {
107 case kTfLiteUInt8:
108 resultState = this->_GetTopResults<uint8_t>(
109 outputTensor, vecResults,
110 labels, quantParams.scale,
111 quantParams.offset);
112 break;
113 case kTfLiteInt8:
114 resultState = this->_GetTopResults<int8_t>(
115 outputTensor, vecResults,
116 labels, quantParams.scale,
117 quantParams.offset);
118 break;
119 default:
120 printf_err("Tensor type %s not supported by classifier\n",
121 TfLiteTypeGetName(outputTensor->type));
122 return false;
123 }
124
125 if (!resultState) {
126 printf_err("Failed to get sorted set\n");
127 return false;
128 }
129
130 return true;
131}