blob: 00e689b3216df5fc85aebd755e7f9da21823c64f [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtonc2911442022-04-22 09:08:21 +01002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
Richard Burtonc2911442022-04-22 09:08:21 +010018
alexander3c798932021-03-26 21:42:19 +000019#include "Wav2LetterModel.hpp"
alexander31ae9f02022-02-10 16:15:54 +000020#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000021
Richard Burtonc2911442022-04-22 09:08:21 +010022#include <cmath>
23
alexander3c798932021-03-26 21:42:19 +000024namespace arm {
25namespace app {
alexander3c798932021-03-26 21:42:19 +000026
Richard Burtonb40ecf82022-04-22 16:14:57 +010027 AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
Richard Burtonc2911442022-04-22 09:08:21 +010028 const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
29 const uint32_t outputContextLen,
30 const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
31 ):
32 m_classifier(classifier),
33 m_outputTensor(outputTensor),
34 m_labels{labels},
35 m_results(results),
36 m_outputContextLen(outputContextLen),
Isabella Gottardi56ee6202021-05-12 08:27:15 +010037 m_countIterations(0),
Richard Burtonc2911442022-04-22 09:08:21 +010038 m_blankTokenIdx(blankTokenIdx),
39 m_reductionAxisIdx(reductionAxisIdx)
40 {
Richard Burtonb40ecf82022-04-22 16:14:57 +010041 this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
Richard Burtonc2911442022-04-22 09:08:21 +010042 this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
43 }
alexander3c798932021-03-26 21:42:19 +000044
Richard Burtonb40ecf82022-04-22 16:14:57 +010045 bool AsrPostProcess::DoPostProcess()
alexander3c798932021-03-26 21:42:19 +000046 {
47 /* Basic checks. */
Richard Burtonc2911442022-04-22 09:08:21 +010048 if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
alexander3c798932021-03-26 21:42:19 +000049 return false;
50 }
51
52 /* Irrespective of tensor type, we use unsigned "byte" */
Richard Burtonc2911442022-04-22 09:08:21 +010053 auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
Richard Burtonb40ecf82022-04-22 16:14:57 +010054 const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor);
alexander3c798932021-03-26 21:42:19 +000055
56 /* Other sanity checks. */
57 if (0 == elemSz) {
58 printf_err("Tensor type not supported for post processing\n");
59 return false;
Richard Burtonc2911442022-04-22 09:08:21 +010060 } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) {
alexander3c798932021-03-26 21:42:19 +000061 printf_err("Insufficient number of tensor bytes\n");
62 return false;
63 }
64
65 /* Which axis do we need to process? */
Richard Burtonc2911442022-04-22 09:08:21 +010066 switch (this->m_reductionAxisIdx) {
67 case Wav2LetterModel::ms_outputRowsIdx:
68 this->EraseSectionsRowWise(
69 ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx],
70 this->m_lastIteration);
71 break;
alexander3c798932021-03-26 21:42:19 +000072 default:
Richard Burtonc2911442022-04-22 09:08:21 +010073 printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx);
74 return false;
alexander3c798932021-03-26 21:42:19 +000075 }
Richard Burtonc2911442022-04-22 09:08:21 +010076 this->m_classifier.GetClassificationResults(this->m_outputTensor,
77 this->m_results, this->m_labels, 1);
alexander3c798932021-03-26 21:42:19 +000078
Richard Burtonc2911442022-04-22 09:08:21 +010079 return true;
alexander3c798932021-03-26 21:42:19 +000080 }
81
Richard Burtonb40ecf82022-04-22 16:14:57 +010082 bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
alexander3c798932021-03-26 21:42:19 +000083 {
84 if (nullptr == tensor) {
85 return false;
86 }
87
88 if (static_cast<int>(axisIdx) >= tensor->dims->size) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010089 printf_err("Invalid axis index: %" PRIu32 "; Max: %d\n",
alexander3c798932021-03-26 21:42:19 +000090 axisIdx, tensor->dims->size);
91 return false;
92 }
93
Isabella Gottardi56ee6202021-05-12 08:27:15 +010094 if (static_cast<int>(this->m_totalLen) !=
alexander3c798932021-03-26 21:42:19 +000095 tensor->dims->data[axisIdx]) {
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010096 printf_err("Unexpected tensor dimension for axis %" PRIu32", got %d.\n",
Richard Burtonc2911442022-04-22 09:08:21 +010097 axisIdx, tensor->dims->data[axisIdx]);
alexander3c798932021-03-26 21:42:19 +000098 return false;
99 }
100
101 return true;
102 }
103
Richard Burtonb40ecf82022-04-22 16:14:57 +0100104 uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
alexander3c798932021-03-26 21:42:19 +0000105 {
106 switch(tensor->type) {
107 case kTfLiteUInt8:
alexander3c798932021-03-26 21:42:19 +0000108 case kTfLiteInt8:
109 return 1;
110 case kTfLiteInt16:
111 return 2;
112 case kTfLiteInt32:
alexander3c798932021-03-26 21:42:19 +0000113 case kTfLiteFloat32:
114 return 4;
115 default:
116 printf_err("Unsupported tensor type %s\n",
117 TfLiteTypeGetName(tensor->type));
118 }
119
120 return 0;
121 }
122
Richard Burtonb40ecf82022-04-22 16:14:57 +0100123 bool AsrPostProcess::EraseSectionsRowWise(
Richard Burtonc2911442022-04-22 09:08:21 +0100124 uint8_t* ptrData,
125 const uint32_t strideSzBytes,
126 const bool lastIteration)
alexander3c798932021-03-26 21:42:19 +0000127 {
128 /* In this case, the "zero-ing" is quite simple as the region
129 * to be zeroed sits in contiguous memory (row-major). */
Richard Burtonc2911442022-04-22 09:08:21 +0100130 const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen;
alexander3c798932021-03-26 21:42:19 +0000131
132 /* Erase left context? */
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100133 if (this->m_countIterations > 0) {
alexander3c798932021-03-26 21:42:19 +0000134 /* Set output of each classification window to the blank token. */
135 std::memset(ptrData, 0, eraseLen);
Richard Burtonc2911442022-04-22 09:08:21 +0100136 for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100137 ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
alexander3c798932021-03-26 21:42:19 +0000138 }
139 }
140
141 /* Erase right context? */
142 if (false == lastIteration) {
Richard Burtonc2911442022-04-22 09:08:21 +0100143 uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen));
alexander3c798932021-03-26 21:42:19 +0000144 /* Set output of each classification window to the blank token. */
145 std::memset(rightCtxPtr, 0, eraseLen);
Richard Burtonc2911442022-04-22 09:08:21 +0100146 for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100147 rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
alexander3c798932021-03-26 21:42:19 +0000148 }
149 }
150
151 if (lastIteration) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100152 this->m_countIterations = 0;
alexander3c798932021-03-26 21:42:19 +0000153 } else {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100154 ++this->m_countIterations;
alexander3c798932021-03-26 21:42:19 +0000155 }
156
157 return true;
158 }
159
Richard Burtonb40ecf82022-04-22 16:14:57 +0100160 uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model)
alexander3c798932021-03-26 21:42:19 +0000161 {
Richard Burtonc2911442022-04-22 09:08:21 +0100162 TfLiteTensor* inputTensor = model.GetInputTensor(0);
163 const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
164 if (inputRows == 0) {
165 printf_err("Error getting number of input rows for axis: %" PRIu32 "\n",
166 Wav2LetterModel::ms_inputRowsIdx);
167 }
168 return inputRows;
alexander3c798932021-03-26 21:42:19 +0000169 }
170
Richard Burtonb40ecf82022-04-22 16:14:57 +0100171 uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
Richard Burtonc2911442022-04-22 09:08:21 +0100172 {
173 const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
174 if (outputRows == 0) {
175 printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
176 Wav2LetterModel::ms_outputRowsIdx);
177 }
Richard Burtonb40ecf82022-04-22 16:14:57 +0100178
179 /* Watching for underflow. */
Richard Burtonc2911442022-04-22 09:08:21 +0100180 int innerLen = (outputRows - (2 * outputCtxLen));
181
182 return std::max(innerLen, 0);
183 }
184
Richard Burtonb40ecf82022-04-22 16:14:57 +0100185 uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
Richard Burtonc2911442022-04-22 09:08:21 +0100186 {
Richard Burtonb40ecf82022-04-22 16:14:57 +0100187 const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model);
Richard Burtonc2911442022-04-22 09:08:21 +0100188 const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
189 constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
190
191 /* Check to make sure that the input tensor supports the above
192 * context and inner lengths. */
193 if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
194 printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
195 inputCtxLen);
196 return 0;
197 }
198
199 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
200 const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
201 if (outputRows == 0) {
202 printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
203 Wav2LetterModel::ms_outputRowsIdx);
204 return 0;
205 }
206
207 const float inOutRowRatio = static_cast<float>(inputRows) /
208 static_cast<float>(outputRows);
209
210 return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio);
211 }
212
alexander3c798932021-03-26 21:42:19 +0000213} /* namespace app */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100214} /* namespace arm */