blob: bd862d244584fe0e0b063767a709676d47ddf7db [file] [log] [blame]
Matthew Sloyanc8eb9552020-11-26 10:54:22 +00001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2020, 2023-2024 Arm Ltd and Contributors. All rights reserved.
Matthew Sloyanc8eb9552020-11-26 10:54:22 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000012
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000013#include <tensorflow/lite/version.h>
14
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000015namespace
16{
17
18std::vector<char> CreateLogicalBinaryTfLiteModel(tflite::BuiltinOperator logicalOperatorCode,
19 tflite::TensorType tensorType,
20 const std::vector <int32_t>& input0TensorShape,
21 const std::vector <int32_t>& input1TensorShape,
22 const std::vector <int32_t>& outputTensorShape,
23 float quantScale = 1.0f,
24 int quantOffset = 0)
25{
26 using namespace tflite;
27 flatbuffers::FlatBufferBuilder flatBufferBuilder;
28
29 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +000030 buffers.push_back(CreateBuffer(flatBufferBuilder));
31 buffers.push_back(CreateBuffer(flatBufferBuilder));
32 buffers.push_back(CreateBuffer(flatBufferBuilder));
33 buffers.push_back(CreateBuffer(flatBufferBuilder));
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000034
35 auto quantizationParameters =
36 CreateQuantizationParameters(flatBufferBuilder,
37 0,
38 0,
39 flatBufferBuilder.CreateVector<float>({ quantScale }),
40 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
41
42
43 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
44 tensors[0] = CreateTensor(flatBufferBuilder,
45 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
46 input0TensorShape.size()),
47 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000048 1,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000049 flatBufferBuilder.CreateString("input_0"),
50 quantizationParameters);
51 tensors[1] = CreateTensor(flatBufferBuilder,
52 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
53 input1TensorShape.size()),
54 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000055 2,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000056 flatBufferBuilder.CreateString("input_1"),
57 quantizationParameters);
58 tensors[2] = CreateTensor(flatBufferBuilder,
59 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
60 outputTensorShape.size()),
61 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000062 3,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000063 flatBufferBuilder.CreateString("output"),
64 quantizationParameters);
65
66 // create operator
67 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
68 flatbuffers::Offset<void> operatorBuiltinOptions = 0;
69 switch (logicalOperatorCode)
70 {
71 case BuiltinOperator_LOGICAL_AND:
72 {
73 operatorBuiltinOptionsType = BuiltinOptions_LogicalAndOptions;
74 operatorBuiltinOptions = CreateLogicalAndOptions(flatBufferBuilder).Union();
75 break;
76 }
77 case BuiltinOperator_LOGICAL_OR:
78 {
79 operatorBuiltinOptionsType = BuiltinOptions_LogicalOrOptions;
80 operatorBuiltinOptions = CreateLogicalOrOptions(flatBufferBuilder).Union();
81 break;
82 }
83 default:
84 break;
85 }
86 const std::vector<int32_t> operatorInputs{ {0, 1} };
87 const std::vector<int32_t> operatorOutputs{ 2 };
88 flatbuffers::Offset <Operator> logicalBinaryOperator =
89 CreateOperator(flatBufferBuilder,
90 0,
91 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
92 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
93 operatorBuiltinOptionsType,
94 operatorBuiltinOptions);
95
96 const std::vector<int> subgraphInputs{ {0, 1} };
97 const std::vector<int> subgraphOutputs{ 2 };
98 flatbuffers::Offset <SubGraph> subgraph =
99 CreateSubGraph(flatBufferBuilder,
100 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
101 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
102 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
103 flatBufferBuilder.CreateVector(&logicalBinaryOperator, 1));
104
105 flatbuffers::Offset <flatbuffers::String> modelDescription =
106 flatBufferBuilder.CreateString("ArmnnDelegate: Logical Binary Operator Model");
107 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, logicalOperatorCode);
108
109 flatbuffers::Offset <Model> flatbufferModel =
110 CreateModel(flatBufferBuilder,
111 TFLITE_SCHEMA_VERSION,
112 flatBufferBuilder.CreateVector(&operatorCode, 1),
113 flatBufferBuilder.CreateVector(&subgraph, 1),
114 modelDescription,
115 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
116
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100117 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000118
119 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
120 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
121}
122
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000123void LogicalBinaryTest(tflite::BuiltinOperator logicalOperatorCode,
124 tflite::TensorType tensorType,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000125 std::vector<int32_t>& input0Shape,
126 std::vector<int32_t>& input1Shape,
127 std::vector<int32_t>& expectedOutputShape,
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100128 std::vector<bool>& input0Values,
129 std::vector<bool>& input1Values,
130 std::vector<bool>& expectedOutputValues,
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000131 float quantScale = 1.0f,
Colm Donelaneff204a2023-11-28 15:46:09 +0000132 int quantOffset = 0,
133 const std::vector<armnn::BackendId>& backends = {})
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000134{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100135 using namespace delegateTestInterpreter;
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000136 std::vector<char> modelBuffer = CreateLogicalBinaryTfLiteModel(logicalOperatorCode,
137 tensorType,
138 input0Shape,
139 input1Shape,
140 expectedOutputShape,
141 quantScale,
142 quantOffset);
143
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100144 // Setup interpreter with just TFLite Runtime.
145 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
146 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
147 CHECK(tfLiteInterpreter.FillInputTensor(input0Values, 0) == kTfLiteOk);
148 CHECK(tfLiteInterpreter.FillInputTensor(input1Values, 1) == kTfLiteOk);
149 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
150 std::vector<bool> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult(0);
151 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000152
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100153 // Setup interpreter with Arm NN Delegate applied.
Colm Donelaneff204a2023-11-28 15:46:09 +0000154 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100155 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
156 CHECK(armnnInterpreter.FillInputTensor(input0Values, 0) == kTfLiteOk);
157 CHECK(armnnInterpreter.FillInputTensor(input1Values, 1) == kTfLiteOk);
158 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
159 std::vector<bool> armnnOutputValues = armnnInterpreter.GetOutputResult(0);
160 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000161
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100162 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000163
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100164 armnnDelegate::CompareData(expectedOutputValues, armnnOutputValues, expectedOutputValues.size());
165 armnnDelegate::CompareData(expectedOutputValues, tfLiteOutputValues, expectedOutputValues.size());
166 armnnDelegate::CompareData(tfLiteOutputValues, armnnOutputValues, expectedOutputValues.size());
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000167
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100168 tfLiteInterpreter.Cleanup();
169 armnnInterpreter.Cleanup();
Matthew Sloyanc8eb9552020-11-26 10:54:22 +0000170}
171
172} // anonymous namespace