blob: 99f1a29af03ed74d1ab828a6895196c36346902b [file] [log] [blame]
Idriss Chaouchcbf79292023-09-08 11:18:16 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BroadcastToTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9#include <flatbuffers/flatbuffers.h>
10#include <tensorflow/lite/interpreter.h>
11#include <tensorflow/lite/kernels/register.h>
12#include <tensorflow/lite/model.h>
Idriss Chaouchcbf79292023-09-08 11:18:16 +010013#include <tensorflow/lite/version.h>
14#include <doctest/doctest.h>
15
16namespace armnnDelegate
17{
18template<typename T>
19void BroadcastToTest(std::vector<armnn::BackendId> &backends, tflite::TensorType inputTensorType)
20{
21 // Set input data
22 std::vector<T> inputValues = {
23 0, 1, 2, 3
24 };
25 // Set output data
26 std::vector<T> expectedOutputValues = {
27 0, 1, 2, 3,
28 0, 1, 2, 3,
29 0, 1, 2, 3
30 };
31
32 // The shape data
33 const std::vector<int32_t> shapeData = {3, 4};
34
35 // Set shapes
36 const std::vector<int32_t> inputShape = {1, 4};
37 const std::vector<int32_t> shapeShape = {2};
38 const std::vector<int32_t> expectedOutputShape = {3, 4};
39
40 BroadcastToTestImpl<T>(inputTensorType,
41 tflite::BuiltinOperator_BROADCAST_TO,
Idriss Chaouchcbf79292023-09-08 11:18:16 +010042 inputValues,
43 inputShape,
44 shapeShape,
45 shapeData,
46 expectedOutputValues,
Colm Donelaneff204a2023-11-28 15:46:09 +000047 expectedOutputShape,
48 backends);
Idriss Chaouchcbf79292023-09-08 11:18:16 +010049}
50
Colm Donelaneff204a2023-11-28 15:46:09 +000051TEST_SUITE("BroadcastToTests_Tests")
Idriss Chaouchcbf79292023-09-08 11:18:16 +010052{
53
Colm Donelaneff204a2023-11-28 15:46:09 +000054 /**
55 * Only CpuRef is supported for these tests.
56 */
57 TEST_CASE ("BroadcastTo_int_Test")
Idriss Chaouchcbf79292023-09-08 11:18:16 +010058 {
59 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
60 BroadcastToTest<int32_t>(backends, ::tflite::TensorType::TensorType_INT32);
61 }
62
Colm Donelaneff204a2023-11-28 15:46:09 +000063 TEST_CASE ("BroadcastTo_Float32_Test")
Idriss Chaouchcbf79292023-09-08 11:18:16 +010064 {
65 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
66 BroadcastToTest<float>(backends, ::tflite::TensorType::TensorType_FLOAT32);
67 }
68
Colm Donelaneff204a2023-11-28 15:46:09 +000069 TEST_CASE ("BroadcastTo_Uint8_t_Test")
Idriss Chaouchcbf79292023-09-08 11:18:16 +010070 {
71 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
72 BroadcastToTest<uint8_t>(backends, ::tflite::TensorType::TensorType_UINT8);
73 }
74
Colm Donelaneff204a2023-11-28 15:46:09 +000075 TEST_CASE ("BroadcastTo_Int8_t_Test")
Idriss Chaouchcbf79292023-09-08 11:18:16 +010076 {
77 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
78 BroadcastToTest<int8_t>(backends, ::tflite::TensorType::TensorType_INT8);
79 }
80
81} // TEST_SUITE("BroadcastToTests_CpuRefTests")
82}