blob: b814da4c8a53b71d8906f9fa24c4e8c44d0f2645 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef MODEL_HPP
18#define MODEL_HPP
19
20#include "TensorFlowLiteMicro.hpp"
21#include "BufAttributes.hpp"
22
23#include <cstdint>
24
25namespace arm {
26namespace app {
27
28 /**
29 * @brief NN model class wrapping the underlying TensorFlow-Lite-Micro API.
30 */
31 class Model {
32 public:
33 /** @brief Constructor. */
34 Model();
35
36 /** @brief Destructor. */
37 ~Model();
38
39 /** @brief Gets the pointer to the model's input tensor at given input index. */
40 TfLiteTensor* GetInputTensor(size_t index) const;
41
42 /** @brief Gets the pointer to the model's output tensor at given output index. */
43 TfLiteTensor* GetOutputTensor(size_t index) const;
44
45 /** @brief Gets the model's data type. */
46 TfLiteType GetType() const;
47
48 /** @brief Gets the pointer to the model's input shape. */
49 TfLiteIntArray* GetInputShape(size_t index) const;
50
51 /** @brief Gets the pointer to the model's output shape at given output index. */
52 TfLiteIntArray* GetOutputShape(size_t index) const;
53
54 /** @brief Gets the number of input tensors the model has. */
55 size_t GetNumInputs() const;
56
57 /** @brief Gets the number of output tensors the model has. */
58 size_t GetNumOutputs() const;
59
60 /** @brief Logs the tensor information to stdout. */
61 void LogTensorInfo(TfLiteTensor* tensor);
62
63 /** @brief Logs the interpreter information to stdout. */
64 void LogInterpreterInfo();
65
66 /** @brief Initialise the model class object.
67 * @param[in] allocator Optional: a pre-initialised micro allocator pointer,
68 * if available. If supplied, this allocator will be used
69 * to create the interpreter instance.
70 * @return true if initialisation succeeds, false otherwise.
71 **/
72 bool Init(tflite::MicroAllocator* allocator = nullptr);
73
74 /**
75 * @brief Gets the allocator pointer for this instance.
76 * @return Pointer to a tflite::MicroAllocator object, if
77 * available; nullptr otherwise.
78 **/
79 tflite::MicroAllocator* GetAllocator();
80
81 /** @brief Checks if this object has been initialised. */
82 bool IsInited() const;
83
84 /** @brief Checks if the model uses signed data. */
85 bool IsDataSigned() const;
86
Cisco Cervellera02101092021-09-07 11:34:43 +010087 /** @brief Checks if the model uses Ethos-U operator */
88 bool ContainsEthosUOperator() const;
89
90 /** @brief Runs the inference (invokes the interpreter). */
91 virtual bool RunInference();
alexander3c798932021-03-26 21:42:19 +000092
93 /** @brief Model information handler common to all models.
94 * @return true or false based on execution success.
95 **/
96 bool ShowModelInfoHandler();
97
98 /** @brief Gets a pointer to the tensor arena. */
99 uint8_t* GetTensorArena();
100
101 protected:
102 /** @brief Gets the pointer to the NN model data array.
103 * @return Pointer of uint8_t type.
104 **/
105 virtual const uint8_t* ModelPointer() = 0;
106
107 /** @brief Gets the model size.
108 * @return size_t, size in bytes.
109 **/
110 virtual size_t ModelSize() = 0;
111
112 /**
113 * @brief Gets the op resolver for the model instance.
114 * @return const reference to a tflite::MicroOpResolver object.
115 **/
116 virtual const tflite::MicroOpResolver& GetOpResolver() = 0;
117
118 /**
119 * @brief Add all the operators required for the given model.
120 * Implementation of this should come from the use case.
121 * @return true is ops are successfully added, false otherwise.
122 **/
123 virtual bool EnlistOperations() = 0;
124
125 /** @brief Gets the total size of tensor arena available for use. */
126 size_t GetActivationBufferSize();
127
128 private:
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100129 tflite::MicroErrorReporter m_uErrorReporter; /* Error reporter object. */
130 tflite::ErrorReporter* m_pErrorReporter = nullptr; /* Pointer to the error reporter. */
131 const tflite::Model* m_pModel = nullptr; /* Tflite model pointer. */
132 tflite::MicroInterpreter* m_pInterpreter = nullptr; /* Tflite interpreter. */
133 tflite::MicroAllocator* m_pAllocator = nullptr; /* Tflite micro allocator. */
134 bool m_inited = false; /* Indicates whether this object has been initialised. */
alexander3c798932021-03-26 21:42:19 +0000135
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100136 std::vector<TfLiteTensor*> m_input = {}; /* Model's input tensor pointers. */
137 std::vector<TfLiteTensor*> m_output = {}; /* Model's output tensor pointers. */
138 TfLiteType m_type = kTfLiteNoType;/* Model's data type. */
alexander3c798932021-03-26 21:42:19 +0000139
140 };
141
142} /* namespace app */
143} /* namespace arm */
144
145#endif /* MODEL_HPP */