blob: 6dd015173c0ee7dbf43c326a70af829331360212 [file] [log] [blame]
Teresa Charlin98427a12020-11-25 18:22:57 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9
10#include <flatbuffers/flatbuffers.h>
11#include <tensorflow/lite/schema/schema_generated.h>
12
13#include <doctest/doctest.h>
14
15namespace armnnDelegate
16{
17
18// GATHER Operator
19void GatherUint8Test(std::vector<armnn::BackendId>& backends)
20{
21
22 std::vector<int32_t> paramsShape{8};
23 std::vector<int32_t> indicesShape{3};
24 std::vector<int32_t> expectedOutputShape{3};
25
26 int32_t axis = 0;
27 std::vector<uint8_t> paramsValues{1, 2, 3, 4, 5, 6, 7, 8};
28 std::vector<int32_t> indicesValues{7, 6, 5};
29 std::vector<uint8_t> expectedOutputValues{8, 7, 6};
30
31 GatherTest<uint8_t>(::tflite::TensorType_UINT8,
32 backends,
33 paramsShape,
34 indicesShape,
35 expectedOutputShape,
36 axis,
37 paramsValues,
38 indicesValues,
39 expectedOutputValues);
40}
41
42void GatherFp32Test(std::vector<armnn::BackendId>& backends)
43{
44 std::vector<int32_t> paramsShape{8};
45 std::vector<int32_t> indicesShape{3};
46 std::vector<int32_t> expectedOutputShape{3};
47
48 int32_t axis = 0;
49 std::vector<float> paramsValues{1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f};
50 std::vector<int32_t> indicesValues{7, 6, 5};
51 std::vector<float> expectedOutputValues{8.8f, 7.7f, 6.6f};
52
53 GatherTest<float>(::tflite::TensorType_FLOAT32,
54 backends,
55 paramsShape,
56 indicesShape,
57 expectedOutputShape,
58 axis,
59 paramsValues,
60 indicesValues,
61 expectedOutputValues);
62}
63
64// GATHER Test Suite
65TEST_SUITE("GATHER_CpuRefTests")
66{
67
68TEST_CASE ("GATHER_Uint8_CpuRef_Test")
69{
70 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
71 GatherUint8Test(backends);
72}
73
74TEST_CASE ("GATHER_Fp32_CpuRef_Test")
75{
76 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
77 GatherFp32Test(backends);
78}
79
80}
81
82TEST_SUITE("GATHER_CpuAccTests")
83{
84
85TEST_CASE ("GATHER_Uint8_CpuAcc_Test")
86{
87 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
88 GatherUint8Test(backends);
89}
90
91TEST_CASE ("GATHER_Fp32_CpuAcc_Test")
92{
93 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
94 GatherFp32Test(backends);
95}
96
97}
98
99TEST_SUITE("GATHER_GpuAccTests")
100{
101
102TEST_CASE ("GATHER_Uint8_GpuAcc_Test")
103{
104 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
105 GatherUint8Test(backends);
106}
107
108TEST_CASE ("GATHER_Fp32_GpuAcc_Test")
109{
110 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
111 GatherFp32Test(backends);
112}
113
114}
115// End of GATHER Test Suite
116
117} // namespace armnnDelegate