blob: 6c8218374190bff7947d6d408a895994c74f5de2 [file] [log] [blame]
Sadik Armaganac97c8c2019-03-04 17:44:21 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Sadik Armaganac97c8c2019-03-04 17:44:21 +00006#include "ParserFlatbuffersSerializeFixture.hpp"
Finn Williams85d36712021-01-26 22:30:06 +00007#include <armnnDeserializer/IDeserializer.hpp>
Sadik Armaganac97c8c2019-03-04 17:44:21 +00008
9#include <string>
Sadik Armaganac97c8c2019-03-04 17:44:21 +000010
Sadik Armagan1625efc2021-06-10 18:24:34 +010011TEST_SUITE("Deserializer_Mean")
12{
Sadik Armaganac97c8c2019-03-04 17:44:21 +000013struct MeanFixture : public ParserFlatbuffersSerializeFixture
14{
15 explicit MeanFixture(const std::string &inputShape,
16 const std::string &outputShape,
17 const std::string &axis,
18 const std::string &dataType)
19 {
20 m_JsonString = R"(
21 {
22 inputIds: [0],
23 outputIds: [2],
24 layers: [
25 {
26 layer_type: "InputLayer",
27 layer: {
28 base: {
29 layerBindingId: 0,
30 base: {
31 index: 0,
32 layerName: "InputLayer",
33 layerType: "Input",
34 inputSlots: [{
35 index: 0,
36 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
37 }],
38 outputSlots: [{
39 index: 0,
40 tensorInfo: {
41 dimensions: )" + inputShape + R"(,
42 dataType: )" + dataType + R"(
43 }
44 }]
45 }
46 }
47 }
48 },
49 {
50 layer_type: "MeanLayer",
51 layer: {
52 base: {
53 index: 1,
54 layerName: "MeanLayer",
55 layerType: "Mean",
56 inputSlots: [{
57 index: 0,
58 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
59 }],
60 outputSlots: [{
61 index: 0,
62 tensorInfo: {
63 dimensions: )" + outputShape + R"(,
64 dataType: )" + dataType + R"(
65 }
66 }]
67 },
68 descriptor: {
69 axis: )" + axis + R"(,
70 keepDims: true
71 }
72 }
73 },
74 {
75 layer_type: "OutputLayer",
76 layer: {
77 base:{
78 layerBindingId: 2,
79 base: {
80 index: 2,
81 layerName: "OutputLayer",
82 layerType: "Output",
83 inputSlots: [{
84 index: 0,
85 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
86 }],
87 outputSlots: [{
88 index: 0,
89 tensorInfo: {
90 dimensions: )" + outputShape + R"(,
91 dataType: )" + dataType + R"(
92 },
93 }],
94 }
95 }
96 },
97 }
98 ]
99 }
100 )";
101 Setup();
102 }
103};
104
105struct SimpleMeanFixture : MeanFixture
106{
107 SimpleMeanFixture()
108 : MeanFixture("[ 1, 1, 3, 2 ]", // inputShape
109 "[ 1, 1, 1, 2 ]", // outputShape
110 "[ 2 ]", // axis
111 "Float32") // dataType
112 {}
113};
114
Sadik Armagan1625efc2021-06-10 18:24:34 +0100115TEST_CASE_FIXTURE(SimpleMeanFixture, "SimpleMean")
Sadik Armaganac97c8c2019-03-04 17:44:21 +0000116{
117 RunTest<4, armnn::DataType::Float32>(
118 0,
119 {{"InputLayer", { 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f }}},
120 {{"OutputLayer", { 2.0f, 2.0f }}});
121}
122
Sadik Armagan1625efc2021-06-10 18:24:34 +0100123}