blob: 4e2fe17e10b8466ed8443eb2a46bf4490106e41c [file] [log] [blame]
Richard Burtonec5e99b2022-10-05 11:00:37 +01001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Richard Burtonec5e99b2022-10-05 11:00:37 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "KwsClassifier.hpp"
18
19#include <catch.hpp>
20
21TEST_CASE("Test invalid classifier")
22{
23 TfLiteTensor* outputTens = nullptr;
24 std::vector<arm::app::ClassificationResult> resultVec;
25 arm::app::KwsClassifier classifier;
26 std::vector<std::vector<float>> resultHistory;
27 REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5, true, resultHistory));
28}
29
30TEST_CASE("Test valid classifier, average=0 should be same as 1)")
31{
32 int dimArray[] = {1, 5};
33 std::vector<std::string> labels(5);
34 std::vector<uint8_t> outputVec = {0, 1, 2, 3, 4};
35 std::vector<std::vector<float>> resultHistory = {};
36 TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
37 TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor(
38 outputVec.data(), dims, 1, 0);
39 TfLiteTensor* outputTensor = &tfTensor;
40 std::vector<arm::app::ClassificationResult> resultVec;
41 arm::app::KwsClassifier classifier;
42
43 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
44 REQUIRE(resultVec[0].m_labelIdx == 4);
45 REQUIRE(resultVec[0].m_normalisedVal == 4);
46
47 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
48 REQUIRE(resultVec[0].m_labelIdx == 4);
49 REQUIRE(resultVec[0].m_normalisedVal == 4);
50
51 std::vector<std::vector<float>> expectedHistory = {};
52 REQUIRE(resultHistory == expectedHistory);
53}
54
55TEST_CASE("Test valid classifier UINT8, average=1, softmax=false")
56{
57 int dimArray[] = {1, 5};
58 std::vector<std::string> labels(5);
59 std::vector<uint8_t> outputVec = {0, 1, 2, 3, 4};
60 std::vector<std::vector<float>> resultHistory = {{0, 0, 0, 0, 0}};
61 TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
62 TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor(
63 outputVec.data(), dims, 1, 0);
64 TfLiteTensor* outputTensor = &tfTensor;
65 std::vector<arm::app::ClassificationResult> resultVec;
66 arm::app::KwsClassifier classifier;
67
68 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
69 REQUIRE(resultVec[0].m_labelIdx == 4);
70 REQUIRE(resultVec[0].m_normalisedVal == 4);
71
72 /* We do not update history if not >1 in size. */
73 std::vector<std::vector<float>> expectedHistory = {{0, 0, 0, 0, 0}};
74 REQUIRE(resultHistory == expectedHistory);
75}
76
77TEST_CASE("Test valid classifier UINT8, average=2")
78{
79 int dimArray[] = {1, 5};
80 std::vector<std::string> labels(5);
81 std::vector<uint8_t> outputVec = {0, 1, 2, 3, 4};
82 std::vector<std::vector<float>> resultHistory = {{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}};
83 TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
84 TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor(
85 outputVec.data(), dims, 1, 0);
86 TfLiteTensor* outputTensor = &tfTensor;
87 std::vector<arm::app::ClassificationResult> resultVec;
88 arm::app::KwsClassifier classifier;
89
90 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
91 REQUIRE(resultVec[0].m_labelIdx == 4);
92 REQUIRE(resultVec[0].m_normalisedVal == 2);
93
94 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
95 REQUIRE(resultVec[0].m_labelIdx == 4);
96 REQUIRE(resultVec[0].m_normalisedVal == 4);
97
98 std::vector<std::vector<float>> expectedHistory = {{0, 1, 2, 3, 4}, {0, 1, 2, 3, 4}};
99 REQUIRE(resultHistory == expectedHistory);
100}
101
102TEST_CASE("Test valid classifier int8, average=0")
103{
104 int dimArray[] = {1, 5};
105 std::vector<std::string> labels(5);
106 std::vector<int8_t> outputVec = {-2, -1, 0, 2, 1};
107 std::vector<std::vector<float>> resultHistory = {};
108 TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
109 TfLiteTensor tfTensor = tflite::testing::CreateQuantizedTensor(
110 outputVec.data(), dims, 1, 0);
111 TfLiteTensor* outputTensor = &tfTensor;
112 std::vector<arm::app::ClassificationResult> resultVec;
113 arm::app::KwsClassifier classifier;
114
115 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
116 REQUIRE(resultVec[0].m_labelIdx == 3);
117 REQUIRE(resultVec[0].m_normalisedVal == 2);
118
119 REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 1, false, resultHistory));
120 REQUIRE(resultVec[0].m_labelIdx == 3);
121 REQUIRE(resultVec[0].m_normalisedVal == 2);
122
123 std::vector<std::vector<float>> expectedHistory = {};
124 REQUIRE(resultHistory == expectedHistory);
125}