blob: f2d9357d1fe1f73c482d21c75de795d13af5c7e3 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18
19#include "Wav2LetterModel.hpp"
20
21namespace arm {
22namespace app {
23namespace audio {
24namespace asr {
25
26 Postprocess::Postprocess(const uint32_t contextLen,
27 const uint32_t innerLen,
28 const uint32_t blankTokenIdx)
Isabella Gottardi56ee6202021-05-12 08:27:15 +010029 : m_contextLen(contextLen),
30 m_innerLen(innerLen),
31 m_totalLen(2 * this->m_contextLen + this->m_innerLen),
32 m_countIterations(0),
33 m_blankTokenIdx(blankTokenIdx)
alexander3c798932021-03-26 21:42:19 +000034 {}
35
36 bool Postprocess::Invoke(TfLiteTensor* tensor,
37 const uint32_t axisIdx,
38 const bool lastIteration)
39 {
40 /* Basic checks. */
alexanderc350cdc2021-04-29 20:36:09 +010041 if (!this->IsInputValid(tensor, axisIdx)) {
alexander3c798932021-03-26 21:42:19 +000042 return false;
43 }
44
45 /* Irrespective of tensor type, we use unsigned "byte" */
46 uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
alexanderc350cdc2021-04-29 20:36:09 +010047 const uint32_t elemSz = this->GetTensorElementSize(tensor);
alexander3c798932021-03-26 21:42:19 +000048
49 /* Other sanity checks. */
50 if (0 == elemSz) {
51 printf_err("Tensor type not supported for post processing\n");
52 return false;
Isabella Gottardi56ee6202021-05-12 08:27:15 +010053 } else if (elemSz * this->m_totalLen > tensor->bytes) {
alexander3c798932021-03-26 21:42:19 +000054 printf_err("Insufficient number of tensor bytes\n");
55 return false;
56 }
57
58 /* Which axis do we need to process? */
59 switch (axisIdx) {
60 case arm::app::Wav2LetterModel::ms_outputRowsIdx:
alexanderc350cdc2021-04-29 20:36:09 +010061 return this->EraseSectionsRowWise(ptrData,
62 elemSz *
63 tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
64 lastIteration);
alexander3c798932021-03-26 21:42:19 +000065 default:
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010066 printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
alexander3c798932021-03-26 21:42:19 +000067 }
68
69 return false;
70 }
71
alexanderc350cdc2021-04-29 20:36:09 +010072 bool Postprocess::IsInputValid(TfLiteTensor* tensor,
73 const uint32_t axisIdx) const
alexander3c798932021-03-26 21:42:19 +000074 {
75 if (nullptr == tensor) {
76 return false;
77 }
78
79 if (static_cast<int>(axisIdx) >= tensor->dims->size) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010080 printf_err("Invalid axis index: %" PRIu32 "; Max: %d\n",
alexander3c798932021-03-26 21:42:19 +000081 axisIdx, tensor->dims->size);
82 return false;
83 }
84
Isabella Gottardi56ee6202021-05-12 08:27:15 +010085 if (static_cast<int>(this->m_totalLen) !=
alexander3c798932021-03-26 21:42:19 +000086 tensor->dims->data[axisIdx]) {
87 printf_err("Unexpected tensor dimension for axis %d, \n",
88 tensor->dims->data[axisIdx]);
89 return false;
90 }
91
92 return true;
93 }
94
alexanderc350cdc2021-04-29 20:36:09 +010095 uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor)
alexander3c798932021-03-26 21:42:19 +000096 {
97 switch(tensor->type) {
98 case kTfLiteUInt8:
99 return 1;
100 case kTfLiteInt8:
101 return 1;
102 case kTfLiteInt16:
103 return 2;
104 case kTfLiteInt32:
105 return 4;
106 case kTfLiteFloat32:
107 return 4;
108 default:
109 printf_err("Unsupported tensor type %s\n",
110 TfLiteTypeGetName(tensor->type));
111 }
112
113 return 0;
114 }
115
alexanderc350cdc2021-04-29 20:36:09 +0100116 bool Postprocess::EraseSectionsRowWise(
alexander3c798932021-03-26 21:42:19 +0000117 uint8_t* ptrData,
118 const uint32_t strideSzBytes,
119 const bool lastIteration)
120 {
121 /* In this case, the "zero-ing" is quite simple as the region
122 * to be zeroed sits in contiguous memory (row-major). */
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100123 const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
alexander3c798932021-03-26 21:42:19 +0000124
125 /* Erase left context? */
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100126 if (this->m_countIterations > 0) {
alexander3c798932021-03-26 21:42:19 +0000127 /* Set output of each classification window to the blank token. */
128 std::memset(ptrData, 0, eraseLen);
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100129 for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
130 ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
alexander3c798932021-03-26 21:42:19 +0000131 }
132 }
133
134 /* Erase right context? */
135 if (false == lastIteration) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100136 uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
alexander3c798932021-03-26 21:42:19 +0000137 /* Set output of each classification window to the blank token. */
138 std::memset(rightCtxPtr, 0, eraseLen);
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100139 for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
140 rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
alexander3c798932021-03-26 21:42:19 +0000141 }
142 }
143
144 if (lastIteration) {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100145 this->m_countIterations = 0;
alexander3c798932021-03-26 21:42:19 +0000146 } else {
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100147 ++this->m_countIterations;
alexander3c798932021-03-26 21:42:19 +0000148 }
149
150 return true;
151 }
152
153} /* namespace asr */
154} /* namespace audio */
155} /* namespace app */
156} /* namespace arm */