blob: 60ee51e0eb2462ecf094c754c85241506aa94592 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18
19#include "Wav2LetterModel.hpp"
20
21
22namespace arm {
23namespace app {
24namespace audio {
25namespace asr {
26
27 Postprocess::Postprocess(const uint32_t contextLen,
28 const uint32_t innerLen,
29 const uint32_t blankTokenIdx)
30 : _m_contextLen(contextLen),
31 _m_innerLen(innerLen),
32 _m_totalLen(2 * this->_m_contextLen + this->_m_innerLen),
33 _m_countIterations(0),
34 _m_blankTokenIdx(blankTokenIdx)
35 {}
36
37 bool Postprocess::Invoke(TfLiteTensor* tensor,
38 const uint32_t axisIdx,
39 const bool lastIteration)
40 {
41 /* Basic checks. */
42 if (!this->_IsInputValid(tensor, axisIdx)) {
43 return false;
44 }
45
46 /* Irrespective of tensor type, we use unsigned "byte" */
47 uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
48 const uint32_t elemSz = this->_GetTensorElementSize(tensor);
49
50 /* Other sanity checks. */
51 if (0 == elemSz) {
52 printf_err("Tensor type not supported for post processing\n");
53 return false;
54 } else if (elemSz * this->_m_totalLen > tensor->bytes) {
55 printf_err("Insufficient number of tensor bytes\n");
56 return false;
57 }
58
59 /* Which axis do we need to process? */
60 switch (axisIdx) {
61 case arm::app::Wav2LetterModel::ms_outputRowsIdx:
62 return this->_EraseSectionsRowWise(ptrData,
63 elemSz * tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
64 lastIteration);
65 case arm::app::Wav2LetterModel::ms_outputColsIdx:
66 return this->_EraseSectionsColWise(ptrData,
67 elemSz * tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx],
68 lastIteration);
69 default:
70 printf_err("Unsupported axis index: %u\n", axisIdx);
71 }
72
73 return false;
74 }
75
76 bool Postprocess::_IsInputValid(TfLiteTensor* tensor,
77 const uint32_t axisIdx) const
78 {
79 if (nullptr == tensor) {
80 return false;
81 }
82
83 if (static_cast<int>(axisIdx) >= tensor->dims->size) {
84 printf_err("Invalid axis index: %u; Max: %d\n",
85 axisIdx, tensor->dims->size);
86 return false;
87 }
88
89 if (static_cast<int>(this->_m_totalLen) !=
90 tensor->dims->data[axisIdx]) {
91 printf_err("Unexpected tensor dimension for axis %d, \n",
92 tensor->dims->data[axisIdx]);
93 return false;
94 }
95
96 return true;
97 }
98
99 uint32_t Postprocess::_GetTensorElementSize(TfLiteTensor* tensor)
100 {
101 switch(tensor->type) {
102 case kTfLiteUInt8:
103 return 1;
104 case kTfLiteInt8:
105 return 1;
106 case kTfLiteInt16:
107 return 2;
108 case kTfLiteInt32:
109 return 4;
110 case kTfLiteFloat32:
111 return 4;
112 default:
113 printf_err("Unsupported tensor type %s\n",
114 TfLiteTypeGetName(tensor->type));
115 }
116
117 return 0;
118 }
119
120 bool Postprocess::_EraseSectionsRowWise(
121 uint8_t* ptrData,
122 const uint32_t strideSzBytes,
123 const bool lastIteration)
124 {
125 /* In this case, the "zero-ing" is quite simple as the region
126 * to be zeroed sits in contiguous memory (row-major). */
127 const uint32_t eraseLen = strideSzBytes * this->_m_contextLen;
128
129 /* Erase left context? */
130 if (this->_m_countIterations > 0) {
131 /* Set output of each classification window to the blank token. */
132 std::memset(ptrData, 0, eraseLen);
133 for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) {
134 ptrData[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1;
135 }
136 }
137
138 /* Erase right context? */
139 if (false == lastIteration) {
140 uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->_m_contextLen + this->_m_innerLen));
141 /* Set output of each classification window to the blank token. */
142 std::memset(rightCtxPtr, 0, eraseLen);
143 for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) {
144 rightCtxPtr[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1;
145 }
146 }
147
148 if (lastIteration) {
149 this->_m_countIterations = 0;
150 } else {
151 ++this->_m_countIterations;
152 }
153
154 return true;
155 }
156
157 bool Postprocess::_EraseSectionsColWise(
158 uint8_t* ptrData,
159 const uint32_t strideSzBytes,
160 const bool lastIteration)
161 {
162 /* Not implemented. */
163 UNUSED(ptrData);
164 UNUSED(strideSzBytes);
165 UNUSED(lastIteration);
166 return false;
167 }
168
169} /* namespace asr */
170} /* namespace audio */
171} /* namespace app */
172} /* namespace arm */