blob: 6b368dad0008eda9ebdf6b4aace83af707dfd3b0 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include <catch.hpp>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "MicroNetKwsModel.hpp"
alexander3c798932021-03-26 21:42:19 +000019#include "hal.h"
20
21#include "KwsResult.hpp"
22#include "Labels.hpp"
23#include "UseCaseHandler.hpp"
24#include "Classifier.hpp"
25#include "UseCaseCommonUtils.hpp"
Richard Burtonec5e99b2022-10-05 11:00:37 +010026#include "BufAttributes.hpp"
alexander3c798932021-03-26 21:42:19 +000027
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010028namespace arm {
29 namespace app {
30 static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
31 namespace kws {
32 extern uint8_t* GetModelPointer();
33 extern size_t GetModelLen();
Liam Barry213a5432022-05-09 17:06:19 +010034 } /* namespace kws */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010035 } /* namespace app */
36} /* namespace arm */
37
alexander3c798932021-03-26 21:42:19 +000038TEST_CASE("Model info")
39{
40 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000041 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000042
43 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010044 REQUIRE(model.Init(arm::app::tensorArena,
45 sizeof(arm::app::tensorArena),
46 arm::app::kws::GetModelPointer(),
47 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +000048
49 /* Instantiate application context. */
50 arm::app::ApplicationContext caseContext;
51
52 caseContext.Set<arm::app::Model&>("model", model);
53
54 REQUIRE(model.ShowModelInfoHandler());
55}
56
57
58TEST_CASE("Inference by index")
59{
alexander3c798932021-03-26 21:42:19 +000060 /* Initialise the HAL and platform. */
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010061 hal_platform_init();
alexander3c798932021-03-26 21:42:19 +000062
63 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000064 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000065
66 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010067 REQUIRE(model.Init(arm::app::tensorArena,
68 sizeof(arm::app::tensorArena),
69 arm::app::kws::GetModelPointer(),
70 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +000071
72 /* Instantiate application context. */
73 arm::app::ApplicationContext caseContext;
Isabella Gottardi8df12f32021-04-07 17:15:31 +010074
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010075 arm::app::Profiler profiler{"kws"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +010076 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +000077 caseContext.Set<arm::app::Model&>("model", model);
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010078 caseContext.Set<int>("frameLength", arm::app::kws::g_FrameLength); /* 640 sample length for MicroNetKws. */
79 caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride); /* 320 sample stride for MicroNetKws. */
alexander3c798932021-03-26 21:42:19 +000080 caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
81
82 arm::app::Classifier classifier; /* classifier wrapper object. */
83 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
84
85 auto checker = [&](uint32_t audioIndex, std::vector<uint32_t> labelIndex)
86 {
Richard Burtone6398cd2022-04-13 11:58:28 +010087 caseContext.Set<uint32_t>("clipIndex", audioIndex);
alexander3c798932021-03-26 21:42:19 +000088
89 std::vector<std::string> labels;
90 GetLabelsVector(labels);
91 caseContext.Set<const std::vector<std::string> &>("labels", labels);
92
93 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, audioIndex, false));
94 REQUIRE(caseContext.Has("results"));
95
96 auto results = caseContext.Get<std::vector<arm::app::kws::KwsResult>>("results");
97
98 REQUIRE(results.size() == labelIndex.size());
99
100 for (size_t i = 0; i < results.size(); i++ ) {
101 REQUIRE(results[i].m_resultVec.size());
102 REQUIRE(results[i].m_resultVec[0].m_labelIdx == labelIndex[i]);
103 }
104
105 };
106
107 SECTION("Index = 0, short clip down")
108 {
109 /* Result: down. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000110 checker(0, {0});
alexander3c798932021-03-26 21:42:19 +0000111 }
112
113 SECTION("Index = 1, long clip right->left->up")
114 {
115 /* Result: right->right->left->up->up. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000116 checker(1, {6, 6, 2, 8, 8});
alexander3c798932021-03-26 21:42:19 +0000117 }
118
119 SECTION("Index = 2, short clip yes")
120 {
121 /* Result: yes. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000122 checker(2, {9});
alexander3c798932021-03-26 21:42:19 +0000123 }
124
125 SECTION("Index = 3, long clip yes->no->go->stop")
126 {
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000127 /* Result: yes->no->no->go->go->stop->stop. */
128 checker(3, {9, 3, 3, 1, 1, 7, 7});
alexander3c798932021-03-26 21:42:19 +0000129 }
130}
131
132
133TEST_CASE("Inference run all clips")
134{
alexander3c798932021-03-26 21:42:19 +0000135 /* Initialise the HAL and platform. */
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100136 hal_platform_init();
alexander3c798932021-03-26 21:42:19 +0000137
138 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000139 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000140
141 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100142 REQUIRE(model.Init(arm::app::tensorArena,
143 sizeof(arm::app::tensorArena),
144 arm::app::kws::GetModelPointer(),
145 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +0000146
147 /* Instantiate application context. */
148 arm::app::ApplicationContext caseContext;
149
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100150 arm::app::Profiler profiler{"kws"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100151 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +0000152 caseContext.Set<arm::app::Model&>("model", model);
153 caseContext.Set<uint32_t>("clipIndex", 0);
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100154 caseContext.Set<int>("frameLength", arm::app::kws::g_FrameLength); /* 640 sample length for MicroNet. */
155 caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride); /* 320 sample stride for MicroNet. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000156 caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000157 arm::app::Classifier classifier; /* classifier wrapper object. */
158 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
159
160 std::vector <std::string> labels;
161 GetLabelsVector(labels);
162 caseContext.Set<const std::vector <std::string>&>("labels", labels);
163 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, 0, true));
164}
165
166
167TEST_CASE("List all audio clips")
168{
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100169 /* Initialise the HAL and platform. */
170 hal_platform_init();
alexander3c798932021-03-26 21:42:19 +0000171
172 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000173 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000174
175 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100176 REQUIRE(model.Init(arm::app::tensorArena,
177 sizeof(arm::app::tensorArena),
178 arm::app::kws::GetModelPointer(),
179 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +0000180
181 /* Instantiate application context. */
182 arm::app::ApplicationContext caseContext;
183
alexander3c798932021-03-26 21:42:19 +0000184 caseContext.Set<arm::app::Model&>("model", model);
185
186 REQUIRE(arm::app::ListFilesHandler(caseContext));
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100187}