blob: 30136118c2ad3d74df9aa96ee165f25b41f690bd [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include <catch.hpp>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "MicroNetKwsModel.hpp"
alexander3c798932021-03-26 21:42:19 +000019#include "hal.h"
20
21#include "KwsResult.hpp"
22#include "Labels.hpp"
23#include "UseCaseHandler.hpp"
24#include "Classifier.hpp"
25#include "UseCaseCommonUtils.hpp"
26
27TEST_CASE("Model info")
28{
29 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000030 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000031
32 /* Load the model. */
33 REQUIRE(model.Init());
34
35 /* Instantiate application context. */
36 arm::app::ApplicationContext caseContext;
37
38 caseContext.Set<arm::app::Model&>("model", model);
39
40 REQUIRE(model.ShowModelInfoHandler());
41}
42
43
44TEST_CASE("Inference by index")
45{
46 hal_platform platform;
alexander3c798932021-03-26 21:42:19 +000047 platform_timer timer;
48
49 /* Initialise the HAL and platform. */
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010050 hal_init(&platform, &timer);
alexander3c798932021-03-26 21:42:19 +000051 hal_platform_init(&platform);
52
53 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000054 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000055
56 /* Load the model. */
57 REQUIRE(model.Init());
58
59 /* Instantiate application context. */
60 arm::app::ApplicationContext caseContext;
Isabella Gottardi8df12f32021-04-07 17:15:31 +010061
62 arm::app::Profiler profiler{&platform, "kws"};
63 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +000064 caseContext.Set<hal_platform&>("platform", platform);
65 caseContext.Set<arm::app::Model&>("model", model);
Kshitij Sisodia76a15802021-12-24 11:05:11 +000066 caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */
67 caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */
alexander3c798932021-03-26 21:42:19 +000068 caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
69
70 arm::app::Classifier classifier; /* classifier wrapper object. */
71 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
72
73 auto checker = [&](uint32_t audioIndex, std::vector<uint32_t> labelIndex)
74 {
75 caseContext.Set<uint32_t>("audioIndex", audioIndex);
76
77 std::vector<std::string> labels;
78 GetLabelsVector(labels);
79 caseContext.Set<const std::vector<std::string> &>("labels", labels);
80
81 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, audioIndex, false));
82 REQUIRE(caseContext.Has("results"));
83
84 auto results = caseContext.Get<std::vector<arm::app::kws::KwsResult>>("results");
85
86 REQUIRE(results.size() == labelIndex.size());
87
88 for (size_t i = 0; i < results.size(); i++ ) {
89 REQUIRE(results[i].m_resultVec.size());
90 REQUIRE(results[i].m_resultVec[0].m_labelIdx == labelIndex[i]);
91 }
92
93 };
94
95 SECTION("Index = 0, short clip down")
96 {
97 /* Result: down. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000098 checker(0, {0});
alexander3c798932021-03-26 21:42:19 +000099 }
100
101 SECTION("Index = 1, long clip right->left->up")
102 {
103 /* Result: right->right->left->up->up. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000104 checker(1, {6, 6, 2, 8, 8});
alexander3c798932021-03-26 21:42:19 +0000105 }
106
107 SECTION("Index = 2, short clip yes")
108 {
109 /* Result: yes. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000110 checker(2, {9});
alexander3c798932021-03-26 21:42:19 +0000111 }
112
113 SECTION("Index = 3, long clip yes->no->go->stop")
114 {
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000115 /* Result: yes->no->no->go->go->stop->stop. */
116 checker(3, {9, 3, 3, 1, 1, 7, 7});
alexander3c798932021-03-26 21:42:19 +0000117 }
118}
119
120
121TEST_CASE("Inference run all clips")
122{
123 hal_platform platform;
alexander3c798932021-03-26 21:42:19 +0000124 platform_timer timer;
125
126 /* Initialise the HAL and platform. */
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100127 hal_init(&platform, &timer);
alexander3c798932021-03-26 21:42:19 +0000128 hal_platform_init(&platform);
129
130 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000131 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000132
133 /* Load the model. */
134 REQUIRE(model.Init());
135
136 /* Instantiate application context. */
137 arm::app::ApplicationContext caseContext;
138
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100139 arm::app::Profiler profiler{&platform, "kws"};
140 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +0000141 caseContext.Set<hal_platform&>("platform", platform);
142 caseContext.Set<arm::app::Model&>("model", model);
143 caseContext.Set<uint32_t>("clipIndex", 0);
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000144 caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */
145 caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */
146 caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000147 arm::app::Classifier classifier; /* classifier wrapper object. */
148 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
149
150 std::vector <std::string> labels;
151 GetLabelsVector(labels);
152 caseContext.Set<const std::vector <std::string>&>("labels", labels);
153 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, 0, true));
154}
155
156
157TEST_CASE("List all audio clips")
158{
159 hal_platform platform;
alexander3c798932021-03-26 21:42:19 +0000160 platform_timer timer;
161
162 /* Initialise the HAL and platform. */
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100163 hal_init(&platform, &timer);
alexander3c798932021-03-26 21:42:19 +0000164 hal_platform_init(&platform);
165
166 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000167 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000168
169 /* Load the model. */
170 REQUIRE(model.Init());
171
172 /* Instantiate application context. */
173 arm::app::ApplicationContext caseContext;
174
175 caseContext.Set<hal_platform&>("platform", platform);
176 caseContext.Set<arm::app::Model&>("model", model);
177
178 REQUIRE(arm::app::ListFilesHandler(caseContext));
179}