blob: b17396825612b478bf186d03e67bfe97536229a8 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18
19#include "Wav2LetterModel.hpp"
20
21namespace arm {
22namespace app {
23namespace audio {
24namespace asr {
25
26 Postprocess::Postprocess(const uint32_t contextLen,
27 const uint32_t innerLen,
28 const uint32_t blankTokenIdx)
29 : _m_contextLen(contextLen),
30 _m_innerLen(innerLen),
31 _m_totalLen(2 * this->_m_contextLen + this->_m_innerLen),
32 _m_countIterations(0),
33 _m_blankTokenIdx(blankTokenIdx)
34 {}
35
36 bool Postprocess::Invoke(TfLiteTensor* tensor,
37 const uint32_t axisIdx,
38 const bool lastIteration)
39 {
40 /* Basic checks. */
41 if (!this->_IsInputValid(tensor, axisIdx)) {
42 return false;
43 }
44
45 /* Irrespective of tensor type, we use unsigned "byte" */
46 uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
47 const uint32_t elemSz = this->_GetTensorElementSize(tensor);
48
49 /* Other sanity checks. */
50 if (0 == elemSz) {
51 printf_err("Tensor type not supported for post processing\n");
52 return false;
53 } else if (elemSz * this->_m_totalLen > tensor->bytes) {
54 printf_err("Insufficient number of tensor bytes\n");
55 return false;
56 }
57
58 /* Which axis do we need to process? */
59 switch (axisIdx) {
60 case arm::app::Wav2LetterModel::ms_outputRowsIdx:
61 return this->_EraseSectionsRowWise(ptrData,
62 elemSz * tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
63 lastIteration);
64 default:
65 printf_err("Unsupported axis index: %u\n", axisIdx);
66 }
67
68 return false;
69 }
70
71 bool Postprocess::_IsInputValid(TfLiteTensor* tensor,
72 const uint32_t axisIdx) const
73 {
74 if (nullptr == tensor) {
75 return false;
76 }
77
78 if (static_cast<int>(axisIdx) >= tensor->dims->size) {
79 printf_err("Invalid axis index: %u; Max: %d\n",
80 axisIdx, tensor->dims->size);
81 return false;
82 }
83
84 if (static_cast<int>(this->_m_totalLen) !=
85 tensor->dims->data[axisIdx]) {
86 printf_err("Unexpected tensor dimension for axis %d, \n",
87 tensor->dims->data[axisIdx]);
88 return false;
89 }
90
91 return true;
92 }
93
94 uint32_t Postprocess::_GetTensorElementSize(TfLiteTensor* tensor)
95 {
96 switch(tensor->type) {
97 case kTfLiteUInt8:
98 return 1;
99 case kTfLiteInt8:
100 return 1;
101 case kTfLiteInt16:
102 return 2;
103 case kTfLiteInt32:
104 return 4;
105 case kTfLiteFloat32:
106 return 4;
107 default:
108 printf_err("Unsupported tensor type %s\n",
109 TfLiteTypeGetName(tensor->type));
110 }
111
112 return 0;
113 }
114
115 bool Postprocess::_EraseSectionsRowWise(
116 uint8_t* ptrData,
117 const uint32_t strideSzBytes,
118 const bool lastIteration)
119 {
120 /* In this case, the "zero-ing" is quite simple as the region
121 * to be zeroed sits in contiguous memory (row-major). */
122 const uint32_t eraseLen = strideSzBytes * this->_m_contextLen;
123
124 /* Erase left context? */
125 if (this->_m_countIterations > 0) {
126 /* Set output of each classification window to the blank token. */
127 std::memset(ptrData, 0, eraseLen);
128 for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) {
129 ptrData[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1;
130 }
131 }
132
133 /* Erase right context? */
134 if (false == lastIteration) {
135 uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->_m_contextLen + this->_m_innerLen));
136 /* Set output of each classification window to the blank token. */
137 std::memset(rightCtxPtr, 0, eraseLen);
138 for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) {
139 rightCtxPtr[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1;
140 }
141 }
142
143 if (lastIteration) {
144 this->_m_countIterations = 0;
145 } else {
146 ++this->_m_countIterations;
147 }
148
149 return true;
150 }
151
152} /* namespace asr */
153} /* namespace audio */
154} /* namespace app */
155} /* namespace arm */