blob: dee2f6fed7846bf11fa24195df3ffdfb4283a476 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include <catch.hpp>
18#include "DsCnnModel.hpp"
19#include "hal.h"
20
21#include "KwsResult.hpp"
22#include "Labels.hpp"
23#include "UseCaseHandler.hpp"
24#include "Classifier.hpp"
25#include "UseCaseCommonUtils.hpp"
26
27TEST_CASE("Model info")
28{
29 /* Model wrapper object. */
30 arm::app::DsCnnModel model;
31
32 /* Load the model. */
33 REQUIRE(model.Init());
34
35 /* Instantiate application context. */
36 arm::app::ApplicationContext caseContext;
37
38 caseContext.Set<arm::app::Model&>("model", model);
39
40 REQUIRE(model.ShowModelInfoHandler());
41}
42
43
44TEST_CASE("Inference by index")
45{
46 hal_platform platform;
47 data_acq_module data_acq;
48 data_psn_module data_psn;
49 platform_timer timer;
50
51 /* Initialise the HAL and platform. */
52 hal_init(&platform, &data_acq, &data_psn, &timer);
53 hal_platform_init(&platform);
54
55 /* Model wrapper object. */
56 arm::app::DsCnnModel model;
57
58 /* Load the model. */
59 REQUIRE(model.Init());
60
61 /* Instantiate application context. */
62 arm::app::ApplicationContext caseContext;
63 caseContext.Set<hal_platform&>("platform", platform);
64 caseContext.Set<arm::app::Model&>("model", model);
65 caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */
66 caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */
67 caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
68
69 arm::app::Classifier classifier; /* classifier wrapper object. */
70 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
71
72 auto checker = [&](uint32_t audioIndex, std::vector<uint32_t> labelIndex)
73 {
74 caseContext.Set<uint32_t>("audioIndex", audioIndex);
75
76 std::vector<std::string> labels;
77 GetLabelsVector(labels);
78 caseContext.Set<const std::vector<std::string> &>("labels", labels);
79
80 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, audioIndex, false));
81 REQUIRE(caseContext.Has("results"));
82
83 auto results = caseContext.Get<std::vector<arm::app::kws::KwsResult>>("results");
84
85 REQUIRE(results.size() == labelIndex.size());
86
87 for (size_t i = 0; i < results.size(); i++ ) {
88 REQUIRE(results[i].m_resultVec.size());
89 REQUIRE(results[i].m_resultVec[0].m_labelIdx == labelIndex[i]);
90 }
91
92 };
93
94 SECTION("Index = 0, short clip down")
95 {
96 /* Result: down. */
97 checker(0, {5});
98 }
99
100 SECTION("Index = 1, long clip right->left->up")
101 {
102 /* Result: right->right->left->up->up. */
103 checker(1, {7, 1, 6, 4, 4});
104 }
105
106 SECTION("Index = 2, short clip yes")
107 {
108 /* Result: yes. */
109 checker(2, {2});
110 }
111
112 SECTION("Index = 3, long clip yes->no->go->stop")
113 {
114 /* Result: yes->go->no->go->go->go->stop. */
115 checker(3, {2, 11, 3, 11, 11, 11, 10});
116 }
117}
118
119
120TEST_CASE("Inference run all clips")
121{
122 hal_platform platform;
123 data_acq_module data_acq;
124 data_psn_module data_psn;
125 platform_timer timer;
126
127 /* Initialise the HAL and platform. */
128 hal_init(&platform, &data_acq, &data_psn, &timer);
129 hal_platform_init(&platform);
130
131 /* Model wrapper object. */
132 arm::app::DsCnnModel model;
133
134 /* Load the model. */
135 REQUIRE(model.Init());
136
137 /* Instantiate application context. */
138 arm::app::ApplicationContext caseContext;
139
140 caseContext.Set<hal_platform&>("platform", platform);
141 caseContext.Set<arm::app::Model&>("model", model);
142 caseContext.Set<uint32_t>("clipIndex", 0);
143 caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */
144 caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */
145 caseContext.Set<float>("scoreThreshold", 0.9); /* Normalised score threshold. */
146 arm::app::Classifier classifier; /* classifier wrapper object. */
147 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
148
149 std::vector <std::string> labels;
150 GetLabelsVector(labels);
151 caseContext.Set<const std::vector <std::string>&>("labels", labels);
152 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, 0, true));
153}
154
155
156TEST_CASE("List all audio clips")
157{
158 hal_platform platform;
159 data_acq_module data_acq;
160 data_psn_module data_psn;
161 platform_timer timer;
162
163 /* Initialise the HAL and platform. */
164 hal_init(&platform, &data_acq, &data_psn, &timer);
165 hal_platform_init(&platform);
166
167 /* Model wrapper object. */
168 arm::app::DsCnnModel model;
169
170 /* Load the model. */
171 REQUIRE(model.Init());
172
173 /* Instantiate application context. */
174 arm::app::ApplicationContext caseContext;
175
176 caseContext.Set<hal_platform&>("platform", platform);
177 caseContext.Set<arm::app::Model&>("model", model);
178
179 REQUIRE(arm::app::ListFilesHandler(caseContext));
180}