blob: fcee978deddf8d939b05cdec10f9e35a62e5d0c5 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5#include "../InferenceTest.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "../ImagePreprocessor.hpp"
surmeh01bceff2f2018-03-29 16:29:27 +01007#include "armnnTfParser/ITfParser.hpp"
8
9int main(int argc, char* argv[])
10{
surmeh013537c2c2018-05-18 16:31:43 +010011 int retVal = EXIT_FAILURE;
12 try
surmeh01bceff2f2018-03-29 16:29:27 +010013 {
surmeh013537c2c2018-05-18 16:31:43 +010014 // Coverity fix: The following code may throw an exception of type std::length_error.
15 std::vector<ImageSet> imageSet =
16 {
17 { "Dog.jpg", 208 },
surmeh01c3b012e2018-09-12 16:00:26 +010018 // Top five predictions in tensorflow:
19 // -----------------------------------
20 // 208:golden retriever 0.57466376
21 // 209:Labrador retriever 0.30202731
22 // 853:tennis ball 0.0060001756
23 // 223:kuvasz 0.0053707925
24 // 160:Rhodesian ridgeback 0.0018179063
25
surmeh013537c2c2018-05-18 16:31:43 +010026 { "Cat.jpg", 283 },
surmeh01c3b012e2018-09-12 16:00:26 +010027 // Top five predictions in tensorflow:
28 // -----------------------------------
29 // 283:tiger cat 0.4667799
30 // 282:tabby, tabby cat 0.32511184
31 // 286:Egyptian cat 0.1038616
32 // 288:lynx, catamount 0.0017019814
33 // 284:Persian cat 0.0011340436
34
surmeh013537c2c2018-05-18 16:31:43 +010035 { "shark.jpg", 3 },
surmeh01c3b012e2018-09-12 16:00:26 +010036 // Top five predictions in tensorflow:
37 // -----------------------------------
38 // 3:great white shark, white shark, ... 0.98808634
39 // 148:grey whale, gray whale, ... 0.00070245547
40 // 234:Bouvier des Flandres, ... 0.00024639888
41 // 149:killer whale, killer, ... 0.00014115588
42 // 95:hummingbird 0.00011129203
surmeh013537c2c2018-05-18 16:31:43 +010043 };
44
45 armnn::TensorShape inputTensorShape({ 1, 299, 299, 3 });
46
telsoa01c577f2c2018-08-31 09:22:23 +010047 using DataType = float;
48 using DatabaseType = ImagePreprocessor<float>;
49 using ParserType = armnnTfParser::ITfParser;
50 using ModelType = InferenceModel<ParserType, DataType>;
51
surmeh013537c2c2018-05-18 16:31:43 +010052 // Coverity fix: InferenceTestMain() may throw uncaught exceptions.
telsoa01c577f2c2018-08-31 09:22:23 +010053 retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType, ParserType>(
surmeh01e82ef3f2018-09-13 10:23:42 +010054 argc, argv, "inception_v3_2016_08_28_frozen.pb", true,
surmeh013537c2c2018-05-18 16:31:43 +010055 "input", "InceptionV3/Predictions/Reshape_1", { 0, 1, 2, },
telsoa01c577f2c2018-08-31 09:22:23 +010056 [&imageSet](const char* dataDir, const ModelType&) {
57 return DatabaseType(dataDir, 299, 299, imageSet);
58 },
surmeh013537c2c2018-05-18 16:31:43 +010059 &inputTensorShape);
60 }
61 catch (const std::exception& e)
62 {
63 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
64 // exception of type std::length_error.
65 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
66 std::cerr << "WARNING: TfInceptionV3-Armnn: An error has occurred when running "
67 "the classifier inference tests: " << e.what() << std::endl;
68 }
69 return retVal;
surmeh01bceff2f2018-03-29 16:29:27 +010070}