blob: c109a62878f092493269658caa5f9682cfe9284f [file] [log] [blame]
Éanna Ó Catháin8f958872021-09-15 09:32:30 +01001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include <catch.hpp>
19#include <random>
20#include "hal.h"
21#include "InputFiles.hpp"
22#include "ImageUtils.hpp"
23#include "TestData_vww.hpp"
24#include "VisualWakeWordModel.hpp"
25#include "TensorFlowLiteMicro.hpp"
26
27
28bool RunInference(arm::app::Model& model, const int8_t* imageData)
29{
30 TfLiteTensor* inputTensor = model.GetInputTensor(0);
31 REQUIRE(inputTensor);
32
33 return model.RunInference();
34}
35
36template<typename T>
37void TestInference(int imageIdx,arm::app::Model& model) {
38
39 auto image = test::get_ifm_data_array(imageIdx);
40 auto goldenFV = test::get_ofm_data_array(imageIdx);
41
42 REQUIRE(RunInference(model, image));
43
44 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
45
46 REQUIRE(outputTensor);
47 REQUIRE(outputTensor->bytes == OFM_DATA_SIZE);
48 auto tensorData = tflite::GetTensorData<T>(outputTensor);
49 REQUIRE(tensorData);
50
51 for (size_t i = 0; i < outputTensor->bytes; i++) {
52 auto testVal = static_cast<int>(tensorData[i]);
53 auto goldenVal = static_cast<int>(goldenFV[i]);
54 CHECK(testVal == goldenVal);
55 }
56}
57
58
59/**
60 * @brief Given an image name, get its index
61 * @param[in] imageName Name of the image expected
62 * @return index of the image if valid and (-1) if not found
63 */
64static int _GetImageIdx(std::string &imageName)
65{
66 int imgIdx = -1;
67 for (uint32_t i = 0 ; i < NUMBER_OF_FILES; ++i) {
68 if (imageName == std::string(get_filename(i))) {
69 info("Image %s exists at index %u\n", get_filename(i), i);
70 imgIdx = static_cast<int>(i);
71 break;
72 }
73 }
74
75 if (-1 == imgIdx) {
76 warn("Image %s not found!\n", imageName.c_str());
77 }
78
79 return imgIdx;
80}
81