blob: 039206136531972c1e19674f00693ba0af0f6b0b [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
alexander3c798932021-03-26 21:42:19 +000018#include "Wav2LetterModel.hpp"
alexander31ae9f02022-02-10 16:15:54 +000019#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000020
21namespace arm {
22namespace app {
23namespace audio {
24namespace asr {
25
26 Postprocess::Postprocess(const uint32_t contextLen,
27 const uint32_t innerLen,
28 const uint32_t blankTokenIdx)
Isabella Gottardi56ee6202021-05-12 08:27:15 +010029 : m_contextLen(contextLen),
30 m_innerLen(innerLen),
31 m_totalLen(2 * this->m_contextLen + this->m_innerLen),
32 m_countIterations(0),
33 m_blankTokenIdx(blankTokenIdx)
alexander3c798932021-03-26 21:42:19 +000034 {}
35
36 bool Postprocess::Invoke(TfLiteTensor* tensor,
37 const uint32_t axisIdx,
38 const bool lastIteration)
39 {
40 /* Basic checks. */
alexanderc350cdc2021-04-29 20:36:09 +010041 if (!this->IsInputValid(tensor, axisIdx)) {
alexander3c798932021-03-26 21:42:19 +000042 return false;
43 }
44
45 /* Irrespective of tensor type, we use unsigned "byte" */
46 uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
alexanderc350cdc2021-04-29 20:36:09 +010047 const uint32_t elemSz = this->GetTensorElementSize(tensor);
alexander3c798932021-03-26 21:42:19 +000048
49 /* Other sanity checks. */
50 if (0 == elemSz) {
51 printf_err("Tensor type not supported for post processing\n");
52 return false;
Isabella Gottardi56ee6202021-05-12 08:27:15 +010053 } else if (elemSz * this->m_totalLen > tensor->bytes) {
alexander3c798932021-03-26 21:42:19 +000054 printf_err("Insufficient number of tensor bytes\n");
55 return false;
56 }
57
58 /* Which axis do we need to process? */
59 switch (axisIdx) {
60 case arm::app::Wav2LetterModel::ms_outputRowsIdx:
alexanderc350cdc2021-04-29 20:36:09 +010061 return this->EraseSectionsRowWise(ptrData,
62 elemSz *
63 tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
64 lastIteration);
alexander3c798932021-03-26 21:42:19 +000065 case arm::app::Wav2LetterModel::ms_outputColsIdx:
alexanderc350cdc2021-04-29 20:36:09 +010066 return this->EraseSectionsColWise(ptrData,
67 elemSz *
68 tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx],
69 lastIteration);
alexander3c798932021-03-26 21:42:19 +000070 default:
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010071 printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
alexander3c798932021-03-26 21:42:19 +000072 }
73
74 return false;
75 }
76
alexanderc350cdc2021-04-29 20:36:09 +010077 bool Postprocess::IsInputValid(TfLiteTensor* tensor,
78 const uint32_t axisIdx) const
alexander3c798932021-03-26 21:42:19 +000079 {
80 if (nullptr == tensor) {
81 return false;
82 }
83
84 if (static_cast<int>(axisIdx) >= tensor->dims->size) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010085 printf_err("Invalid axis index: %" PRIu32 "; Max: %d\n",
alexander3c798932021-03-26 21:42:19 +000086 axisIdx, tensor->dims->size);
87 return false;
88 }
89
Isabella Gottardi56ee6202021-05-12 08:27:15 +010090 if (static_cast<int>(this->m_totalLen) !=
alexander3c798932021-03-26 21:42:19 +000091 tensor->dims->data[axisIdx]) {
92 printf_err("Unexpected tensor dimension for axis %d, \n",
93 tensor->dims->data[axisIdx]);
94 return false;
95 }
96
97 return true;
98 }
99
alexanderc350cdc2021-04-29 20:36:09 +0100100 uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor)
alexander3c798932021-03-26 21:42:19 +0000101 {
102 switch(tensor->type) {
103 case kTfLiteUInt8:
alexander3c798932021-03-26 21:42:19 +0000104 case kTfLiteInt8:
105 return 1;
106 case kTfLiteInt16:
107 return 2;
108 case kTfLiteInt32:
alexander3c798932021-03-26 21:42:19 +0000109 case kTfLiteFloat32:
110 return 4;
111 default:
112 printf_err("Unsupported tensor type %s\n",
113 TfLiteTypeGetName(tensor->type));
114 }
115
116 return 0;
117 }
118
alexanderc350cdc2021-04-29 20:36:09 +0100119 bool Postprocess::EraseSectionsRowWise(
alexander3c798932021-03-26 21:42:19 +0000120 uint8_t* ptrData,
121 const uint32_t strideSzBytes,
122 const bool lastIteration)
123 {
124 /* In this case, the "zero-ing" is quite simple as the region
125 * to be zeroed sits in contiguous memory (row-major). */
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100126 const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
alexander3c798932021-03-26 21:42:19 +0000127
128 /* Erase left context? */
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100129 if (this->m_countIterations > 0) {
alexander3c798932021-03-26 21:42:19 +0000130 /* Set output of each classification window to the blank token. */
131 std::memset(ptrData, 0, eraseLen);
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100132 for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
133 ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
alexander3c798932021-03-26 21:42:19 +0000134 }
135 }
136
137 /* Erase right context? */
138 if (false == lastIteration) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100139 uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
alexander3c798932021-03-26 21:42:19 +0000140 /* Set output of each classification window to the blank token. */
141 std::memset(rightCtxPtr, 0, eraseLen);
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100142 for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
143 rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
alexander3c798932021-03-26 21:42:19 +0000144 }
145 }
146
147 if (lastIteration) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100148 this->m_countIterations = 0;
alexander3c798932021-03-26 21:42:19 +0000149 } else {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100150 ++this->m_countIterations;
alexander3c798932021-03-26 21:42:19 +0000151 }
152
153 return true;
154 }
155
alexanderc350cdc2021-04-29 20:36:09 +0100156 bool Postprocess::EraseSectionsColWise(
157 const uint8_t* ptrData,
alexander3c798932021-03-26 21:42:19 +0000158 const uint32_t strideSzBytes,
159 const bool lastIteration)
160 {
161 /* Not implemented. */
162 UNUSED(ptrData);
163 UNUSED(strideSzBytes);
164 UNUSED(lastIteration);
165 return false;
166 }
167
168} /* namespace asr */
169} /* namespace audio */
170} /* namespace app */
171} /* namespace arm */