blob: d19e0c4ff8e55777526cba483281506ed6bfb465 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Conor Kennedy5cf8e742023-02-13 10:50:40 +00002 * SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18#include "Wav2LetterModel.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010019#include "ClassificationResult.hpp"
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010020#include "BufAttributes.hpp"
alexander3c798932021-03-26 21:42:19 +000021
22#include <algorithm>
23#include <catch.hpp>
24#include <limits>
25
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010026namespace arm {
27 namespace app {
28 static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
29
30 namespace asr {
31 extern uint8_t* GetModelPointer();
32 extern size_t GetModelLen();
33 }
34 namespace kws {
35 extern uint8_t* GetModelPointer();
36 extern size_t GetModelLen();
37 }
38 } /* namespace app */
39} /* namespace arm */
40
alexander3c798932021-03-26 21:42:19 +000041template <typename T>
Richard Burton4e002792022-05-04 09:45:02 +010042static TfLiteTensor GetTestTensor(
43 std::vector<int>& shape,
44 T initVal,
45 std::vector<T>& vectorBuf)
alexander3c798932021-03-26 21:42:19 +000046{
47 REQUIRE(0 != shape.size());
48
49 shape.insert(shape.begin(), shape.size());
50 uint32_t sizeInBytes = sizeof(T);
51 for (size_t i = 1; i < shape.size(); ++i) {
52 sizeInBytes *= shape[i];
53 }
54
55 /* Allocate mem. */
56 vectorBuf = std::vector<T>(sizeInBytes, initVal);
57 TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data());
58 return tflite::testing::CreateQuantizedTensor(
Richard Burton4e002792022-05-04 09:45:02 +010059 vectorBuf.data(), dims,
60 1, 0, "test-tensor");
alexander3c798932021-03-26 21:42:19 +000061}
62
63TEST_CASE("Checking return value")
64{
65 SECTION("Mismatched post processing parameters and tensor size")
66 {
Richard Burton4e002792022-05-04 09:45:02 +010067 const uint32_t outputCtxLen = 5;
68 arm::app::AsrClassifier classifier;
69 arm::app::Wav2LetterModel model;
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010070 model.Init(arm::app::tensorArena,
71 sizeof(arm::app::tensorArena),
72 arm::app::asr::GetModelPointer(),
73 arm::app::asr::GetModelLen());
Conor Kennedy5cf8e742023-02-13 10:50:40 +000074 std::vector<std::string> placeholderLabels = {"a", "b", "$"};
Richard Burton4e002792022-05-04 09:45:02 +010075 const uint32_t blankTokenIdx = 2;
Conor Kennedy5cf8e742023-02-13 10:50:40 +000076 std::vector<arm::app::ClassificationResult> placeholderResult;
alexander3c798932021-03-26 21:42:19 +000077 std::vector <int> tensorShape = {1, 1, 1, 13};
78 std::vector <int8_t> tensorVec;
79 TfLiteTensor tensor = GetTestTensor<int8_t>(
Richard Burton4e002792022-05-04 09:45:02 +010080 tensorShape, 100, tensorVec);
81
Conor Kennedy5cf8e742023-02-13 10:50:40 +000082 arm::app::AsrPostProcess post{&tensor, classifier, placeholderLabels, placeholderResult, outputCtxLen,
Richard Burton4e002792022-05-04 09:45:02 +010083 blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
84
85 REQUIRE(!post.DoPostProcess());
alexander3c798932021-03-26 21:42:19 +000086 }
87
88 SECTION("Post processing succeeds")
89 {
Richard Burton4e002792022-05-04 09:45:02 +010090 const uint32_t outputCtxLen = 5;
91 arm::app::AsrClassifier classifier;
92 arm::app::Wav2LetterModel model;
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010093 model.Init(arm::app::tensorArena,
94 sizeof(arm::app::tensorArena),
95 arm::app::asr::GetModelPointer(),
96 arm::app::asr::GetModelLen());
Conor Kennedy5cf8e742023-02-13 10:50:40 +000097 std::vector<std::string> placeholderLabels = {"a", "b", "$"};
Richard Burton4e002792022-05-04 09:45:02 +010098 const uint32_t blankTokenIdx = 2;
Conor Kennedy5cf8e742023-02-13 10:50:40 +000099 std::vector<arm::app::ClassificationResult> placeholderResult;
Richard Burton4e002792022-05-04 09:45:02 +0100100 std::vector<int> tensorShape = {1, 1, 13, 1};
101 std::vector<int8_t> tensorVec;
alexander3c798932021-03-26 21:42:19 +0000102 TfLiteTensor tensor = GetTestTensor<int8_t>(
Richard Burton4e002792022-05-04 09:45:02 +0100103 tensorShape, 100, tensorVec);
104
Conor Kennedy5cf8e742023-02-13 10:50:40 +0000105 arm::app::AsrPostProcess post{&tensor, classifier, placeholderLabels, placeholderResult, outputCtxLen,
Richard Burton4e002792022-05-04 09:45:02 +0100106 blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
alexander3c798932021-03-26 21:42:19 +0000107
108 /* Copy elements to compare later. */
Richard Burton4e002792022-05-04 09:45:02 +0100109 std::vector<int8_t> originalVec = tensorVec;
alexander3c798932021-03-26 21:42:19 +0000110
111 /* This step should not erase anything. */
Richard Burton4e002792022-05-04 09:45:02 +0100112 REQUIRE(post.DoPostProcess());
alexander3c798932021-03-26 21:42:19 +0000113 }
114}
115
Richard Burton4e002792022-05-04 09:45:02 +0100116
alexander3c798932021-03-26 21:42:19 +0000117TEST_CASE("Postprocessing - erasing required elements")
118{
Richard Burton4e002792022-05-04 09:45:02 +0100119 constexpr uint32_t outputCtxLen = 5;
alexander3c798932021-03-26 21:42:19 +0000120 constexpr uint32_t innerLen = 3;
Richard Burton4e002792022-05-04 09:45:02 +0100121 constexpr uint32_t nRows = 2*outputCtxLen + innerLen;
alexander3c798932021-03-26 21:42:19 +0000122 constexpr uint32_t nCols = 10;
123 constexpr uint32_t blankTokenIdx = nCols - 1;
Richard Burton4e002792022-05-04 09:45:02 +0100124 std::vector<int> tensorShape = {1, 1, nRows, nCols};
125 arm::app::AsrClassifier classifier;
126 arm::app::Wav2LetterModel model;
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100127 model.Init(arm::app::tensorArena,
128 sizeof(arm::app::tensorArena),
129 arm::app::asr::GetModelPointer(),
130 arm::app::asr::GetModelLen());
Conor Kennedy5cf8e742023-02-13 10:50:40 +0000131 std::vector<std::string> placeholderLabels = {"a", "b", "$"};
132 std::vector<arm::app::ClassificationResult> placeholderResult;
alexander3c798932021-03-26 21:42:19 +0000133
134 SECTION("First and last iteration")
135 {
Richard Burton4e002792022-05-04 09:45:02 +0100136 std::vector<int8_t> tensorVec;
137 TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
Conor Kennedy5cf8e742023-02-13 10:50:40 +0000138 arm::app::AsrPostProcess post{&tensor, classifier, placeholderLabels, placeholderResult, outputCtxLen,
Richard Burton4e002792022-05-04 09:45:02 +0100139 blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
alexander3c798932021-03-26 21:42:19 +0000140
141 /* Copy elements to compare later. */
Richard Burton4e002792022-05-04 09:45:02 +0100142 std::vector<int8_t>originalVec = tensorVec;
alexander3c798932021-03-26 21:42:19 +0000143
144 /* This step should not erase anything. */
Richard Burton4e002792022-05-04 09:45:02 +0100145 post.m_lastIteration = true;
146 REQUIRE(post.DoPostProcess());
alexander3c798932021-03-26 21:42:19 +0000147 REQUIRE(originalVec == tensorVec);
148 }
149
150 SECTION("Right context erase")
151 {
alexander3c798932021-03-26 21:42:19 +0000152 std::vector <int8_t> tensorVec;
153 TfLiteTensor tensor = GetTestTensor<int8_t>(
Richard Burton4e002792022-05-04 09:45:02 +0100154 tensorShape, 100, tensorVec);
Conor Kennedy5cf8e742023-02-13 10:50:40 +0000155 arm::app::AsrPostProcess post{&tensor, classifier, placeholderLabels, placeholderResult, outputCtxLen,
Richard Burton4e002792022-05-04 09:45:02 +0100156 blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
alexander3c798932021-03-26 21:42:19 +0000157
158 /* Copy elements to compare later. */
Richard Burton4e002792022-05-04 09:45:02 +0100159 std::vector<int8_t> originalVec = tensorVec;
alexander3c798932021-03-26 21:42:19 +0000160
161 /* This step should erase the right context only. */
Richard Burton4e002792022-05-04 09:45:02 +0100162 post.m_lastIteration = false;
163 REQUIRE(post.DoPostProcess());
alexander3c798932021-03-26 21:42:19 +0000164 REQUIRE(originalVec != tensorVec);
165
166 /* The last ctxLen * 10 elements should be gone. */
Richard Burton4e002792022-05-04 09:45:02 +0100167 for (size_t i = 0; i < outputCtxLen; ++i) {
alexander3c798932021-03-26 21:42:19 +0000168 for (size_t j = 0; j < nCols; ++j) {
Richard Burton4e002792022-05-04 09:45:02 +0100169 /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */
alexander3c798932021-03-26 21:42:19 +0000170 if (j == blankTokenIdx) {
Richard Burton4e002792022-05-04 09:45:02 +0100171 CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
alexander3c798932021-03-26 21:42:19 +0000172 } else {
Richard Burton4e002792022-05-04 09:45:02 +0100173 CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
alexander3c798932021-03-26 21:42:19 +0000174 }
175
176 /* Check left context is preserved. */
177 CHECK(tensorVec[i*nCols + j] == originalVec[i*nCols + j]);
178 }
179 }
180
181 /* Check inner elements are preserved. */
Richard Burton4e002792022-05-04 09:45:02 +0100182 for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
alexander3c798932021-03-26 21:42:19 +0000183 CHECK(tensorVec[i] == originalVec[i]);
184 }
185 }
186
187 SECTION("Left and right context erase")
188 {
alexander3c798932021-03-26 21:42:19 +0000189 std::vector <int8_t> tensorVec;
Richard Burton4e002792022-05-04 09:45:02 +0100190 TfLiteTensor tensor = GetTestTensor<int8_t>(
191 tensorShape, 100, tensorVec);
Conor Kennedy5cf8e742023-02-13 10:50:40 +0000192 arm::app::AsrPostProcess post{&tensor, classifier, placeholderLabels, placeholderResult, outputCtxLen,
Richard Burton4e002792022-05-04 09:45:02 +0100193 blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
alexander3c798932021-03-26 21:42:19 +0000194
195 /* Copy elements to compare later. */
196 std::vector <int8_t> originalVec = tensorVec;
197
198 /* This step should erase right context. */
Richard Burton4e002792022-05-04 09:45:02 +0100199 post.m_lastIteration = false;
200 REQUIRE(post.DoPostProcess());
alexander3c798932021-03-26 21:42:19 +0000201
202 /* Calling it the second time should erase the left context. */
Richard Burton4e002792022-05-04 09:45:02 +0100203 REQUIRE(post.DoPostProcess());
alexander3c798932021-03-26 21:42:19 +0000204
205 REQUIRE(originalVec != tensorVec);
206
207 /* The first and last ctxLen * 10 elements should be gone. */
Richard Burton4e002792022-05-04 09:45:02 +0100208 for (size_t i = 0; i < outputCtxLen; ++i) {
alexander3c798932021-03-26 21:42:19 +0000209 for (size_t j = 0; j < nCols; ++j) {
210 /* Check left and right context elements are zeroed. */
211 if (j == blankTokenIdx) {
Richard Burton4e002792022-05-04 09:45:02 +0100212 CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
213 CHECK(tensorVec[i*nCols + j] == 1);
alexander3c798932021-03-26 21:42:19 +0000214 } else {
Richard Burton4e002792022-05-04 09:45:02 +0100215 CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
216 CHECK(tensorVec[i*nCols + j] == 0);
alexander3c798932021-03-26 21:42:19 +0000217 }
218 }
219 }
220
221 /* Check inner elements are preserved. */
Richard Burton4e002792022-05-04 09:45:02 +0100222 for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
alexander3c798932021-03-26 21:42:19 +0000223 /* Check left context is preserved. */
224 CHECK(tensorVec[i] == originalVec[i]);
225 }
226 }
227
228 SECTION("Try left context erase")
229 {
alexander3c798932021-03-26 21:42:19 +0000230 std::vector <int8_t> tensorVec;
231 TfLiteTensor tensor = GetTestTensor<int8_t>(
Richard Burton4e002792022-05-04 09:45:02 +0100232 tensorShape, 100, tensorVec);
233
234 /* Should not be able to erase the left context if it is the first iteration. */
Conor Kennedy5cf8e742023-02-13 10:50:40 +0000235 arm::app::AsrPostProcess post{&tensor, classifier, placeholderLabels, placeholderResult, outputCtxLen,
Richard Burton4e002792022-05-04 09:45:02 +0100236 blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
alexander3c798932021-03-26 21:42:19 +0000237
238 /* Copy elements to compare later. */
239 std::vector <int8_t> originalVec = tensorVec;
240
241 /* Calling it the second time should erase the left context. */
Richard Burton4e002792022-05-04 09:45:02 +0100242 post.m_lastIteration = true;
243 REQUIRE(post.DoPostProcess());
244
alexander3c798932021-03-26 21:42:19 +0000245 REQUIRE(originalVec == tensorVec);
246 }
Richard Burton4e002792022-05-04 09:45:02 +0100247}