blob: fd7a97a1e5318416e7f1a4bcd95a783114db386b [file] [log] [blame]
Davide Grohmann37fd8a32022-04-07 15:02:12 +02001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#pragma once
20
21#include "tensorflow/lite/schema/schema_generated.h"
22
23#include <stdlib.h>
24#include <string>
25
26namespace InferenceProcess {
27
28template <typename T, typename U>
29class Array {
30public:
31 Array() = delete;
32 Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {}
33
34 auto size() const {
35 return _size;
36 }
37
38 auto capacity() const {
39 return _capacity;
40 }
41
42 void push_back(const T &data) {
43 _data[_size++] = data;
44 }
45
46private:
47 T *const _data;
48 U &_size;
49 const size_t _capacity{};
50};
51
52template <typename T, typename U>
53Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
54 return Array<T, U>{data, size, capacity};
55}
56
57class InferenceParser {
58public:
Davide Grohmann30b17b92022-06-14 15:17:18 +020059 const tflite::Model *getModel(const void *buffer, size_t size) {
60 // Verify buffer
61 flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t *>(buffer), size);
62 if (!tflite::VerifyModelBuffer(base_verifier)) {
63 printf("Warning: the model is not valid\n");
64 return nullptr;
65 }
66
Davide Grohmann37fd8a32022-04-07 15:02:12 +020067 // Create model handle
68 const tflite::Model *model = tflite::GetModel(buffer);
69 if (model->subgraphs() == nullptr) {
70 printf("Warning: nullptr subgraph\n");
Davide Grohmann30b17b92022-06-14 15:17:18 +020071 return nullptr;
Davide Grohmann37fd8a32022-04-07 15:02:12 +020072 }
73
Davide Grohmann30b17b92022-06-14 15:17:18 +020074 return model;
75 }
76
77 template <typename T, typename U, size_t S>
78 bool parseModel(const void *buffer, size_t size, char (&description)[S], T &&ifmDims, U &&ofmDims) {
79 const tflite::Model *model = getModel(buffer, size);
80 if (model == nullptr) {
81 return true;
82 }
Davide Grohmann37fd8a32022-04-07 15:02:12 +020083 strncpy(description, model->description()->c_str(), sizeof(description));
84
85 // Get input dimensions for first subgraph
86 auto *subgraph = *model->subgraphs()->begin();
87 bool failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims);
88 if (failed) {
89 return true;
90 }
91
92 // Get output dimensions for last subgraph
93 subgraph = *model->subgraphs()->rbegin();
94 failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims);
95 if (failed) {
96 return true;
97 }
98
99 return false;
100 }
101
102private:
103 bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) {
104 size = 1;
105
106 if (shape == nullptr) {
107 printf("Warning: nullptr shape size.\n");
108 return true;
109 }
110
111 if (shape->Length() == 0) {
112 printf("Warning: shape zero length.\n");
113 return true;
114 }
115
116 for (auto it = shape->begin(); it != shape->end(); ++it) {
117 size *= *it;
118 }
119
120 return false;
121 }
122
123 bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) {
124 switch (type) {
125 case tflite::TensorType::TensorType_UINT8:
126 case tflite::TensorType::TensorType_INT8:
127 size = 1;
128 break;
129 case tflite::TensorType::TensorType_INT16:
130 size = 2;
131 break;
132 case tflite::TensorType::TensorType_INT32:
133 case tflite::TensorType::TensorType_FLOAT32:
134 size = 4;
135 break;
136 default:
137 printf("Warning: Unsupported tensor type\n");
138 return true;
139 }
140
141 return false;
142 }
143
144 template <typename T>
145 bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) {
146 if (subgraph == nullptr || tensorMap == nullptr) {
147 printf("Warning: nullptr subgraph or tensormap.\n");
148 return true;
149 }
150
151 if ((dims.capacity() - dims.size()) < tensorMap->size()) {
152 printf("Warning: tensormap size is larger than dimension capacity.\n");
153 return true;
154 }
155
156 for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
157 auto tensor = subgraph->tensors()->Get(*index);
158 size_t size;
159 size_t tensorSize;
160
161 bool failed = getShapeSize(tensor->shape(), size);
162 if (failed) {
163 return true;
164 }
165
166 failed = getTensorTypeSize(tensor->type(), tensorSize);
167 if (failed) {
168 return true;
169 }
170
171 size *= tensorSize;
172
173 if (size > 0) {
174 dims.push_back(size);
175 }
176 }
177
178 return false;
179 }
180};
181
182} // namespace InferenceProcess