blob: 4988e745d0d249fb0cf9ce71f058adddb28199b4 [file] [log] [blame]
Nikhil Raj747f5862019-07-19 15:15:23 +01001//
Mike Kellya9c32672023-12-04 17:23:09 +00002// Copyright © 2019,2021,2023 Arm Ltd and Contributors. All rights reserved.
Nikhil Raj747f5862019-07-19 15:15:23 +01003// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <ResolveType.hpp>
8
9#include <armnn/INetwork.hpp>
10
Sadik Armagana097d2a2021-11-24 15:47:28 +000011#include <CommonTestUtils.hpp>
Nikhil Raj747f5862019-07-19 15:15:23 +010012
Sadik Armagan1625efc2021-06-10 18:24:34 +010013#include <doctest/doctest.h>
14
Nikhil Raj747f5862019-07-19 15:15:23 +010015namespace
16{
17template<typename armnn::DataType DataType>
18INetworkPtr CreatePreluNetwork(const armnn::TensorInfo& inputInfo,
19 const armnn::TensorInfo& alphaInfo,
20 const armnn::TensorInfo& outputInfo)
21{
22 using namespace armnn;
23
24 INetworkPtr net(INetwork::Create());
25
26 IConnectableLayer* input = net->AddInputLayer(0, "input");
27 IConnectableLayer* alpha = net->AddInputLayer(1, "alpha");
28 IConnectableLayer* prelu = net->AddPreluLayer("Prelu");
29 IConnectableLayer* output = net->AddOutputLayer(0, "output");
30
31 Connect(input, prelu, inputInfo, 0, 0);
32 Connect(alpha, prelu, alphaInfo, 0, 1);
33 Connect(prelu, output, outputInfo, 0, 0);
34
35 return net;
36}
37
38template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
39void PreluEndToEnd(const std::vector<BackendId>& backends,
40 const std::vector<T>& inputData,
41 const std::vector<T>& alphaData,
42 const std::vector<T>& expectedOutputData,
43 const float qScale ,
44 const int32_t qOffset)
45{
46 using namespace armnn;
47
48 armnn::TensorInfo inputInfo({ 2, 2, 2, 1 }, ArmnnType);
49 armnn::TensorInfo alphaInfo({ 1, 2, 2, 1 }, ArmnnType);
50 armnn::TensorInfo outputInfo({ 2, 2, 2, 1 }, ArmnnType);
51
52 inputInfo.SetQuantizationOffset(qOffset);
53 inputInfo.SetQuantizationScale(qScale);
Cathal Corbett5b8093c2021-10-22 11:12:07 +010054 inputInfo.SetConstant(true);
Nikhil Raj747f5862019-07-19 15:15:23 +010055 alphaInfo.SetQuantizationOffset(qOffset);
56 alphaInfo.SetQuantizationScale(qScale);
Cathal Corbett5b8093c2021-10-22 11:12:07 +010057 alphaInfo.SetConstant(true);
Nikhil Raj747f5862019-07-19 15:15:23 +010058 outputInfo.SetQuantizationOffset(qOffset);
59 outputInfo.SetQuantizationScale(qScale);
60
61 INetworkPtr net = CreatePreluNetwork<ArmnnType>(inputInfo, alphaInfo, outputInfo);
62
Sadik Armagan1625efc2021-06-10 18:24:34 +010063 CHECK(net);
Nikhil Raj747f5862019-07-19 15:15:23 +010064
65 std::map<int, std::vector<T>> inputTensorData = { { 0, inputData }, { 1, alphaData} };
66 std::map<int, std::vector<T>> expectedOutputTensorData = { { 0, expectedOutputData } };
67
Mike Kellya9c32672023-12-04 17:23:09 +000068 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net),
Nikhil Raj747f5862019-07-19 15:15:23 +010069 inputTensorData,
70 expectedOutputTensorData,
71 backends);
72}
73
74template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
75void PreluEndToEndPositiveTest(const std::vector<BackendId>& backends, const float qScale = 1.0f,
76 const int32_t qOffset = 2)
77{
78 std::vector<T> inputData{ 1, 2, 3, 4, 5, 6, 7, 8 };
79 std::vector<T> alphaData{ 2, 1, 1, 1 };
80
81 std::vector<T> expectedOutputData{ 2, 2, 3, 4, 5, 6, 7, 8 };
82
83 PreluEndToEnd<ArmnnType>(backends, inputData, alphaData, expectedOutputData, qScale, qOffset);
84}
85
86template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
87void PreluEndToEndNegativeTest(const std::vector<BackendId>& backends, const float qScale = 1.0f,
88 const int32_t qOffset = 0)
89{
90 std::vector<T> inputData{ 1, -2, 3, 4, 5, 6, 7, 8 };
91 std::vector<T> alphaData{ 1, 2, 1, 1 };
92
93 std::vector<T> expectedOutputData{ 1, -4, 3, 4, 5, 6, 7, 8 };
94
95 PreluEndToEnd<ArmnnType>(backends, inputData, alphaData, expectedOutputData, qScale, qOffset);
96}
97
98} // anonymous namespace