Éanna Ó Catháin | 919c14e | 2020-09-14 17:36:49 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | #include <catch.hpp> |
| 6 | #include <opencv2/opencv.hpp> |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 7 | #include "ObjectDetectionPipeline.hpp" |
Éanna Ó Catháin | 919c14e | 2020-09-14 17:36:49 +0100 | [diff] [blame] | 8 | #include "Types.hpp" |
| 9 | |
| 10 | static std::string GetResourceFilePath(const std::string& filename) |
| 11 | { |
| 12 | std::string testResources = TEST_RESOURCE_DIR; |
| 13 | if (0 == testResources.size()) |
| 14 | { |
| 15 | throw "Invalid test resources directory provided"; |
| 16 | } |
| 17 | else |
| 18 | { |
| 19 | if(testResources.back() != '/') |
| 20 | { |
| 21 | return testResources + "/" + filename; |
| 22 | } |
| 23 | else |
| 24 | { |
| 25 | return testResources + filename; |
| 26 | } |
| 27 | } |
| 28 | } |
| 29 | |
| 30 | TEST_CASE("Test Network Execution SSD_MOBILE") |
| 31 | { |
| 32 | std::string testResources = TEST_RESOURCE_DIR; |
| 33 | REQUIRE(testResources != ""); |
| 34 | // Create the network options |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 35 | common::PipelineOptions options; |
Éanna Ó Catháin | 919c14e | 2020-09-14 17:36:49 +0100 | [diff] [blame] | 36 | options.m_ModelFilePath = GetResourceFilePath("detect.tflite"); |
| 37 | options.m_ModelName = "SSD_MOBILE"; |
| 38 | options.m_backends = {"CpuAcc", "CpuRef"}; |
| 39 | |
| 40 | od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options); |
| 41 | |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 42 | common::InferenceResults<float> results; |
Éanna Ó Catháin | 919c14e | 2020-09-14 17:36:49 +0100 | [diff] [blame] | 43 | cv::Mat processed; |
| 44 | cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR); |
| 45 | cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB); |
| 46 | |
| 47 | objectDetectionPipeline->PreProcessing(inputFrame, processed); |
| 48 | |
| 49 | CHECK(processed.type() == CV_8UC3); |
| 50 | CHECK(processed.cols == 300); |
| 51 | CHECK(processed.rows == 300); |
| 52 | |
| 53 | objectDetectionPipeline->Inference(processed, results); |
| 54 | objectDetectionPipeline->PostProcessing(results, |
| 55 | [](od::DetectedObjects detects) -> void { |
| 56 | CHECK(detects.size() == 2); |
| 57 | CHECK(detects[0].GetLabel() == "0"); |
| 58 | }); |
| 59 | |
| 60 | } |