blob: 3b796a0c212fb842b2dc316e23e92e785ffae731 [file] [log] [blame]
Teresa Charlin70dc5e92024-03-05 17:59:27 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <CommonTestUtils.hpp>
9
10#include <armnn/INetwork.hpp>
11#include <ResolveType.hpp>
12
13#include <doctest/doctest.h>
14
15using namespace armnn;
16
17namespace
18{
19
20template<DataType ArmnnType, typename T = ResolveType<ArmnnType>>
21INetworkPtr CreateScatterNdNetwork(const TensorInfo& shapeInfo,
22 const TensorInfo& indicesInfo,
23 const TensorInfo& updatesInfo,
24 const TensorInfo& outputInfo,
25 const std::vector<int32_t>& indicesData,
26 const std::vector<T>& updatesData,
27 const ScatterNdDescriptor& descriptor)
28{
29 INetworkPtr net(INetwork::Create());
30
31 IConnectableLayer* shapeLayer = net->AddInputLayer(0);
32 IConnectableLayer* indicesLayer = net->AddConstantLayer(ConstTensor(indicesInfo, indicesData));
33 IConnectableLayer* updatesLayer = net->AddConstantLayer(ConstTensor(updatesInfo, updatesData));
34 IConnectableLayer* scatterNdLayer = net->AddScatterNdLayer(descriptor, "scatterNd");
35 IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
36 Connect(shapeLayer, scatterNdLayer, shapeInfo, 0, 0);
37 Connect(indicesLayer, scatterNdLayer, indicesInfo, 0, 1);
38 Connect(updatesLayer, scatterNdLayer, updatesInfo, 0, 2);
39 Connect(scatterNdLayer, outputLayer, outputInfo, 0, 0);
40
41 return net;
42}
43
44template<DataType ArmnnType, typename T = ResolveType<ArmnnType>>
45void ScatterNd1DimUpdateWithInputEndToEnd(const std::vector<BackendId>& backends)
46{
47 float_t qScale = 1.f;
48 int32_t qOffset = 0;
49
50 TensorInfo inputInfo({ 5 }, ArmnnType, qScale, qOffset, true);
51 TensorInfo indicesInfo({ 3, 1 }, DataType::Signed32, 1.0f, 0, true);
52 TensorInfo updatesInfo({ 3 }, ArmnnType, qScale, qOffset, true);
53 TensorInfo outputInfo({ 5 }, ArmnnType, qScale, qOffset, false);
54
55 std::vector<T> inputData = armnnUtils::QuantizedVector<T>({ 0, 0, 0, 0, 0 }, qScale, qOffset);
56 std::vector<int32_t> indicesData{0, 1, 2};
57 std::vector<T> updatesData = armnnUtils::QuantizedVector<T>({ 1, 2, 3 }, qScale, qOffset);
58 std::vector<T> expectedOutput = armnnUtils::QuantizedVector<T>({ 1, 2, 3, 0, 0 }, qScale, qOffset);
59
60 armnn::ScatterNdDescriptor descriptor(armnn::ScatterNdFunction::Update, true);
61
62 INetworkPtr net = CreateScatterNdNetwork<ArmnnType>(inputInfo, indicesInfo, updatesInfo, outputInfo,
63 indicesData, updatesData, descriptor);
64
65 CHECK(net);
66
67 std::map<int, std::vector<T>> inputTensorData = {{ 0, inputData }};
68 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
69
70 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
71}
72
73template<DataType ArmnnType, typename T = ResolveType<ArmnnType>>
74void ScatterNd1DimUpdateNoInputEndToEnd(const std::vector<BackendId>& backends)
75{
76 float_t qScale = 1.f;
77 int32_t qOffset = 0;
78
79 TensorInfo shapeInfo({ 1 }, DataType::Signed32, 1.0f, 0, true);
80 TensorInfo indicesInfo({ 3, 1 }, DataType::Signed32, 1.0f, 0, true);
81 TensorInfo updatesInfo({ 3 }, ArmnnType, qScale, qOffset, true);
82 TensorInfo outputInfo({ 5 }, ArmnnType, qScale, qOffset, false);
83
84 std::vector<int32_t> shapeData{ 5 };
85 std::vector<int32_t> indicesData{ 0, 1, 2 };
86 std::vector<T> updatesData = armnnUtils::QuantizedVector<T>({ 1, 2, 3 }, qScale, qOffset);
87 std::vector<T> expectedOutput = armnnUtils::QuantizedVector<T>({ 1, 2, 3, 0, 0 }, qScale, qOffset);
88
89 armnn::ScatterNdDescriptor descriptor(armnn::ScatterNdFunction::Update, false);
90
91 INetworkPtr net = CreateScatterNdNetwork<ArmnnType>(shapeInfo, indicesInfo, updatesInfo, outputInfo,
92 indicesData, updatesData, descriptor);
93
94 CHECK(net);
95
96 std::map<int, std::vector<int32_t>> inputTensorData = {{ 0, shapeData }};
97 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
98
99 EndToEndLayerTestImpl<DataType::Signed32, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
100}
101
102template<DataType ArmnnType, typename T = ResolveType<ArmnnType>>
103void ScatterNd2DimUpdateWithInputEndToEnd(const std::vector<BackendId>& backends)
104{
105 float_t qScale = 1.f;
106 int32_t qOffset = 0;
107
108 TensorInfo inputInfo({ 3, 3 }, ArmnnType, qScale, qOffset, true);
109 TensorInfo indicesInfo({ 3, 2 }, DataType::Signed32, 1.0f, 0, true);
110 TensorInfo updatesInfo({ 3 }, ArmnnType, qScale, qOffset, true);
111 TensorInfo outputInfo({ 3, 3 }, ArmnnType, qScale, qOffset, false);
112
113 std::vector<T> inputData = armnnUtils::QuantizedVector<T>({ 1, 1, 1, 1, 1, 1, 1, 1, 1 }, qScale, qOffset);
114 std::vector<int32_t> indicesData{0, 0, 1, 1, 2, 2};
115 std::vector<T> updatesData = armnnUtils::QuantizedVector<T>({ 1, 2, 3 }, qScale, qOffset);
116 std::vector<T> expectedOutput = armnnUtils::QuantizedVector<T>({ 1, 1, 1, 1, 2, 1, 1, 1, 3 }, qScale, qOffset);
117
118 armnn::ScatterNdDescriptor descriptor(armnn::ScatterNdFunction::Update, true);
119
120 INetworkPtr net = CreateScatterNdNetwork<ArmnnType>(inputInfo, indicesInfo, updatesInfo, outputInfo,
121 indicesData, updatesData, descriptor);
122
123 CHECK(net);
124
125 std::map<int, std::vector<T>> inputTensorData = {{ 0, inputData }};
126 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
127
128 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
129}
130
131template<DataType ArmnnType, typename T = ResolveType<ArmnnType>>
132void ScatterNd2DimUpdateNoInputEndToEnd(const std::vector<BackendId>& backends)
133{
134 float_t qScale = 1.f;
135 int32_t qOffset = 0;
136
137 TensorInfo shapeInfo({ 2 }, DataType::Signed32, 1.0f, 0, true);
138 TensorInfo indicesInfo({ 3, 2 }, DataType::Signed32, 1.0f, 0, true);
139 TensorInfo updatesInfo({ 3 }, ArmnnType, qScale, qOffset, true);
140 TensorInfo outputInfo({ 3, 3 }, ArmnnType, qScale, qOffset, false);
141
142 std::vector<int32_t> shapeData{ 3, 3 };
143 std::vector<int32_t> indicesData{0, 0, 1, 1, 2, 2};
144 std::vector<T> updatesData = armnnUtils::QuantizedVector<T>({ 1, 2, 3 }, qScale, qOffset);
145 std::vector<T> expectedOutput = armnnUtils::QuantizedVector<T>({ 1, 0, 0, 0, 2, 0, 0, 0, 3 }, qScale, qOffset);
146
147 armnn::ScatterNdDescriptor descriptor(armnn::ScatterNdFunction::Update, false);
148
149 INetworkPtr net = CreateScatterNdNetwork<ArmnnType>(shapeInfo, indicesInfo, updatesInfo, outputInfo,
150 indicesData, updatesData, descriptor);
151
152 CHECK(net);
153
154 std::map<int, std::vector<int32_t>> inputTensorData = {{ 0, shapeData }};
155 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
156
157 EndToEndLayerTestImpl<DataType::Signed32, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
158}
159
160} // anonymous namespace