blob: 29dafd36b469b2fd2a75d07fd91e390d64baed7d [file] [log] [blame]
FrancisMurtagh2262bbd2018-12-20 16:09:45 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +00008
FrancisMurtagh2262bbd2018-12-20 16:09:45 +00009#include <armnn/INetwork.hpp>
10
11#include <backendsCommon/test/CommonTestUtils.hpp>
12
13#include <boost/test/unit_test.hpp>
14
15#include <vector>
16
17namespace
18{
19
kevmay012b4d88e2019-01-24 14:05:09 +000020template<armnn::DataType ArmnnTypeInput, armnn::DataType ArmnnTypeOutput>
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000021INetworkPtr CreateArithmeticNetwork(const std::vector<TensorShape>& inputShapes,
22 const TensorShape& outputShape,
23 const LayerType type,
24 const float qScale = 1.0f,
25 const int32_t qOffset = 0)
26{
27 using namespace armnn;
28
29 // Builds up the structure of the network.
30 INetworkPtr net(INetwork::Create());
31
32 IConnectableLayer* arithmeticLayer = nullptr;
33
34 switch(type){
35 case LayerType::Equal: arithmeticLayer = net->AddEqualLayer("equal"); break;
36 case LayerType::Greater: arithmeticLayer = net->AddGreaterLayer("greater"); break;
37 default: BOOST_TEST_FAIL("Non-Arithmetic layer type called.");
38 }
39
40 for (unsigned int i = 0; i < inputShapes.size(); ++i)
41 {
kevmay012b4d88e2019-01-24 14:05:09 +000042 TensorInfo inputTensorInfo(inputShapes[i], ArmnnTypeInput, qScale, qOffset);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000043 IConnectableLayer* input = net->AddInputLayer(boost::numeric_cast<LayerBindingId>(i));
44 Connect(input, arithmeticLayer, inputTensorInfo, 0, i);
45 }
46
kevmay012b4d88e2019-01-24 14:05:09 +000047 TensorInfo outputTensorInfo(outputShape, ArmnnTypeOutput, qScale, qOffset);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000048 IConnectableLayer* output = net->AddOutputLayer(0, "output");
49 Connect(arithmeticLayer, output, outputTensorInfo, 0, 0);
50
51 return net;
52}
53
kevmay012b4d88e2019-01-24 14:05:09 +000054template<armnn::DataType ArmnnInputType,
55 armnn::DataType ArmnnOutputType,
56 typename TInput = armnn::ResolveType<ArmnnInputType>,
57 typename TOutput = armnn::ResolveType<ArmnnOutputType>>
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000058void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
59 const LayerType type,
kevmay012b4d88e2019-01-24 14:05:09 +000060 const std::vector<TOutput> expectedOutput)
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000061{
62 using namespace armnn;
63
64 const std::vector<TensorShape> inputShapes{{ 2, 2, 2, 2 }, { 2, 2, 2, 2 }};
65 const TensorShape& outputShape = { 2, 2, 2, 2 };
66
67 // Builds up the structure of the network
kevmay012b4d88e2019-01-24 14:05:09 +000068 INetworkPtr net = CreateArithmeticNetwork<ArmnnInputType, ArmnnOutputType>(inputShapes, outputShape, type);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000069
70 BOOST_TEST_CHECKPOINT("create a network");
71
kevmay012b4d88e2019-01-24 14:05:09 +000072 const std::vector<TInput> input0({ 1, 1, 1, 1, 5, 5, 5, 5,
73 3, 3, 3, 3, 4, 4, 4, 4 });
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000074
kevmay012b4d88e2019-01-24 14:05:09 +000075 const std::vector<TInput> input1({ 1, 1, 1, 1, 3, 3, 3, 3,
76 5, 5, 5, 5, 4, 4, 4, 4 });
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000077
kevmay012b4d88e2019-01-24 14:05:09 +000078 std::map<int, std::vector<TInput>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
79 std::map<int, std::vector<TOutput>> expectedOutputData = {{ 0, expectedOutput }};
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000080
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +000081 EndToEndLayerTestImpl<ArmnnInputType, ArmnnOutputType>(move(net), inputTensorData, expectedOutputData, backends);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000082}
83
kevmay012b4d88e2019-01-24 14:05:09 +000084template<armnn::DataType ArmnnInputType,
85 armnn::DataType ArmnnOutputType,
86 typename TInput = armnn::ResolveType<ArmnnInputType>,
87 typename TOutput = armnn::ResolveType<ArmnnOutputType>>
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000088void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
89 const LayerType type,
kevmay012b4d88e2019-01-24 14:05:09 +000090 const std::vector<TOutput> expectedOutput)
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000091{
92 using namespace armnn;
93
94 const std::vector<TensorShape> inputShapes{{ 1, 2, 2, 3 }, { 1, 1, 1, 3 }};
95 const TensorShape& outputShape = { 1, 2, 2, 3 };
96
97 // Builds up the structure of the network
kevmay012b4d88e2019-01-24 14:05:09 +000098 INetworkPtr net = CreateArithmeticNetwork<ArmnnInputType, ArmnnOutputType>(inputShapes, outputShape, type);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +000099
100 BOOST_TEST_CHECKPOINT("create a network");
101
kevmay012b4d88e2019-01-24 14:05:09 +0000102 const std::vector<TInput> input0({ 1, 2, 3, 1, 0, 6,
103 7, 8, 9, 10, 11, 12 });
FrancisMurtagh2262bbd2018-12-20 16:09:45 +0000104
kevmay012b4d88e2019-01-24 14:05:09 +0000105 const std::vector<TInput> input1({ 1, 1, 3 });
FrancisMurtagh2262bbd2018-12-20 16:09:45 +0000106
kevmay012b4d88e2019-01-24 14:05:09 +0000107 std::map<int, std::vector<TInput>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
108 std::map<int, std::vector<TOutput>> expectedOutputData = {{ 0, expectedOutput }};
FrancisMurtagh2262bbd2018-12-20 16:09:45 +0000109
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000110 EndToEndLayerTestImpl<ArmnnInputType, ArmnnOutputType>(move(net), inputTensorData, expectedOutputData, backends);
FrancisMurtagh2262bbd2018-12-20 16:09:45 +0000111}
112
113} // anonymous namespace