blob: 9c39c30f1f15ae82f4b8fd74aec21743df895ad1 [file] [log] [blame]
Matthew Sloyana35b40b2021-02-05 17:22:28 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Matthew Sloyana35b40b2021-02-05 17:22:28 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Matthew Sloyana35b40b2021-02-05 17:22:28 +000012
13#include <flatbuffers/flatbuffers.h>
Matthew Sloyana35b40b2021-02-05 17:22:28 +000014#include <tensorflow/lite/kernels/register.h>
Matthew Sloyana35b40b2021-02-05 17:22:28 +000015#include <tensorflow/lite/version.h>
16
17#include <doctest/doctest.h>
18
19namespace
20{
21
22std::vector<char> CreateBatchSpaceTfLiteModel(tflite::BuiltinOperator batchSpaceOperatorCode,
23 tflite::TensorType tensorType,
24 std::vector<int32_t>& inputTensorShape,
25 std::vector <int32_t>& outputTensorShape,
26 std::vector<unsigned int>& blockData,
27 std::vector<std::pair<unsigned int, unsigned int>>& cropsPadData,
28 float quantScale = 1.0f,
29 int quantOffset = 0)
30{
31 using namespace tflite;
32 flatbuffers::FlatBufferBuilder flatBufferBuilder;
33
Ryan OShea238ecd92023-03-07 11:44:23 +000034 std::array<flatbuffers::Offset<tflite::Buffer>, 5> buffers;
35 buffers[0] = CreateBuffer(flatBufferBuilder);
36 buffers[1] = CreateBuffer(flatBufferBuilder);
37 buffers[2] = CreateBuffer(flatBufferBuilder,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000038 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(blockData.data()),
39 sizeof(int32_t) * blockData.size()));
Ryan OShea238ecd92023-03-07 11:44:23 +000040 buffers[3] = CreateBuffer(flatBufferBuilder,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000041 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cropsPadData.data()),
42 sizeof(int64_t) * cropsPadData.size()));
Ryan OShea238ecd92023-03-07 11:44:23 +000043 buffers[4] = CreateBuffer(flatBufferBuilder);
Matthew Sloyana35b40b2021-02-05 17:22:28 +000044
45 auto quantizationParameters =
46 CreateQuantizationParameters(flatBufferBuilder,
47 0,
48 0,
49 flatBufferBuilder.CreateVector<float>({ quantScale }),
50 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
51
52 std::string cropsOrPadding =
53 batchSpaceOperatorCode == tflite::BuiltinOperator_BATCH_TO_SPACE_ND ? "crops" : "padding";
54
55 std::vector<int32_t> blockShape { 2 };
56 std::vector<int32_t> cropsOrPaddingShape { 2, 2 };
57
58 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
59 tensors[0] = CreateTensor(flatBufferBuilder,
60 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
61 inputTensorShape.size()),
62 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000063 1,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000064 flatBufferBuilder.CreateString("input"),
65 quantizationParameters);
66
67 tensors[1] = CreateTensor(flatBufferBuilder,
68 flatBufferBuilder.CreateVector<int32_t>(blockShape.data(),
69 blockShape.size()),
70 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000071 2,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000072 flatBufferBuilder.CreateString("block"),
73 quantizationParameters);
74
75 tensors[2] = CreateTensor(flatBufferBuilder,
76 flatBufferBuilder.CreateVector<int32_t>(cropsOrPaddingShape.data(),
77 cropsOrPaddingShape.size()),
78 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000079 3,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000080 flatBufferBuilder.CreateString(cropsOrPadding),
81 quantizationParameters);
82
83 // Create output tensor
84 tensors[3] = CreateTensor(flatBufferBuilder,
85 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
86 outputTensorShape.size()),
87 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000088 4,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000089 flatBufferBuilder.CreateString("output"),
90 quantizationParameters);
91
92 // Create operator
93 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
94 flatbuffers::Offset<void> operatorBuiltinOptions = 0;
95 switch (batchSpaceOperatorCode)
96 {
97 case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
98 {
99 operatorBuiltinOptionsType = tflite::BuiltinOptions_BatchToSpaceNDOptions;
100 operatorBuiltinOptions = CreateBatchToSpaceNDOptions(flatBufferBuilder).Union();
101 break;
102 }
103 case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
104 {
105 operatorBuiltinOptionsType = tflite::BuiltinOptions_SpaceToBatchNDOptions;
106 operatorBuiltinOptions = CreateSpaceToBatchNDOptions(flatBufferBuilder).Union();
107 break;
108 }
109 default:
110 break;
111 }
112
113 const std::vector<int> operatorInputs{ {0, 1, 2} };
114 const std::vector<int> operatorOutputs{ 3 };
115 flatbuffers::Offset <Operator> batchSpaceOperator =
116 CreateOperator(flatBufferBuilder,
117 0,
118 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
119 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
120 operatorBuiltinOptionsType,
121 operatorBuiltinOptions);
122
123 const std::vector<int> subgraphInputs{ {0, 1, 2} };
124 const std::vector<int> subgraphOutputs{ 3 };
125 flatbuffers::Offset <SubGraph> subgraph =
126 CreateSubGraph(flatBufferBuilder,
127 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
128 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
129 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
130 flatBufferBuilder.CreateVector(&batchSpaceOperator, 1));
131
132 flatbuffers::Offset <flatbuffers::String> modelDescription =
133 flatBufferBuilder.CreateString("ArmnnDelegate: BatchSpace Operator Model");
134 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, batchSpaceOperatorCode);
135
136 flatbuffers::Offset <Model> flatbufferModel =
137 CreateModel(flatBufferBuilder,
138 TFLITE_SCHEMA_VERSION,
139 flatBufferBuilder.CreateVector(&operatorCode, 1),
140 flatBufferBuilder.CreateVector(&subgraph, 1),
141 modelDescription,
142 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
143
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100144 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000145
146 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
147 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
148}
149
150template <typename T>
151void BatchSpaceTest(tflite::BuiltinOperator controlOperatorCode,
152 tflite::TensorType tensorType,
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000153 std::vector<int32_t>& inputShape,
154 std::vector<int32_t>& expectedOutputShape,
155 std::vector<T>& inputValues,
156 std::vector<unsigned int>& blockShapeValues,
157 std::vector<std::pair<unsigned int, unsigned int>>& cropsPaddingValues,
158 std::vector<T>& expectedOutputValues,
159 float quantScale = 1.0f,
Colm Donelaneff204a2023-11-28 15:46:09 +0000160 int quantOffset = 0,
161 const std::vector<armnn::BackendId>& backends = {})
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000162{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100163 using namespace delegateTestInterpreter;
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000164 std::vector<char> modelBuffer = CreateBatchSpaceTfLiteModel(controlOperatorCode,
165 tensorType,
166 inputShape,
167 expectedOutputShape,
168 blockShapeValues,
169 cropsPaddingValues,
170 quantScale,
171 quantOffset);
172
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100173 // Setup interpreter with just TFLite Runtime.
174 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
175 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
176 CHECK(tfLiteInterpreter.FillInputTensor(inputValues, 0) == kTfLiteOk);
177 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
178 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
179 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000180
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100181 // Setup interpreter with Arm NN Delegate applied.
Colm Donelaneff204a2023-11-28 15:46:09 +0000182 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100183 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
184 CHECK(armnnInterpreter.FillInputTensor(inputValues, 0) == kTfLiteOk);
185 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
186 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
187 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000188
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100189 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
190 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000191
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100192 tfLiteInterpreter.Cleanup();
193 armnnInterpreter.Cleanup();
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000194}
195
196} // anonymous namespace