blob: 2f2ae7bf40eec4ade336e2e4c4efa0ac8592c4bc [file] [log] [blame]
Matthew Sloyanc8eb9552020-11-26 10:54:22 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Matthew Sloyanc8eb9552020-11-26 10:54:22 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
Teresa Charlinad1b3d72023-03-14 12:10:28 +000016#include <schema_generated.h>
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000017#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
24std::vector<char> CreateLogicalBinaryTfLiteModel(tflite::BuiltinOperator logicalOperatorCode,
25 tflite::TensorType tensorType,
26 const std::vector <int32_t>& input0TensorShape,
27 const std::vector <int32_t>& input1TensorShape,
28 const std::vector <int32_t>& outputTensorShape,
29 float quantScale = 1.0f,
30 int quantOffset = 0)
31{
32 using namespace tflite;
33 flatbuffers::FlatBufferBuilder flatBufferBuilder;
34
35 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +000036 buffers.push_back(CreateBuffer(flatBufferBuilder));
37 buffers.push_back(CreateBuffer(flatBufferBuilder));
38 buffers.push_back(CreateBuffer(flatBufferBuilder));
39 buffers.push_back(CreateBuffer(flatBufferBuilder));
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000040
41 auto quantizationParameters =
42 CreateQuantizationParameters(flatBufferBuilder,
43 0,
44 0,
45 flatBufferBuilder.CreateVector<float>({ quantScale }),
46 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
47
48
49 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
50 tensors[0] = CreateTensor(flatBufferBuilder,
51 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
52 input0TensorShape.size()),
53 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000054 1,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000055 flatBufferBuilder.CreateString("input_0"),
56 quantizationParameters);
57 tensors[1] = CreateTensor(flatBufferBuilder,
58 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
59 input1TensorShape.size()),
60 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000061 2,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000062 flatBufferBuilder.CreateString("input_1"),
63 quantizationParameters);
64 tensors[2] = CreateTensor(flatBufferBuilder,
65 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
66 outputTensorShape.size()),
67 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000068 3,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000069 flatBufferBuilder.CreateString("output"),
70 quantizationParameters);
71
72 // create operator
73 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
74 flatbuffers::Offset<void> operatorBuiltinOptions = 0;
75 switch (logicalOperatorCode)
76 {
77 case BuiltinOperator_LOGICAL_AND:
78 {
79 operatorBuiltinOptionsType = BuiltinOptions_LogicalAndOptions;
80 operatorBuiltinOptions = CreateLogicalAndOptions(flatBufferBuilder).Union();
81 break;
82 }
83 case BuiltinOperator_LOGICAL_OR:
84 {
85 operatorBuiltinOptionsType = BuiltinOptions_LogicalOrOptions;
86 operatorBuiltinOptions = CreateLogicalOrOptions(flatBufferBuilder).Union();
87 break;
88 }
89 default:
90 break;
91 }
92 const std::vector<int32_t> operatorInputs{ {0, 1} };
93 const std::vector<int32_t> operatorOutputs{ 2 };
94 flatbuffers::Offset <Operator> logicalBinaryOperator =
95 CreateOperator(flatBufferBuilder,
96 0,
97 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
98 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
99 operatorBuiltinOptionsType,
100 operatorBuiltinOptions);
101
102 const std::vector<int> subgraphInputs{ {0, 1} };
103 const std::vector<int> subgraphOutputs{ 2 };
104 flatbuffers::Offset <SubGraph> subgraph =
105 CreateSubGraph(flatBufferBuilder,
106 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
107 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
108 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
109 flatBufferBuilder.CreateVector(&logicalBinaryOperator, 1));
110
111 flatbuffers::Offset <flatbuffers::String> modelDescription =
112 flatBufferBuilder.CreateString("ArmnnDelegate: Logical Binary Operator Model");
113 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, logicalOperatorCode);
114
115 flatbuffers::Offset <Model> flatbufferModel =
116 CreateModel(flatBufferBuilder,
117 TFLITE_SCHEMA_VERSION,
118 flatBufferBuilder.CreateVector(&operatorCode, 1),
119 flatBufferBuilder.CreateVector(&subgraph, 1),
120 modelDescription,
121 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
122
123 flatBufferBuilder.Finish(flatbufferModel);
124
125 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
126 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
127}
128
129template <typename T>
130void LogicalBinaryTest(tflite::BuiltinOperator logicalOperatorCode,
131 tflite::TensorType tensorType,
132 std::vector<armnn::BackendId>& backends,
133 std::vector<int32_t>& input0Shape,
134 std::vector<int32_t>& input1Shape,
135 std::vector<int32_t>& expectedOutputShape,
136 std::vector<T>& input0Values,
137 std::vector<T>& input1Values,
138 std::vector<T>& expectedOutputValues,
139 float quantScale = 1.0f,
140 int quantOffset = 0)
141{
142 using namespace tflite;
143 std::vector<char> modelBuffer = CreateLogicalBinaryTfLiteModel(logicalOperatorCode,
144 tensorType,
145 input0Shape,
146 input1Shape,
147 expectedOutputShape,
148 quantScale,
149 quantOffset);
150
151 const Model* tfLiteModel = GetModel(modelBuffer.data());
152 // Create TfLite Interpreters
153 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
154 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
155 (&armnnDelegateInterpreter) == kTfLiteOk);
156 CHECK(armnnDelegateInterpreter != nullptr);
157 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
158
159 std::unique_ptr<Interpreter> tfLiteInterpreter;
160 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
161 (&tfLiteInterpreter) == kTfLiteOk);
162 CHECK(tfLiteInterpreter != nullptr);
163 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
164
165 // Create the ArmNN Delegate
166 armnnDelegate::DelegateOptions delegateOptions(backends);
167 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
168 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
169 armnnDelegate::TfLiteArmnnDelegateDelete);
170 CHECK(theArmnnDelegate != nullptr);
171 // Modify armnnDelegateInterpreter to use armnnDelegate
172 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
173
174 // Set input data for the armnn interpreter
175 armnnDelegate::FillInput(armnnDelegateInterpreter, 0, input0Values);
176 armnnDelegate::FillInput(armnnDelegateInterpreter, 1, input1Values);
177
178 // Set input data for the tflite interpreter
179 armnnDelegate::FillInput(tfLiteInterpreter, 0, input0Values);
180 armnnDelegate::FillInput(tfLiteInterpreter, 1, input1Values);
181
182 // Run EnqueWorkload
183 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
184 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
185
186 // Compare output data, comparing Boolean values is handled differently and needs to call the CompareData function
187 // directly. This is because Boolean types get converted to a bit representation in a vector.
188 auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
189 auto tfLiteDelegateOutputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateOutputId);
190 auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
191 auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateOutputId);
192
193 armnnDelegate::CompareData(expectedOutputValues, armnnDelegateOutputData, expectedOutputValues.size());
194 armnnDelegate::CompareData(expectedOutputValues, tfLiteDelegateOutputData, expectedOutputValues.size());
195 armnnDelegate::CompareData(tfLiteDelegateOutputData, armnnDelegateOutputData, expectedOutputValues.size());
196
197 armnnDelegateInterpreter.reset(nullptr);
198 tfLiteInterpreter.reset(nullptr);
199}
200
201} // anonymous namespace