blob: 4ea745628ceb43e4ba654d9834b42a84bf1bea62 [file] [log] [blame]
Sadik Armaganac97c8c2019-03-04 17:44:21 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersSerializeFixture.hpp"
8#include "../Deserializer.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(Deserializer)
14
15struct MeanFixture : public ParserFlatbuffersSerializeFixture
16{
17 explicit MeanFixture(const std::string &inputShape,
18 const std::string &outputShape,
19 const std::string &axis,
20 const std::string &dataType)
21 {
22 m_JsonString = R"(
23 {
24 inputIds: [0],
25 outputIds: [2],
26 layers: [
27 {
28 layer_type: "InputLayer",
29 layer: {
30 base: {
31 layerBindingId: 0,
32 base: {
33 index: 0,
34 layerName: "InputLayer",
35 layerType: "Input",
36 inputSlots: [{
37 index: 0,
38 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
39 }],
40 outputSlots: [{
41 index: 0,
42 tensorInfo: {
43 dimensions: )" + inputShape + R"(,
44 dataType: )" + dataType + R"(
45 }
46 }]
47 }
48 }
49 }
50 },
51 {
52 layer_type: "MeanLayer",
53 layer: {
54 base: {
55 index: 1,
56 layerName: "MeanLayer",
57 layerType: "Mean",
58 inputSlots: [{
59 index: 0,
60 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
61 }],
62 outputSlots: [{
63 index: 0,
64 tensorInfo: {
65 dimensions: )" + outputShape + R"(,
66 dataType: )" + dataType + R"(
67 }
68 }]
69 },
70 descriptor: {
71 axis: )" + axis + R"(,
72 keepDims: true
73 }
74 }
75 },
76 {
77 layer_type: "OutputLayer",
78 layer: {
79 base:{
80 layerBindingId: 2,
81 base: {
82 index: 2,
83 layerName: "OutputLayer",
84 layerType: "Output",
85 inputSlots: [{
86 index: 0,
87 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
88 }],
89 outputSlots: [{
90 index: 0,
91 tensorInfo: {
92 dimensions: )" + outputShape + R"(,
93 dataType: )" + dataType + R"(
94 },
95 }],
96 }
97 }
98 },
99 }
100 ]
101 }
102 )";
103 Setup();
104 }
105};
106
107struct SimpleMeanFixture : MeanFixture
108{
109 SimpleMeanFixture()
110 : MeanFixture("[ 1, 1, 3, 2 ]", // inputShape
111 "[ 1, 1, 1, 2 ]", // outputShape
112 "[ 2 ]", // axis
113 "Float32") // dataType
114 {}
115};
116
117BOOST_FIXTURE_TEST_CASE(SimpleMean, SimpleMeanFixture)
118{
119 RunTest<4, armnn::DataType::Float32>(
120 0,
121 {{"InputLayer", { 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f }}},
122 {{"OutputLayer", { 2.0f, 2.0f }}});
123}
124
125BOOST_AUTO_TEST_SUITE_END()