blob: 9ed2e1be6735865b72f038debe912421d8d1116c [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18#include "Wav2LetterModel.hpp"
19
20#include <algorithm>
21#include <catch.hpp>
22#include <limits>
23
24template <typename T>
25static TfLiteTensor GetTestTensor(
26 std::vector <int>& shape,
27 T initVal,
28 std::vector<T>& vectorBuf)
29{
30 REQUIRE(0 != shape.size());
31
32 shape.insert(shape.begin(), shape.size());
33 uint32_t sizeInBytes = sizeof(T);
34 for (size_t i = 1; i < shape.size(); ++i) {
35 sizeInBytes *= shape[i];
36 }
37
38 /* Allocate mem. */
39 vectorBuf = std::vector<T>(sizeInBytes, initVal);
40 TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data());
41 return tflite::testing::CreateQuantizedTensor(
42 vectorBuf.data(), dims,
43 1, 0, "test-tensor");
44}
45
46TEST_CASE("Checking return value")
47{
48 SECTION("Mismatched post processing parameters and tensor size")
49 {
50 const uint32_t ctxLen = 5;
51 const uint32_t innerLen = 3;
52 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
53
54 std::vector <int> tensorShape = {1, 1, 1, 13};
55 std::vector <int8_t> tensorVec;
56 TfLiteTensor tensor = GetTestTensor<int8_t>(
57 tensorShape, 100, tensorVec);
58 REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
59 }
60
61 SECTION("Post processing succeeds")
62 {
63 const uint32_t ctxLen = 5;
64 const uint32_t innerLen = 3;
65 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
66
67 std::vector <int> tensorShape = {1, 1, 13, 1};
68 std::vector <int8_t> tensorVec;
69 TfLiteTensor tensor = GetTestTensor<int8_t>(
70 tensorShape, 100, tensorVec);
71
72 /* Copy elements to compare later. */
73 std::vector <int8_t> originalVec = tensorVec;
74
75 /* This step should not erase anything. */
76 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
77 }
78}
79
80
81TEST_CASE("Postprocessing - erasing required elements")
82{
83 constexpr uint32_t ctxLen = 5;
84 constexpr uint32_t innerLen = 3;
85 constexpr uint32_t nRows = 2*ctxLen + innerLen;
86 constexpr uint32_t nCols = 10;
87 constexpr uint32_t blankTokenIdx = nCols - 1;
88 std::vector <int> tensorShape = {1, 1, nRows, nCols};
89
90 SECTION("First and last iteration")
91 {
92 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
93
94 std::vector <int8_t> tensorVec;
95 TfLiteTensor tensor = GetTestTensor<int8_t>(
96 tensorShape, 100, tensorVec);
97
98 /* Copy elements to compare later. */
99 std::vector <int8_t> originalVec = tensorVec;
100
101 /* This step should not erase anything. */
102 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
103 REQUIRE(originalVec == tensorVec);
104 }
105
106 SECTION("Right context erase")
107 {
108 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
109
110 std::vector <int8_t> tensorVec;
111 TfLiteTensor tensor = GetTestTensor<int8_t>(
112 tensorShape, 100, tensorVec);
113
114 /* Copy elements to compare later. */
115 std::vector <int8_t> originalVec = tensorVec;
116
117 /* This step should erase the right context only. */
118 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
119 REQUIRE(originalVec != tensorVec);
120
121 /* The last ctxLen * 10 elements should be gone. */
122 for (size_t i = 0; i < ctxLen; ++i) {
123 for (size_t j = 0; j < nCols; ++j) {
124 /* Check right context elements are zeroed. */
125 if (j == blankTokenIdx) {
126 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
127 } else {
128 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
129 }
130
131 /* Check left context is preserved. */
132 CHECK(tensorVec[i*nCols + j] == originalVec[i*nCols + j]);
133 }
134 }
135
136 /* Check inner elements are preserved. */
137 for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
138 CHECK(tensorVec[i] == originalVec[i]);
139 }
140 }
141
142 SECTION("Left and right context erase")
143 {
144 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
145
146 std::vector <int8_t> tensorVec;
147 TfLiteTensor tensor = GetTestTensor<int8_t>(
148 tensorShape, 100, tensorVec);
149
150 /* Copy elements to compare later. */
151 std::vector <int8_t> originalVec = tensorVec;
152
153 /* This step should erase right context. */
154 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
155
156 /* Calling it the second time should erase the left context. */
157 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
158
159 REQUIRE(originalVec != tensorVec);
160
161 /* The first and last ctxLen * 10 elements should be gone. */
162 for (size_t i = 0; i < ctxLen; ++i) {
163 for (size_t j = 0; j < nCols; ++j) {
164 /* Check left and right context elements are zeroed. */
165 if (j == blankTokenIdx) {
166 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
167 CHECK(tensorVec[i*nCols + j] == 1);
168 } else {
169 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
170 CHECK(tensorVec[i*nCols + j] == 0);
171 }
172 }
173 }
174
175 /* Check inner elements are preserved. */
176 for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
177 /* Check left context is preserved. */
178 CHECK(tensorVec[i] == originalVec[i]);
179 }
180 }
181
182 SECTION("Try left context erase")
183 {
184 /* Should not be able to erase the left context if it is the first iteration. */
185 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
186
187 std::vector <int8_t> tensorVec;
188 TfLiteTensor tensor = GetTestTensor<int8_t>(
189 tensorShape, 100, tensorVec);
190
191 /* Copy elements to compare later. */
192 std::vector <int8_t> originalVec = tensorVec;
193
194 /* Calling it the second time should erase the left context. */
195 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
196
197 REQUIRE(originalVec == tensorVec);
198 }
199}