blob: 7396b7672c280e7a15e010d776e741f4d886b4b3 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "../YoloInferenceTest.hpp"
6#include "armnnCaffeParser/ICaffeParser.hpp"
7#include "armnn/TypesUtils.hpp"
8
9int main(int argc, char* argv[])
10{
11 armnn::TensorShape inputTensorShape{ { 1, 3, YoloImageHeight, YoloImageWidth } };
12
13 using YoloInferenceModel = InferenceModel<armnnCaffeParser::ICaffeParser,
14 float>;
15
surmeh013537c2c2018-05-18 16:31:43 +010016 int retVal = EXIT_FAILURE;
17 try
18 {
19 // Coverity fix: InferenceTestMain() may throw uncaught exceptions.
20 retVal = InferenceTestMain(argc, argv, { 0 },
21 [&inputTensorShape]()
22 {
23 return make_unique<YoloTestCaseProvider<YoloInferenceModel>>(
24 [&]
25 (typename YoloInferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000026 {
surmeh013537c2c2018-05-18 16:31:43 +010027 if (!ValidateDirectory(modelOptions.m_ModelDir))
28 {
29 return std::unique_ptr<YoloInferenceModel>();
30 }
telsoa014fcda012018-03-09 14:13:49 +000031
surmeh013537c2c2018-05-18 16:31:43 +010032 typename YoloInferenceModel::Params modelParams;
33 modelParams.m_ModelPath = modelOptions.m_ModelDir + "yolov1_tiny_voc2007_model.caffemodel";
34 modelParams.m_InputBinding = "data";
35 modelParams.m_OutputBinding = "fc12";
36 modelParams.m_InputTensorShape = &inputTensorShape;
37 modelParams.m_IsModelBinary = true;
38 modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
39 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +010040 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +000041
surmeh013537c2c2018-05-18 16:31:43 +010042 return std::make_unique<YoloInferenceModel>(modelParams);
43 });
telsoa014fcda012018-03-09 14:13:49 +000044 });
surmeh013537c2c2018-05-18 16:31:43 +010045 }
46 catch (const std::exception& e)
47 {
48 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
49 // exception of type std::length_error.
50 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
51 std::cerr << "WARNING: CaffeYolo-Armnn: An error has occurred when running "
52 "the classifier inference tests: " << e.what() << std::endl;
53 }
54 return retVal;
telsoa014fcda012018-03-09 14:13:49 +000055}