blob: 436790d6ff242b89c777ae2734e2c53ddfa5eb3e [file] [log] [blame]
Sadik Armagan8b9858d2020-11-09 08:26:22 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan8b9858d2020-11-09 08:26:22 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Jan Eilersfe73b042020-11-18 10:36:46 +00008#include "TestUtils.hpp"
9
Sadik Armagan8b9858d2020-11-09 08:26:22 +000010#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000012
13#include <flatbuffers/flatbuffers.h>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000014#include <tensorflow/lite/kernels/register.h>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000015#include <tensorflow/lite/version.h>
16
17#include <doctest/doctest.h>
18
19namespace
20{
21
22std::vector<char> CreateComparisonTfLiteModel(tflite::BuiltinOperator comparisonOperatorCode,
23 tflite::TensorType tensorType,
24 const std::vector <int32_t>& input0TensorShape,
25 const std::vector <int32_t>& input1TensorShape,
26 const std::vector <int32_t>& outputTensorShape,
27 float quantScale = 1.0f,
28 int quantOffset = 0)
29{
30 using namespace tflite;
31 flatbuffers::FlatBufferBuilder flatBufferBuilder;
32
33 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +000034 buffers.push_back(CreateBuffer(flatBufferBuilder));
35 buffers.push_back(CreateBuffer(flatBufferBuilder));
36 buffers.push_back(CreateBuffer(flatBufferBuilder));
37 buffers.push_back(CreateBuffer(flatBufferBuilder));
Sadik Armagan8b9858d2020-11-09 08:26:22 +000038
39 auto quantizationParameters =
40 CreateQuantizationParameters(flatBufferBuilder,
41 0,
42 0,
43 flatBufferBuilder.CreateVector<float>({ quantScale }),
44 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
45
46 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
47 tensors[0] = CreateTensor(flatBufferBuilder,
48 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
49 input0TensorShape.size()),
50 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000051 1,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000052 flatBufferBuilder.CreateString("input_0"),
53 quantizationParameters);
54 tensors[1] = CreateTensor(flatBufferBuilder,
55 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
56 input1TensorShape.size()),
57 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000058 2,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000059 flatBufferBuilder.CreateString("input_1"),
60 quantizationParameters);
61 tensors[2] = CreateTensor(flatBufferBuilder,
62 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
63 outputTensorShape.size()),
64 ::tflite::TensorType_BOOL,
Ryan OShea238ecd92023-03-07 11:44:23 +000065 3);
Sadik Armagan8b9858d2020-11-09 08:26:22 +000066
67 // create operator
68 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_EqualOptions;;
69 flatbuffers::Offset<void> operatorBuiltinOptions = CreateEqualOptions(flatBufferBuilder).Union();
70 switch (comparisonOperatorCode)
71 {
72 case BuiltinOperator_EQUAL:
73 {
74 operatorBuiltinOptionsType = BuiltinOptions_EqualOptions;
75 operatorBuiltinOptions = CreateEqualOptions(flatBufferBuilder).Union();
76 break;
77 }
78 case BuiltinOperator_NOT_EQUAL:
79 {
80 operatorBuiltinOptionsType = BuiltinOptions_NotEqualOptions;
81 operatorBuiltinOptions = CreateNotEqualOptions(flatBufferBuilder).Union();
82 break;
83 }
84 case BuiltinOperator_GREATER:
85 {
86 operatorBuiltinOptionsType = BuiltinOptions_GreaterOptions;
87 operatorBuiltinOptions = CreateGreaterOptions(flatBufferBuilder).Union();
88 break;
89 }
90 case BuiltinOperator_GREATER_EQUAL:
91 {
92 operatorBuiltinOptionsType = BuiltinOptions_GreaterEqualOptions;
93 operatorBuiltinOptions = CreateGreaterEqualOptions(flatBufferBuilder).Union();
94 break;
95 }
96 case BuiltinOperator_LESS:
97 {
98 operatorBuiltinOptionsType = BuiltinOptions_LessOptions;
99 operatorBuiltinOptions = CreateLessOptions(flatBufferBuilder).Union();
100 break;
101 }
102 case BuiltinOperator_LESS_EQUAL:
103 {
104 operatorBuiltinOptionsType = BuiltinOptions_LessEqualOptions;
105 operatorBuiltinOptions = CreateLessEqualOptions(flatBufferBuilder).Union();
106 break;
107 }
108 default:
109 break;
110 }
Keith Davis892fafe2020-11-26 17:40:35 +0000111 const std::vector<int32_t> operatorInputs{0, 1};
112 const std::vector<int32_t> operatorOutputs{2};
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000113 flatbuffers::Offset <Operator> comparisonOperator =
114 CreateOperator(flatBufferBuilder,
115 0,
116 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
117 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
118 operatorBuiltinOptionsType,
119 operatorBuiltinOptions);
120
Keith Davis892fafe2020-11-26 17:40:35 +0000121 const std::vector<int> subgraphInputs{0, 1};
122 const std::vector<int> subgraphOutputs{2};
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000123 flatbuffers::Offset <SubGraph> subgraph =
124 CreateSubGraph(flatBufferBuilder,
125 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
126 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
127 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
128 flatBufferBuilder.CreateVector(&comparisonOperator, 1));
129
130 flatbuffers::Offset <flatbuffers::String> modelDescription =
131 flatBufferBuilder.CreateString("ArmnnDelegate: Comparison Operator Model");
132 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, comparisonOperatorCode);
133
134 flatbuffers::Offset <Model> flatbufferModel =
135 CreateModel(flatBufferBuilder,
136 TFLITE_SCHEMA_VERSION,
137 flatBufferBuilder.CreateVector(&operatorCode, 1),
138 flatBufferBuilder.CreateVector(&subgraph, 1),
139 modelDescription,
140 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
141
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100142 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000143
144 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
145 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
146}
147
148template <typename T>
149void ComparisonTest(tflite::BuiltinOperator comparisonOperatorCode,
150 tflite::TensorType tensorType,
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000151 std::vector<int32_t>& input0Shape,
152 std::vector<int32_t>& input1Shape,
153 std::vector<int32_t>& outputShape,
154 std::vector<T>& input0Values,
155 std::vector<T>& input1Values,
156 std::vector<bool>& expectedOutputValues,
157 float quantScale = 1.0f,
Colm Donelaneff204a2023-11-28 15:46:09 +0000158 int quantOffset = 0,
159 const std::vector<armnn::BackendId>& backends = {})
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000160{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100161 using namespace delegateTestInterpreter;
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000162 std::vector<char> modelBuffer = CreateComparisonTfLiteModel(comparisonOperatorCode,
163 tensorType,
164 input0Shape,
165 input1Shape,
166 outputShape,
167 quantScale,
168 quantOffset);
169
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100170 // Setup interpreter with just TFLite Runtime.
171 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
172 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
173 CHECK(tfLiteInterpreter.FillInputTensor<T>(input0Values, 0) == kTfLiteOk);
174 CHECK(tfLiteInterpreter.FillInputTensor<T>(input1Values, 1) == kTfLiteOk);
175 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
176 std::vector<bool> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult(0);
177 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000178
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100179 // Setup interpreter with Arm NN Delegate applied.
Colm Donelaneff204a2023-11-28 15:46:09 +0000180 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100181 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
182 CHECK(armnnInterpreter.FillInputTensor<T>(input0Values, 0) == kTfLiteOk);
183 CHECK(armnnInterpreter.FillInputTensor<T>(input1Values, 1) == kTfLiteOk);
184 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
185 std::vector<bool> armnnOutputValues = armnnInterpreter.GetOutputResult(0);
186 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000187
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100188 armnnDelegate::CompareData(expectedOutputValues, armnnOutputValues, expectedOutputValues.size());
189 armnnDelegate::CompareData(expectedOutputValues, tfLiteOutputValues, expectedOutputValues.size());
190 armnnDelegate::CompareData(tfLiteOutputValues, armnnOutputValues, expectedOutputValues.size());
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000191
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100192 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000193
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100194 tfLiteInterpreter.Cleanup();
195 armnnInterpreter.Cleanup();
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000196}
197
198} // anonymous namespace