blob: 48ac32c34966822ef7aa145d33af87c730b924b0 [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <catch.hpp>
6#include <opencv2/opencv.hpp>
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +01007#include "ObjectDetectionPipeline.hpp"
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01008#include "Types.hpp"
9
10static std::string GetResourceFilePath(const std::string& filename)
11{
12 std::string testResources = TEST_RESOURCE_DIR;
13 if (0 == testResources.size())
14 {
15 throw "Invalid test resources directory provided";
16 }
17 else
18 {
19 if(testResources.back() != '/')
20 {
21 return testResources + "/" + filename;
22 }
23 else
24 {
25 return testResources + filename;
26 }
27 }
28}
29
30TEST_CASE("Test Network Execution SSD_MOBILE")
31{
32 std::string testResources = TEST_RESOURCE_DIR;
33 REQUIRE(testResources != "");
34 // Create the network options
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010035 common::PipelineOptions options;
George Gekov23c26272021-08-16 11:32:10 +010036 options.m_ModelFilePath = GetResourceFilePath("ssd_mobilenet_v1.tflite");
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010037 options.m_ModelName = "SSD_MOBILE";
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +000038 options.m_backends = {"CpuAcc", "CpuRef"};
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010039
40 od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options);
41
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010042 common::InferenceResults<float> results;
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010043 cv::Mat processed;
44 cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
45 cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB);
46
47 objectDetectionPipeline->PreProcessing(inputFrame, processed);
48
49 CHECK(processed.type() == CV_8UC3);
50 CHECK(processed.cols == 300);
51 CHECK(processed.rows == 300);
52
53 objectDetectionPipeline->Inference(processed, results);
54 objectDetectionPipeline->PostProcessing(results,
55 [](od::DetectedObjects detects) -> void {
56 CHECK(detects.size() == 2);
57 CHECK(detects[0].GetLabel() == "0");
58 });
59
60}