blob: f4692cfb07903ef561680936f76e0869318de76b [file] [log] [blame]
Idriss Chaouchcbf79292023-09-08 11:18:16 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BroadcastToTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9#include <flatbuffers/flatbuffers.h>
10#include <tensorflow/lite/interpreter.h>
11#include <tensorflow/lite/kernels/register.h>
12#include <tensorflow/lite/model.h>
13#include <schema_generated.h>
14#include <tensorflow/lite/version.h>
15#include <doctest/doctest.h>
16
17namespace armnnDelegate
18{
19template<typename T>
20void BroadcastToTest(std::vector<armnn::BackendId> &backends, tflite::TensorType inputTensorType)
21{
22 // Set input data
23 std::vector<T> inputValues = {
24 0, 1, 2, 3
25 };
26 // Set output data
27 std::vector<T> expectedOutputValues = {
28 0, 1, 2, 3,
29 0, 1, 2, 3,
30 0, 1, 2, 3
31 };
32
33 // The shape data
34 const std::vector<int32_t> shapeData = {3, 4};
35
36 // Set shapes
37 const std::vector<int32_t> inputShape = {1, 4};
38 const std::vector<int32_t> shapeShape = {2};
39 const std::vector<int32_t> expectedOutputShape = {3, 4};
40
41 BroadcastToTestImpl<T>(inputTensorType,
42 tflite::BuiltinOperator_BROADCAST_TO,
43 backends,
44 inputValues,
45 inputShape,
46 shapeShape,
47 shapeData,
48 expectedOutputValues,
49 expectedOutputShape);
50}
51
52TEST_SUITE("BroadcastToTests_CpuRefTests")
53{
54
55 TEST_CASE ("BroadcastTo_int_CpuRef_Test")
56 {
57 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
58 BroadcastToTest<int32_t>(backends, ::tflite::TensorType::TensorType_INT32);
59 }
60
61 TEST_CASE ("BroadcastTo_Float32_CpuRef_Test")
62 {
63 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
64 BroadcastToTest<float>(backends, ::tflite::TensorType::TensorType_FLOAT32);
65 }
66
67 TEST_CASE ("BroadcastTo_Uint8_t_CpuRef_Test")
68 {
69 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
70 BroadcastToTest<uint8_t>(backends, ::tflite::TensorType::TensorType_UINT8);
71 }
72
73 TEST_CASE ("BroadcastTo_Int8_t_CpuRef_Test")
74 {
75 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
76 BroadcastToTest<int8_t>(backends, ::tflite::TensorType::TensorType_INT8);
77 }
78
79} // TEST_SUITE("BroadcastToTests_CpuRefTests")
80}