blob: 27907f1df332e6ffc5d0d0419d7b16cbe0688d93 [file] [log] [blame]
Teresa Charlin43baf502021-09-27 10:10:39 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagana097d2a2021-11-24 15:47:28 +00008#include <CommonTestUtils.hpp>
Teresa Charlin43baf502021-09-27 10:10:39 +01009
10#include <armnn/INetwork.hpp>
11#include <ResolveType.hpp>
12
13#include <doctest/doctest.h>
14
15namespace{
16
17armnn::INetworkPtr CreateChannelShuffleNetwork(const armnn::TensorInfo& inputInfo,
18 const armnn::TensorInfo& outputInfo,
19 const armnn::ChannelShuffleDescriptor& descriptor)
20{
21 armnn::INetworkPtr net(armnn::INetwork::Create());
22
23 armnn::IConnectableLayer* inputLayer = net->AddInputLayer(0);
24 armnn::IConnectableLayer* channelShuffleLayer = net->AddChannelShuffleLayer(descriptor, "channelShuffle");
25 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
26 Connect(inputLayer, channelShuffleLayer, inputInfo, 0, 0);
27 Connect(channelShuffleLayer, outputLayer, outputInfo, 0, 0);
28
29 return net;
30}
31
32template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
33void ChannelShuffleEndToEnd(const std::vector<BackendId>& backends)
34{
35 armnn::TensorInfo inputInfo({ 3,12 }, ArmnnType);
36 armnn::TensorInfo outputInfo({ 3,12 }, ArmnnType);
37
38 inputInfo.SetQuantizationScale(1.0f);
39 inputInfo.SetQuantizationOffset(0);
Cathal Corbett5b8093c2021-10-22 11:12:07 +010040 inputInfo.SetConstant(true);
Teresa Charlin43baf502021-09-27 10:10:39 +010041 outputInfo.SetQuantizationScale(1.0f);
42 outputInfo.SetQuantizationOffset(0);
43
44 // Creates structures for input & output.
45 std::vector<T> inputData{
46 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
47 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
48 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35
49 };
50
51 std::vector<T> expectedOutput{
52 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11,
53 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23,
54 24, 28, 32, 25, 29, 33, 26, 30, 34, 27, 31, 35
55 };
56 ChannelShuffleDescriptor descriptor;
57 descriptor.m_Axis = 1;
58 descriptor.m_NumGroups = 3;
59
60 // Builds up the structure of the network
61 armnn::INetworkPtr net = CreateChannelShuffleNetwork(inputInfo, outputInfo, descriptor);
62
63 CHECK(net);
64
65 std::map<int, std::vector<T>> inputTensorData = {{ 0, inputData }};
66 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
67
68 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
69}
70
71} // anonymous namespace