blob: 6fd7df362aac3005feb13e337a47f333b33f03b4 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Wav2LetterPostprocess.hpp"
18#include "Wav2LetterModel.hpp"
19
20#include <algorithm>
21#include <catch.hpp>
22#include <limits>
23
24template <typename T>
25static TfLiteTensor GetTestTensor(std::vector <int>& shape,
26 T initVal,
27 std::vector<T>& vectorBuf)
28{
29 REQUIRE(0 != shape.size());
30
31 shape.insert(shape.begin(), shape.size());
32 uint32_t sizeInBytes = sizeof(T);
33 for (size_t i = 1; i < shape.size(); ++i) {
34 sizeInBytes *= shape[i];
35 }
36
37 /* Allocate mem. */
38 vectorBuf = std::vector<T>(sizeInBytes, initVal);
39 TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data());
40 return tflite::testing::CreateQuantizedTensor(
41 vectorBuf.data(), dims,
42 1, 0, "test-tensor");
43}
44
45TEST_CASE("Checking return value")
46{
47 SECTION("Mismatched post processing parameters and tensor size")
48 {
49 const uint32_t ctxLen = 5;
50 const uint32_t innerLen = 3;
51 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
52
53 std::vector <int> tensorShape = {1, 1, 1, 13};
54 std::vector <int8_t> tensorVec;
55 TfLiteTensor tensor = GetTestTensor<int8_t>(
56 tensorShape, 100, tensorVec);
57 REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
58 }
59
60 SECTION("Post processing succeeds")
61 {
62 const uint32_t ctxLen = 5;
63 const uint32_t innerLen = 3;
64 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
65
66 std::vector <int> tensorShape = {1, 1, 13, 1};
67 std::vector <int8_t> tensorVec;
68 TfLiteTensor tensor = GetTestTensor<int8_t>(
69 tensorShape, 100, tensorVec);
70
71 /* Copy elements to compare later. */
72 std::vector <int8_t> originalVec = tensorVec;
73
74 /* This step should not erase anything. */
75 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
76 }
77}
78
79TEST_CASE("Postprocessing - erasing required elements")
80{
81 constexpr uint32_t ctxLen = 5;
82 constexpr uint32_t innerLen = 3;
83 constexpr uint32_t nRows = 2*ctxLen + innerLen;
84 constexpr uint32_t nCols = 10;
85 constexpr uint32_t blankTokenIdx = nCols - 1;
86 std::vector <int> tensorShape = {1, 1, nRows, nCols};
87
88 SECTION("First and last iteration")
89 {
90 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
91 std::vector <int8_t> tensorVec;
92 TfLiteTensor tensor = GetTestTensor<int8_t>(
93 tensorShape, 100, tensorVec);
94
95 /* Copy elements to compare later. */
96 std::vector <int8_t> originalVec = tensorVec;
97
98 /* This step should not erase anything. */
99 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
100 REQUIRE(originalVec == tensorVec);
101 }
102
103 SECTION("Right context erase")
104 {
105 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
106
107 std::vector <int8_t> tensorVec;
108 TfLiteTensor tensor = GetTestTensor<int8_t>(
109 tensorShape, 100, tensorVec);
110
111 /* Copy elements to compare later. */
112 std::vector <int8_t> originalVec = tensorVec;
113
114 /* This step should erase the right context only. */
115 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
116 REQUIRE(originalVec != tensorVec);
117
118 /* The last ctxLen * 10 elements should be gone. */
119 for (size_t i = 0; i < ctxLen; ++i) {
120 for (size_t j = 0; j < nCols; ++j) {
121 /* Check right context elements are zeroed. */
122 if (j == blankTokenIdx) {
123 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
124 } else {
125 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
126 }
127
128 /* Check left context is preserved. */
129 CHECK(tensorVec[i*nCols + j] == originalVec[i*nCols + j]);
130 }
131 }
132
133 /* Check inner elements are preserved. */
134 for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
135 CHECK(tensorVec[i] == originalVec[i]);
136 }
137 }
138
139 SECTION("Left and right context erase")
140 {
141 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
142
143 std::vector <int8_t> tensorVec;
144 TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
145
146 /* Copy elements to compare later. */
147 std::vector <int8_t> originalVec = tensorVec;
148
149 /* This step should erase right context. */
150 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
151
152 /* Calling it the second time should erase the left context. */
153 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
154
155 REQUIRE(originalVec != tensorVec);
156
157 /* The first and last ctxLen * 10 elements should be gone. */
158 for (size_t i = 0; i < ctxLen; ++i) {
159 for (size_t j = 0; j < nCols; ++j) {
160 /* Check left and right context elements are zeroed. */
161 if (j == blankTokenIdx) {
162 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 1);
163 CHECK(tensorVec[i * nCols + j] == 1);
164 } else {
165 CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 0);
166 CHECK(tensorVec[i * nCols + j] == 0);
167 }
168 }
169 }
170
171 /* Check inner elements are preserved. */
172 for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
173 /* Check left context is preserved. */
174 CHECK(tensorVec[i] == originalVec[i]);
175 }
176 }
177
178 SECTION("Try left context erase")
179 {
180 /* Should not be able to erase the left context if it is the first iteration. */
181 arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
182
183 std::vector <int8_t> tensorVec;
184 TfLiteTensor tensor = GetTestTensor<int8_t>(
185 tensorShape, 100, tensorVec);
186
187 /* Copy elements to compare later. */
188 std::vector <int8_t> originalVec = tensorVec;
189
190 /* Calling it the second time should erase the left context. */
191 REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
192 REQUIRE(originalVec == tensorVec);
193 }
194}