blob: 3d908183f7e56d0cac3425fff6e7efb4164d96e9 [file] [log] [blame]
Davide Grohmann37fd8a32022-04-07 15:02:12 +02001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#pragma once
20
21#include "tensorflow/lite/schema/schema_generated.h"
22
23#include <stdlib.h>
24#include <string>
25
26namespace InferenceProcess {
27
28template <typename T, typename U>
29class Array {
30public:
31 Array() = delete;
32 Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {}
33
34 auto size() const {
35 return _size;
36 }
37
38 auto capacity() const {
39 return _capacity;
40 }
41
42 void push_back(const T &data) {
43 _data[_size++] = data;
44 }
45
46private:
47 T *const _data;
48 U &_size;
49 const size_t _capacity{};
50};
51
52template <typename T, typename U>
53Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
54 return Array<T, U>{data, size, capacity};
55}
56
57class InferenceParser {
58public:
59 template <typename T, typename U, size_t S>
60 bool parseModel(const void *buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) {
61 // Create model handle
62 const tflite::Model *model = tflite::GetModel(buffer);
63 if (model->subgraphs() == nullptr) {
64 printf("Warning: nullptr subgraph\n");
65 return true;
66 }
67
68 strncpy(description, model->description()->c_str(), sizeof(description));
69
70 // Get input dimensions for first subgraph
71 auto *subgraph = *model->subgraphs()->begin();
72 bool failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims);
73 if (failed) {
74 return true;
75 }
76
77 // Get output dimensions for last subgraph
78 subgraph = *model->subgraphs()->rbegin();
79 failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims);
80 if (failed) {
81 return true;
82 }
83
84 return false;
85 }
86
87private:
88 bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) {
89 size = 1;
90
91 if (shape == nullptr) {
92 printf("Warning: nullptr shape size.\n");
93 return true;
94 }
95
96 if (shape->Length() == 0) {
97 printf("Warning: shape zero length.\n");
98 return true;
99 }
100
101 for (auto it = shape->begin(); it != shape->end(); ++it) {
102 size *= *it;
103 }
104
105 return false;
106 }
107
108 bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) {
109 switch (type) {
110 case tflite::TensorType::TensorType_UINT8:
111 case tflite::TensorType::TensorType_INT8:
112 size = 1;
113 break;
114 case tflite::TensorType::TensorType_INT16:
115 size = 2;
116 break;
117 case tflite::TensorType::TensorType_INT32:
118 case tflite::TensorType::TensorType_FLOAT32:
119 size = 4;
120 break;
121 default:
122 printf("Warning: Unsupported tensor type\n");
123 return true;
124 }
125
126 return false;
127 }
128
129 template <typename T>
130 bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) {
131 if (subgraph == nullptr || tensorMap == nullptr) {
132 printf("Warning: nullptr subgraph or tensormap.\n");
133 return true;
134 }
135
136 if ((dims.capacity() - dims.size()) < tensorMap->size()) {
137 printf("Warning: tensormap size is larger than dimension capacity.\n");
138 return true;
139 }
140
141 for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
142 auto tensor = subgraph->tensors()->Get(*index);
143 size_t size;
144 size_t tensorSize;
145
146 bool failed = getShapeSize(tensor->shape(), size);
147 if (failed) {
148 return true;
149 }
150
151 failed = getTensorTypeSize(tensor->type(), tensorSize);
152 if (failed) {
153 return true;
154 }
155
156 size *= tensorSize;
157
158 if (size > 0) {
159 dims.push_back(size);
160 }
161 }
162
163 return false;
164 }
165};
166
167} // namespace InferenceProcess