blob: a7e75fbbfa9c38d56c7541f55989abc0f9e743ec [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include <catch.hpp>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "MicroNetKwsModel.hpp"
alexander3c798932021-03-26 21:42:19 +000019#include "hal.h"
20
21#include "KwsResult.hpp"
22#include "Labels.hpp"
23#include "UseCaseHandler.hpp"
24#include "Classifier.hpp"
25#include "UseCaseCommonUtils.hpp"
26
27TEST_CASE("Model info")
28{
29 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000030 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000031
32 /* Load the model. */
33 REQUIRE(model.Init());
34
35 /* Instantiate application context. */
36 arm::app::ApplicationContext caseContext;
37
38 caseContext.Set<arm::app::Model&>("model", model);
39
40 REQUIRE(model.ShowModelInfoHandler());
41}
42
43
44TEST_CASE("Inference by index")
45{
46 hal_platform platform;
47 data_acq_module data_acq;
48 data_psn_module data_psn;
49 platform_timer timer;
50
51 /* Initialise the HAL and platform. */
52 hal_init(&platform, &data_acq, &data_psn, &timer);
53 hal_platform_init(&platform);
54
55 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000056 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000057
58 /* Load the model. */
59 REQUIRE(model.Init());
60
61 /* Instantiate application context. */
62 arm::app::ApplicationContext caseContext;
Isabella Gottardi8df12f32021-04-07 17:15:31 +010063
64 arm::app::Profiler profiler{&platform, "kws"};
65 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +000066 caseContext.Set<hal_platform&>("platform", platform);
67 caseContext.Set<arm::app::Model&>("model", model);
Kshitij Sisodia76a15802021-12-24 11:05:11 +000068 caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */
69 caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */
alexander3c798932021-03-26 21:42:19 +000070 caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
71
72 arm::app::Classifier classifier; /* classifier wrapper object. */
73 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
74
75 auto checker = [&](uint32_t audioIndex, std::vector<uint32_t> labelIndex)
76 {
77 caseContext.Set<uint32_t>("audioIndex", audioIndex);
78
79 std::vector<std::string> labels;
80 GetLabelsVector(labels);
81 caseContext.Set<const std::vector<std::string> &>("labels", labels);
82
83 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, audioIndex, false));
84 REQUIRE(caseContext.Has("results"));
85
86 auto results = caseContext.Get<std::vector<arm::app::kws::KwsResult>>("results");
87
88 REQUIRE(results.size() == labelIndex.size());
89
90 for (size_t i = 0; i < results.size(); i++ ) {
91 REQUIRE(results[i].m_resultVec.size());
92 REQUIRE(results[i].m_resultVec[0].m_labelIdx == labelIndex[i]);
93 }
94
95 };
96
97 SECTION("Index = 0, short clip down")
98 {
99 /* Result: down. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000100 checker(0, {0});
alexander3c798932021-03-26 21:42:19 +0000101 }
102
103 SECTION("Index = 1, long clip right->left->up")
104 {
105 /* Result: right->right->left->up->up. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000106 checker(1, {6, 6, 2, 8, 8});
alexander3c798932021-03-26 21:42:19 +0000107 }
108
109 SECTION("Index = 2, short clip yes")
110 {
111 /* Result: yes. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000112 checker(2, {9});
alexander3c798932021-03-26 21:42:19 +0000113 }
114
115 SECTION("Index = 3, long clip yes->no->go->stop")
116 {
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000117 /* Result: yes->no->no->go->go->stop->stop. */
118 checker(3, {9, 3, 3, 1, 1, 7, 7});
alexander3c798932021-03-26 21:42:19 +0000119 }
120}
121
122
123TEST_CASE("Inference run all clips")
124{
125 hal_platform platform;
126 data_acq_module data_acq;
127 data_psn_module data_psn;
128 platform_timer timer;
129
130 /* Initialise the HAL and platform. */
131 hal_init(&platform, &data_acq, &data_psn, &timer);
132 hal_platform_init(&platform);
133
134 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000135 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000136
137 /* Load the model. */
138 REQUIRE(model.Init());
139
140 /* Instantiate application context. */
141 arm::app::ApplicationContext caseContext;
142
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100143 arm::app::Profiler profiler{&platform, "kws"};
144 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +0000145 caseContext.Set<hal_platform&>("platform", platform);
146 caseContext.Set<arm::app::Model&>("model", model);
147 caseContext.Set<uint32_t>("clipIndex", 0);
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000148 caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */
149 caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */
150 caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000151 arm::app::Classifier classifier; /* classifier wrapper object. */
152 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
153
154 std::vector <std::string> labels;
155 GetLabelsVector(labels);
156 caseContext.Set<const std::vector <std::string>&>("labels", labels);
157 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, 0, true));
158}
159
160
161TEST_CASE("List all audio clips")
162{
163 hal_platform platform;
164 data_acq_module data_acq;
165 data_psn_module data_psn;
166 platform_timer timer;
167
168 /* Initialise the HAL and platform. */
169 hal_init(&platform, &data_acq, &data_psn, &timer);
170 hal_platform_init(&platform);
171
172 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000173 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000174
175 /* Load the model. */
176 REQUIRE(model.Init());
177
178 /* Instantiate application context. */
179 arm::app::ApplicationContext caseContext;
180
181 caseContext.Set<hal_platform&>("platform", platform);
182 caseContext.Set<arm::app::Model&>("model", model);
183
184 REQUIRE(arm::app::ListFilesHandler(caseContext));
185}