blob: 9157a6f4db8bc9b2f24140712d2eaaf8268557e6 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18
19#include "Wav2LetterModel.hpp"
20
21
22namespace arm {
23namespace app {
24namespace audio {
25namespace asr {
26
27 Postprocess::Postprocess(const uint32_t contextLen,
28 const uint32_t innerLen,
29 const uint32_t blankTokenIdx)
30 : _m_contextLen(contextLen),
31 _m_innerLen(innerLen),
32 _m_totalLen(2 * this->_m_contextLen + this->_m_innerLen),
33 _m_countIterations(0),
34 _m_blankTokenIdx(blankTokenIdx)
35 {}
36
37 bool Postprocess::Invoke(TfLiteTensor* tensor,
38 const uint32_t axisIdx,
39 const bool lastIteration)
40 {
41 /* Basic checks. */
alexanderc350cdc2021-04-29 20:36:09 +010042 if (!this->IsInputValid(tensor, axisIdx)) {
alexander3c798932021-03-26 21:42:19 +000043 return false;
44 }
45
46 /* Irrespective of tensor type, we use unsigned "byte" */
47 uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
alexanderc350cdc2021-04-29 20:36:09 +010048 const uint32_t elemSz = this->GetTensorElementSize(tensor);
alexander3c798932021-03-26 21:42:19 +000049
50 /* Other sanity checks. */
51 if (0 == elemSz) {
52 printf_err("Tensor type not supported for post processing\n");
53 return false;
54 } else if (elemSz * this->_m_totalLen > tensor->bytes) {
55 printf_err("Insufficient number of tensor bytes\n");
56 return false;
57 }
58
59 /* Which axis do we need to process? */
60 switch (axisIdx) {
61 case arm::app::Wav2LetterModel::ms_outputRowsIdx:
alexanderc350cdc2021-04-29 20:36:09 +010062 return this->EraseSectionsRowWise(ptrData,
63 elemSz *
64 tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
65 lastIteration);
alexander3c798932021-03-26 21:42:19 +000066 case arm::app::Wav2LetterModel::ms_outputColsIdx:
alexanderc350cdc2021-04-29 20:36:09 +010067 return this->EraseSectionsColWise(ptrData,
68 elemSz *
69 tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx],
70 lastIteration);
alexander3c798932021-03-26 21:42:19 +000071 default:
72 printf_err("Unsupported axis index: %u\n", axisIdx);
73 }
74
75 return false;
76 }
77
alexanderc350cdc2021-04-29 20:36:09 +010078 bool Postprocess::IsInputValid(TfLiteTensor* tensor,
79 const uint32_t axisIdx) const
alexander3c798932021-03-26 21:42:19 +000080 {
81 if (nullptr == tensor) {
82 return false;
83 }
84
85 if (static_cast<int>(axisIdx) >= tensor->dims->size) {
86 printf_err("Invalid axis index: %u; Max: %d\n",
87 axisIdx, tensor->dims->size);
88 return false;
89 }
90
91 if (static_cast<int>(this->_m_totalLen) !=
92 tensor->dims->data[axisIdx]) {
93 printf_err("Unexpected tensor dimension for axis %d, \n",
94 tensor->dims->data[axisIdx]);
95 return false;
96 }
97
98 return true;
99 }
100
alexanderc350cdc2021-04-29 20:36:09 +0100101 uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor)
alexander3c798932021-03-26 21:42:19 +0000102 {
103 switch(tensor->type) {
104 case kTfLiteUInt8:
alexander3c798932021-03-26 21:42:19 +0000105 case kTfLiteInt8:
106 return 1;
107 case kTfLiteInt16:
108 return 2;
109 case kTfLiteInt32:
alexander3c798932021-03-26 21:42:19 +0000110 case kTfLiteFloat32:
111 return 4;
112 default:
113 printf_err("Unsupported tensor type %s\n",
114 TfLiteTypeGetName(tensor->type));
115 }
116
117 return 0;
118 }
119
alexanderc350cdc2021-04-29 20:36:09 +0100120 bool Postprocess::EraseSectionsRowWise(
alexander3c798932021-03-26 21:42:19 +0000121 uint8_t* ptrData,
122 const uint32_t strideSzBytes,
123 const bool lastIteration)
124 {
125 /* In this case, the "zero-ing" is quite simple as the region
126 * to be zeroed sits in contiguous memory (row-major). */
127 const uint32_t eraseLen = strideSzBytes * this->_m_contextLen;
128
129 /* Erase left context? */
130 if (this->_m_countIterations > 0) {
131 /* Set output of each classification window to the blank token. */
132 std::memset(ptrData, 0, eraseLen);
133 for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) {
134 ptrData[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1;
135 }
136 }
137
138 /* Erase right context? */
139 if (false == lastIteration) {
140 uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->_m_contextLen + this->_m_innerLen));
141 /* Set output of each classification window to the blank token. */
142 std::memset(rightCtxPtr, 0, eraseLen);
143 for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) {
144 rightCtxPtr[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1;
145 }
146 }
147
148 if (lastIteration) {
149 this->_m_countIterations = 0;
150 } else {
151 ++this->_m_countIterations;
152 }
153
154 return true;
155 }
156
alexanderc350cdc2021-04-29 20:36:09 +0100157 bool Postprocess::EraseSectionsColWise(
158 const uint8_t* ptrData,
alexander3c798932021-03-26 21:42:19 +0000159 const uint32_t strideSzBytes,
160 const bool lastIteration)
161 {
162 /* Not implemented. */
163 UNUSED(ptrData);
164 UNUSED(strideSzBytes);
165 UNUSED(lastIteration);
166 return false;
167 }
168
169} /* namespace asr */
170} /* namespace audio */
171} /* namespace app */
172} /* namespace arm */