blob: 8fcb762474df8283b3e5f961979a19e7678654a0 [file] [log] [blame]
Idriss Chaouchcbf79292023-09-08 11:18:16 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11#include <DelegateTestInterpreter.hpp>
12
13#include <flatbuffers/flatbuffers.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/version.h>
16
Idriss Chaouchcbf79292023-09-08 11:18:16 +010017#include <doctest/doctest.h>
18
19namespace
20{
21 std::vector<char> CreateBroadcastToTfLiteModel(tflite::BuiltinOperator operatorCode,
22 tflite::TensorType inputTensorType,
23 const std::vector<int32_t>& inputTensorShape,
24 const std::vector<int32_t>& shapeTensorShape,
25 const std::vector<int32_t>& shapeTensorData,
26 const std::vector<int32_t>& outputTensorShape)
27 {
28 using namespace tflite;
29 flatbuffers::FlatBufferBuilder flatBufferBuilder;
30
31 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
32 buffers.push_back(CreateBuffer(flatBufferBuilder));
33 buffers.push_back(CreateBuffer(flatBufferBuilder));
34 buffers.push_back(CreateBuffer(flatBufferBuilder,
35 flatBufferBuilder.CreateVector(
36 reinterpret_cast<const uint8_t*>(shapeTensorData.data()),
37 sizeof(int32_t) * shapeTensorData.size())));
38 buffers.push_back(CreateBuffer(flatBufferBuilder));
39
40 float qScale = 1.0f;
41 int32_t qOffset = 0;
42
43 auto quantizationParameters =
44 CreateQuantizationParameters(flatBufferBuilder,
45 0,
46 0,
47 flatBufferBuilder.CreateVector<float>({ qScale }),
48 flatBufferBuilder.CreateVector<int64_t>({ qOffset }));
49
50 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
51 tensors[0] = CreateTensor(flatBufferBuilder,
52 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
53 inputTensorShape.size()),
54 inputTensorType,
55 1,
56 flatBufferBuilder.CreateString("input_tensor"),
57 quantizationParameters);
58
59 tensors[1] = CreateTensor(flatBufferBuilder,
60 flatBufferBuilder.CreateVector<int32_t>(shapeTensorShape.data(),
61 shapeTensorShape.size()),
62 TensorType_INT32,
63 2,
64 flatBufferBuilder.CreateString("shape_input_tensor"),
65 quantizationParameters);
66
67 tensors[2] = CreateTensor(flatBufferBuilder,
68 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
69 outputTensorShape.size()),
70 inputTensorType,
71 3,
72 flatBufferBuilder.CreateString("output_tensor"),
73 quantizationParameters);
74
75 // Create Operator
76 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_BroadcastToOptions;
77 flatbuffers::Offset<void> operatorBuiltinOption = 0;
78
79 const std::vector<int> operatorInputs {0, 1};
80 const std::vector<int> operatorOutputs {2};
81
82 flatbuffers::Offset<Operator> broadcastOperator =
83 CreateOperator(flatBufferBuilder,
84 0,
85 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
86 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
87 operatorBuiltinOptionsType,
88 operatorBuiltinOption);
89
90 const std::vector<int> subgraphInputs{0, 1};
91 const std::vector<int> subgraphOutputs{2};
92 flatbuffers::Offset <SubGraph> subgraph =
93 CreateSubGraph(flatBufferBuilder,
94 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
95 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
96 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
97 flatBufferBuilder.CreateVector(&broadcastOperator, 1));
98
99 flatbuffers::Offset <flatbuffers::String> modelDescription =
100 flatBufferBuilder.CreateString("ArmnnDelegate: BrodacastTo Operator Model");
101 flatbuffers::Offset <OperatorCode> opCode = CreateOperatorCode(flatBufferBuilder,0,
102 0, 2,
103 tflite::BuiltinOperator_BROADCAST_TO);
104
105 flatbuffers::Offset <Model> flatbufferModel =
106 CreateModel(flatBufferBuilder,
107 TFLITE_SCHEMA_VERSION,
108 flatBufferBuilder.CreateVector(&opCode, 1),
109 flatBufferBuilder.CreateVector(&subgraph, 1),
110 modelDescription,
111 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
112
113 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
114
115 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
116 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
117 }
118
119 template<typename T>
120 void BroadcastToTestImpl(tflite::TensorType inputTensorType,
121 tflite::BuiltinOperator operatorCode,
Idriss Chaouchcbf79292023-09-08 11:18:16 +0100122 std::vector<T>& inputValues,
123 std::vector<int32_t> inputShape,
124 std::vector<int32_t> shapeShapes,
125 std::vector<int32_t> shapeData,
126 std::vector<T>& expectedOutputValues,
Colm Donelaneff204a2023-11-28 15:46:09 +0000127 std::vector<int32_t> expectedOutputShape,
128 const std::vector<armnn::BackendId>& backends)
Idriss Chaouchcbf79292023-09-08 11:18:16 +0100129 {
130 using namespace delegateTestInterpreter;
131
132 std::vector<char> modelBuffer = CreateBroadcastToTfLiteModel(operatorCode,
133 inputTensorType,
134 inputShape,
135 shapeShapes,
136 shapeData,
137 expectedOutputShape);
138
139
140 // Setup interpreter with just TFLite Runtime.
141 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
142 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
143 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
144 CHECK(tfLiteInterpreter.FillInputTensor<int32_t>(shapeData, 1) == kTfLiteOk);
145 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
146 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
147 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
148
149 // Setup interpreter with Arm NN Delegate applied.
Colm Donelaneff204a2023-11-28 15:46:09 +0000150 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Idriss Chaouchcbf79292023-09-08 11:18:16 +0100151 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
152 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
153 CHECK(armnnInterpreter.FillInputTensor<int32_t>(shapeData, 1) == kTfLiteOk);
154 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
155 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
156 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
157
158 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
159 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
160
161 tfLiteInterpreter.Cleanup();
162 armnnInterpreter.Cleanup();
163 }
164
165} // anonymous namespace