blob: db67e540a1af8040cdcc1b5b9a1f2bde3dc3f2d1 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include <catch.hpp>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "MicroNetKwsModel.hpp"
alexander3c798932021-03-26 21:42:19 +000019#include "hal.h"
20
21#include "KwsResult.hpp"
22#include "Labels.hpp"
23#include "UseCaseHandler.hpp"
24#include "Classifier.hpp"
25#include "UseCaseCommonUtils.hpp"
26
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010027namespace arm {
28 namespace app {
29 static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
30 namespace kws {
31 extern uint8_t* GetModelPointer();
32 extern size_t GetModelLen();
Liam Barry213a5432022-05-09 17:06:19 +010033 } /* namespace kws */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010034 } /* namespace app */
35} /* namespace arm */
36
alexander3c798932021-03-26 21:42:19 +000037TEST_CASE("Model info")
38{
39 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000040 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000041
42 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010043 REQUIRE(model.Init(arm::app::tensorArena,
44 sizeof(arm::app::tensorArena),
45 arm::app::kws::GetModelPointer(),
46 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +000047
48 /* Instantiate application context. */
49 arm::app::ApplicationContext caseContext;
50
51 caseContext.Set<arm::app::Model&>("model", model);
52
53 REQUIRE(model.ShowModelInfoHandler());
54}
55
56
57TEST_CASE("Inference by index")
58{
alexander3c798932021-03-26 21:42:19 +000059 /* Initialise the HAL and platform. */
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010060 hal_platform_init();
alexander3c798932021-03-26 21:42:19 +000061
62 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000063 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +000064
65 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010066 REQUIRE(model.Init(arm::app::tensorArena,
67 sizeof(arm::app::tensorArena),
68 arm::app::kws::GetModelPointer(),
69 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +000070
71 /* Instantiate application context. */
72 arm::app::ApplicationContext caseContext;
Isabella Gottardi8df12f32021-04-07 17:15:31 +010073
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010074 arm::app::Profiler profiler{"kws"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +010075 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +000076 caseContext.Set<arm::app::Model&>("model", model);
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010077 caseContext.Set<int>("frameLength", arm::app::kws::g_FrameLength); /* 640 sample length for MicroNetKws. */
78 caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride); /* 320 sample stride for MicroNetKws. */
alexander3c798932021-03-26 21:42:19 +000079 caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
80
81 arm::app::Classifier classifier; /* classifier wrapper object. */
82 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
83
84 auto checker = [&](uint32_t audioIndex, std::vector<uint32_t> labelIndex)
85 {
Richard Burtone6398cd2022-04-13 11:58:28 +010086 caseContext.Set<uint32_t>("clipIndex", audioIndex);
alexander3c798932021-03-26 21:42:19 +000087
88 std::vector<std::string> labels;
89 GetLabelsVector(labels);
90 caseContext.Set<const std::vector<std::string> &>("labels", labels);
91
92 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, audioIndex, false));
93 REQUIRE(caseContext.Has("results"));
94
95 auto results = caseContext.Get<std::vector<arm::app::kws::KwsResult>>("results");
96
97 REQUIRE(results.size() == labelIndex.size());
98
99 for (size_t i = 0; i < results.size(); i++ ) {
100 REQUIRE(results[i].m_resultVec.size());
101 REQUIRE(results[i].m_resultVec[0].m_labelIdx == labelIndex[i]);
102 }
103
104 };
105
106 SECTION("Index = 0, short clip down")
107 {
108 /* Result: down. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000109 checker(0, {0});
alexander3c798932021-03-26 21:42:19 +0000110 }
111
112 SECTION("Index = 1, long clip right->left->up")
113 {
114 /* Result: right->right->left->up->up. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000115 checker(1, {6, 6, 2, 8, 8});
alexander3c798932021-03-26 21:42:19 +0000116 }
117
118 SECTION("Index = 2, short clip yes")
119 {
120 /* Result: yes. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000121 checker(2, {9});
alexander3c798932021-03-26 21:42:19 +0000122 }
123
124 SECTION("Index = 3, long clip yes->no->go->stop")
125 {
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000126 /* Result: yes->no->no->go->go->stop->stop. */
127 checker(3, {9, 3, 3, 1, 1, 7, 7});
alexander3c798932021-03-26 21:42:19 +0000128 }
129}
130
131
132TEST_CASE("Inference run all clips")
133{
alexander3c798932021-03-26 21:42:19 +0000134 /* Initialise the HAL and platform. */
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100135 hal_platform_init();
alexander3c798932021-03-26 21:42:19 +0000136
137 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000138 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000139
140 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100141 REQUIRE(model.Init(arm::app::tensorArena,
142 sizeof(arm::app::tensorArena),
143 arm::app::kws::GetModelPointer(),
144 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +0000145
146 /* Instantiate application context. */
147 arm::app::ApplicationContext caseContext;
148
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100149 arm::app::Profiler profiler{"kws"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100150 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +0000151 caseContext.Set<arm::app::Model&>("model", model);
152 caseContext.Set<uint32_t>("clipIndex", 0);
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100153 caseContext.Set<int>("frameLength", arm::app::kws::g_FrameLength); /* 640 sample length for MicroNet. */
154 caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride); /* 320 sample stride for MicroNet. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000155 caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000156 arm::app::Classifier classifier; /* classifier wrapper object. */
157 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
158
159 std::vector <std::string> labels;
160 GetLabelsVector(labels);
161 caseContext.Set<const std::vector <std::string>&>("labels", labels);
162 REQUIRE(arm::app::ClassifyAudioHandler(caseContext, 0, true));
163}
164
165
166TEST_CASE("List all audio clips")
167{
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100168 /* Initialise the HAL and platform. */
169 hal_platform_init();
alexander3c798932021-03-26 21:42:19 +0000170
171 /* Model wrapper object. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000172 arm::app::MicroNetKwsModel model;
alexander3c798932021-03-26 21:42:19 +0000173
174 /* Load the model. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100175 REQUIRE(model.Init(arm::app::tensorArena,
176 sizeof(arm::app::tensorArena),
177 arm::app::kws::GetModelPointer(),
178 arm::app::kws::GetModelLen()));
alexander3c798932021-03-26 21:42:19 +0000179
180 /* Instantiate application context. */
181 arm::app::ApplicationContext caseContext;
182
alexander3c798932021-03-26 21:42:19 +0000183 caseContext.Set<arm::app::Model&>("model", model);
184
185 REQUIRE(arm::app::ListFilesHandler(caseContext));
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100186}