blob: 1d6cf1d99bc0df6c75ea59890eff93da1fb20642 [file] [log] [blame]
FrancisMurtagh2262bbd2018-12-20 16:09:45 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +00007#include "TypeUtils.hpp"
8
FrancisMurtagh2262bbd2018-12-20 16:09:45 +00009#include <armnn/INetwork.hpp>
10
11#include <backendsCommon/test/CommonTestUtils.hpp>
12
13#include <boost/test/unit_test.hpp>
14
15#include <vector>
16
17namespace
18{
19
20template<typename armnn::DataType DataType>
21INetworkPtr CreateArithmeticNetwork(const std::vector<TensorShape>& inputShapes,
22 const TensorShape& outputShape,
23 const LayerType type,
24 const float qScale = 1.0f,
25 const int32_t qOffset = 0)
26{
27 using namespace armnn;
28
29 // Builds up the structure of the network.
30 INetworkPtr net(INetwork::Create());
31
32 IConnectableLayer* arithmeticLayer = nullptr;
33
34 switch(type){
35 case LayerType::Equal: arithmeticLayer = net->AddEqualLayer("equal"); break;
36 case LayerType::Greater: arithmeticLayer = net->AddGreaterLayer("greater"); break;
37 default: BOOST_TEST_FAIL("Non-Arithmetic layer type called.");
38 }
39
40 for (unsigned int i = 0; i < inputShapes.size(); ++i)
41 {
42 TensorInfo inputTensorInfo(inputShapes[i], DataType, qScale, qOffset);
43 IConnectableLayer* input = net->AddInputLayer(boost::numeric_cast<LayerBindingId>(i));
44 Connect(input, arithmeticLayer, inputTensorInfo, 0, i);
45 }
46
47 TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
48 IConnectableLayer* output = net->AddOutputLayer(0, "output");
49 Connect(arithmeticLayer, output, outputTensorInfo, 0, 0);
50
51 return net;
52}
53
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000054template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000055void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
56 const LayerType type,
57 const std::vector<T> expectedOutput)
58{
59 using namespace armnn;
60
61 const std::vector<TensorShape> inputShapes{{ 2, 2, 2, 2 }, { 2, 2, 2, 2 }};
62 const TensorShape& outputShape = { 2, 2, 2, 2 };
63
64 // Builds up the structure of the network
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000065 INetworkPtr net = CreateArithmeticNetwork<ArmnnType>(inputShapes, outputShape, type);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000066
67 BOOST_TEST_CHECKPOINT("create a network");
68
69 const std::vector<T> input0({ 1, 1, 1, 1, 5, 5, 5, 5,
70 3, 3, 3, 3, 4, 4, 4, 4 });
71
72 const std::vector<T> input1({ 1, 1, 1, 1, 3, 3, 3, 3,
73 5, 5, 5, 5, 4, 4, 4, 4 });
74
75 std::map<int, std::vector<T>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
76 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
77
78 EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
79}
80
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000081template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000082void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
83 const LayerType type,
84 const std::vector<T> expectedOutput)
85{
86 using namespace armnn;
87
88 const std::vector<TensorShape> inputShapes{{ 1, 2, 2, 3 }, { 1, 1, 1, 3 }};
89 const TensorShape& outputShape = { 1, 2, 2, 3 };
90
91 // Builds up the structure of the network
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000092 INetworkPtr net = CreateArithmeticNetwork<ArmnnType>(inputShapes, outputShape, type);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000093
94 BOOST_TEST_CHECKPOINT("create a network");
95
96 const std::vector<T> input0({ 1, 2, 3, 1, 0, 6,
97 7, 8, 9, 10, 11, 12 });
98
99 const std::vector<T> input1({ 1, 1, 3 });
100
101 std::map<int, std::vector<T>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
102 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
103
104 EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
105}
106
107} // anonymous namespace