blob: 1251648b47b43299984b2800ecbe128e7880ee09 [file] [log] [blame]
Davide Grohmann37fd8a32022-04-07 15:02:12 +02001/*
Jonny Svärd8788ab32023-04-27 18:05:34 +02002 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Davide Grohmann37fd8a32022-04-07 15:02:12 +02003 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#pragma once
20
21#include "tensorflow/lite/schema/schema_generated.h"
22
23#include <stdlib.h>
24#include <string>
25
26namespace InferenceProcess {
27
28template <typename T, typename U>
29class Array {
30public:
31 Array() = delete;
32 Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {}
33
34 auto size() const {
35 return _size;
36 }
37
38 auto capacity() const {
39 return _capacity;
40 }
41
42 void push_back(const T &data) {
43 _data[_size++] = data;
44 }
45
46private:
47 T *const _data;
48 U &_size;
49 const size_t _capacity{};
50};
51
52template <typename T, typename U>
53Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
54 return Array<T, U>{data, size, capacity};
55}
56
57class InferenceParser {
58public:
Davide Grohmann30b17b92022-06-14 15:17:18 +020059 const tflite::Model *getModel(const void *buffer, size_t size) {
60 // Verify buffer
61 flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t *>(buffer), size);
62 if (!tflite::VerifyModelBuffer(base_verifier)) {
63 printf("Warning: the model is not valid\n");
64 return nullptr;
65 }
66
Davide Grohmann37fd8a32022-04-07 15:02:12 +020067 // Create model handle
68 const tflite::Model *model = tflite::GetModel(buffer);
69 if (model->subgraphs() == nullptr) {
70 printf("Warning: nullptr subgraph\n");
Davide Grohmann30b17b92022-06-14 15:17:18 +020071 return nullptr;
Davide Grohmann37fd8a32022-04-07 15:02:12 +020072 }
73
Davide Grohmann30b17b92022-06-14 15:17:18 +020074 return model;
75 }
76
77 template <typename T, typename U, size_t S>
78 bool parseModel(const void *buffer, size_t size, char (&description)[S], T &&ifmDims, U &&ofmDims) {
79 const tflite::Model *model = getModel(buffer, size);
80 if (model == nullptr) {
81 return true;
82 }
Mikael Olsson74c514a2023-08-07 17:42:18 +020083
84 // Depending on the src string, strncpy may not add a null-terminator
85 // so one is manually added at the end.
86 strncpy(description, model->description()->c_str(), S - 1);
87 description[S - 1] = '\0';
Davide Grohmann37fd8a32022-04-07 15:02:12 +020088
89 // Get input dimensions for first subgraph
90 auto *subgraph = *model->subgraphs()->begin();
91 bool failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims);
92 if (failed) {
93 return true;
94 }
95
96 // Get output dimensions for last subgraph
97 subgraph = *model->subgraphs()->rbegin();
98 failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims);
99 if (failed) {
100 return true;
101 }
102
103 return false;
104 }
105
106private:
107 bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) {
108 size = 1;
109
110 if (shape == nullptr) {
111 printf("Warning: nullptr shape size.\n");
112 return true;
113 }
114
Jonny Svärd8788ab32023-04-27 18:05:34 +0200115 if (shape->size() == 0) {
116 printf("Warning: shape zero size.\n");
Davide Grohmann37fd8a32022-04-07 15:02:12 +0200117 return true;
118 }
119
120 for (auto it = shape->begin(); it != shape->end(); ++it) {
121 size *= *it;
122 }
123
124 return false;
125 }
126
127 bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) {
128 switch (type) {
129 case tflite::TensorType::TensorType_UINT8:
130 case tflite::TensorType::TensorType_INT8:
131 size = 1;
132 break;
133 case tflite::TensorType::TensorType_INT16:
134 size = 2;
135 break;
136 case tflite::TensorType::TensorType_INT32:
137 case tflite::TensorType::TensorType_FLOAT32:
138 size = 4;
139 break;
140 default:
141 printf("Warning: Unsupported tensor type\n");
142 return true;
143 }
144
145 return false;
146 }
147
148 template <typename T>
149 bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) {
150 if (subgraph == nullptr || tensorMap == nullptr) {
151 printf("Warning: nullptr subgraph or tensormap.\n");
152 return true;
153 }
154
155 if ((dims.capacity() - dims.size()) < tensorMap->size()) {
156 printf("Warning: tensormap size is larger than dimension capacity.\n");
157 return true;
158 }
159
160 for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
161 auto tensor = subgraph->tensors()->Get(*index);
162 size_t size;
163 size_t tensorSize;
164
165 bool failed = getShapeSize(tensor->shape(), size);
166 if (failed) {
167 return true;
168 }
169
170 failed = getTensorTypeSize(tensor->type(), tensorSize);
171 if (failed) {
172 return true;
173 }
174
175 size *= tensorSize;
176
177 if (size > 0) {
178 dims.push_back(size);
179 }
180 }
181
182 return false;
183 }
184};
185
186} // namespace InferenceProcess