blob: 05a31a479eedb86d5976b85704d20a0f448dc804 [file] [log] [blame]
Éanna Ó Catháin8f958872021-09-15 09:32:30 +01001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include <catch.hpp>
18#include "VisualWakeWordModel.hpp"
19#include "hal.h"
20
21#include "ClassificationResult.hpp"
22#include "Labels.hpp"
23#include "UseCaseHandler.hpp"
24#include "Classifier.hpp"
25#include "UseCaseCommonUtils.hpp"
26
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010027namespace arm {
28 namespace app {
29 static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
30 } /* namespace app */
31} /* namespace arm */
32
33extern uint8_t* GetModelPointer();
34extern size_t GetModelLen();
35
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010036TEST_CASE("Model info")
37{
38 arm::app::VisualWakeWordModel model; /* model wrapper object */
39
40 /* Load the model */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010041 REQUIRE(model.Init(arm::app::tensorArena,
42 sizeof(arm::app::tensorArena),
43 GetModelPointer(),
44 GetModelLen()));
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010045
46 /* Instantiate application context */
47 arm::app::ApplicationContext caseContext;
48
49 caseContext.Set<arm::app::Model&>("model", model);
50
51 REQUIRE(model.ShowModelInfoHandler());
52}
53
54TEST_CASE("Inference by index")
55{
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010056 hal_platform_init();
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010057
58 arm::app::VisualWakeWordModel model; /* model wrapper object */
59
60 /* Load the model */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010061 REQUIRE(model.Init(arm::app::tensorArena,
62 sizeof(arm::app::tensorArena),
63 GetModelPointer(),
64 GetModelLen()));
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010065
66 /* Instantiate application context */
67 arm::app::ApplicationContext caseContext;
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010068 arm::app::Profiler profiler{"pd"};
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010069 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010070 caseContext.Set<arm::app::Model&>("model", model);
71 caseContext.Set<uint32_t>("imgIndex", 0);
72 arm::app::Classifier classifier; /* classifier wrapper object */
73 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
74
75 std::vector <std::string> labels;
76 GetLabelsVector(labels);
77 caseContext.Set<const std::vector <std::string>&>("labels", labels);
78
79 REQUIRE(arm::app::ClassifyImageHandler(caseContext, 0, false));
80
81 auto results = caseContext.Get<std::vector<arm::app::ClassificationResult>>("results");
82
Isabella Gottardi79d41542021-10-20 15:52:32 +010083 REQUIRE(results[0].m_labelIdx == 1);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010084}
85
86TEST_CASE("Inference run all images")
87{
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010088 /* Initialise the HAL and platform */
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010089 hal_platform_init();
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010090
91 arm::app::VisualWakeWordModel model; /* model wrapper object */
92
93 /* Load the model */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010094 REQUIRE(model.Init(arm::app::tensorArena,
95 sizeof(arm::app::tensorArena),
96 GetModelPointer(),
97 GetModelLen()));
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010098
99 /* Instantiate application context */
100 arm::app::ApplicationContext caseContext;
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100101 arm::app::Profiler profiler{"pd"};
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100102 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100103 caseContext.Set<arm::app::Model&>("model", model);
104 caseContext.Set<uint32_t>("imgIndex", 0);
105 arm::app::Classifier classifier; /* classifier wrapper object */
106 caseContext.Set<arm::app::Classifier&>("classifier", classifier);
107
108 std::vector <std::string> labels;
109 GetLabelsVector(labels);
110 caseContext.Set<const std::vector <std::string>&>("labels", labels);
111
112 REQUIRE(arm::app::ClassifyImageHandler(caseContext, 0, true));
113}
114
115TEST_CASE("List all images")
116{
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100117 /* Initialise the HAL and platform */
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100118 hal_platform_init();
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100119
120 arm::app::VisualWakeWordModel model; /* model wrapper object */
121
122 /* Load the model */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100123 REQUIRE(model.Init(arm::app::tensorArena,
124 sizeof(arm::app::tensorArena),
125 GetModelPointer(),
126 GetModelLen()));
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100127
128 /* Instantiate application context */
129 arm::app::ApplicationContext caseContext;
130
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100131 caseContext.Set<arm::app::Model&>("model", model);
132
133 REQUIRE(arm::app::ListFilesHandler(caseContext));
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100134}