blob: 10abb65cceb49ce8e89c1ba5a97985d282cbdbf0 [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "CvVideoFrameReader.hpp"
7#include "CvWindowOutput.hpp"
8#include "CvVideoFileWriter.hpp"
9#include "NetworkPipeline.hpp"
10#include "CmdArgsParser.hpp"
11
12#include <fstream>
13#include <iostream>
14#include <map>
15#include <random>
16
17/*
18 * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
19 */
20std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
21{
22 std::vector<armnn::BackendId> backends;
23 std::stringstream ss(preferredBackends);
24
25 while(ss.good())
26 {
27 std::string backend;
28 std::getline( ss, backend, ',' );
29 backends.emplace_back(backend);
30 }
31 return backends;
32}
33
34/*
35 * Assigns a color to each label in the label set
36 */
37std::vector<std::tuple<std::string, od::BBoxColor>> AssignColourToLabel(const std::string& pathToLabelFile)
38{
39 std::ifstream in(pathToLabelFile);
40 std::vector<std::tuple<std::string, od::BBoxColor>> labels;
41
42 std::string str;
43 std::default_random_engine generator;
44 std::uniform_int_distribution<int> distribution(0,255);
45
46 while (std::getline(in, str))
47 {
48 if(!str.empty())
49 {
50 od::BBoxColor c{
51 .colorCode = std::make_tuple(distribution(generator),
52 distribution(generator),
53 distribution(generator))
54 };
55 auto bboxInfo = std::make_tuple (str, c);
56
57 labels.emplace_back(bboxInfo);
58 }
59 }
60 return labels;
61}
62
63std::tuple<std::unique_ptr<od::IFrameReader<cv::Mat>>,
64 std::unique_ptr<od::IFrameOutput<cv::Mat>>>
65 GetFrameSourceAndSink(const std::map<std::string, std::string>& options) {
66
67 std::unique_ptr<od::IFrameReader<cv::Mat>> readerPtr;
68
69 std::unique_ptr<od::CvVideoFrameReader> reader = std::make_unique<od::CvVideoFrameReader>();
70 reader->Init(GetSpecifiedOption(options, VIDEO_FILE_PATH));
71
72 auto enc = reader->GetSourceEncodingInt();
73 auto fps = reader->GetSourceFps();
74 auto w = reader->GetSourceWidth();
75 auto h = reader->GetSourceHeight();
76 if (!reader->ConvertToRGB())
77 {
78 readerPtr = std::move(std::make_unique<od::CvVideoFrameReaderRgbWrapper>(std::move(reader)));
79 }
80 else
81 {
82 readerPtr = std::move(reader);
83 }
84
85 if(CheckOptionSpecified(options, OUTPUT_VIDEO_FILE_PATH))
86 {
87 std::string outputVideo = GetSpecifiedOption(options, OUTPUT_VIDEO_FILE_PATH);
88 auto writer = std::make_unique<od::CvVideoFileWriter>();
89 writer->Init(outputVideo, enc, fps, w, h);
90
91 return std::make_tuple<>(std::move(readerPtr), std::move(writer));
92 }
93 else
94 {
95 auto writer = std::make_unique<od::CvWindowOutput>();
96 writer->Init("Processed Video");
97 return std::make_tuple<>(std::move(readerPtr), std::move(writer));
98 }
99}
100
101int main(int argc, char *argv[])
102{
103 std::map<std::string, std::string> options;
104
105 int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
106 if (result != 0)
107 {
108 return result;
109 }
110
111 // Create the network options
112 od::ODPipelineOptions pipelineOptions;
113 pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
114 pipelineOptions.m_ModelName = GetSpecifiedOption(options, MODEL_NAME);
115
116 if(CheckOptionSpecified(options, PREFERRED_BACKENDS))
117 {
118 pipelineOptions.m_backends = GetPreferredBackendList((GetSpecifiedOption(options, PREFERRED_BACKENDS)));
119 }
120 else
121 {
122 pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
123 }
124
125 auto labels = AssignColourToLabel(GetSpecifiedOption(options, LABEL_PATH));
126
127 od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(pipelineOptions);
128
129 auto inputAndOutput = GetFrameSourceAndSink(options);
130 std::unique_ptr<od::IFrameReader<cv::Mat>> reader = std::move(std::get<0>(inputAndOutput));
131 std::unique_ptr<od::IFrameOutput<cv::Mat>> sink = std::move(std::get<1>(inputAndOutput));
132
133 if (!sink->IsReady())
134 {
135 std::cerr << "Failed to open video writer.";
136 return 1;
137 }
138
139 od::InferenceResults results;
140
141 std::shared_ptr<cv::Mat> frame = reader->ReadFrame();
142
143 //pre-allocate frames
144 cv::Mat processed;
145
146 while(!reader->IsExhausted(frame))
147 {
148 objectDetectionPipeline->PreProcessing(*frame, processed);
149 objectDetectionPipeline->Inference(processed, results);
150 objectDetectionPipeline->PostProcessing(results,
151 [&frame, &labels](od::DetectedObjects detects) -> void {
152 AddInferenceOutputToFrame(detects, *frame, labels);
153 });
154
155 sink->WriteFrame(frame);
156 frame = reader->ReadFrame();
157 }
158 sink->Close();
159 return 0;
160}