blob: 70cf9ca159f7ccc1c26b2843354ac97d595d4805 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef MODEL_HPP
18#define MODEL_HPP
19
20#include "TensorFlowLiteMicro.hpp"
21#include "BufAttributes.hpp"
22
23#include <cstdint>
24
25namespace arm {
26namespace app {
27
28 /**
29 * @brief NN model class wrapping the underlying TensorFlow-Lite-Micro API.
30 */
31 class Model {
32 public:
33 /** @brief Constructor. */
34 Model();
35
36 /** @brief Destructor. */
37 ~Model();
38
39 /** @brief Gets the pointer to the model's input tensor at given input index. */
40 TfLiteTensor* GetInputTensor(size_t index) const;
41
42 /** @brief Gets the pointer to the model's output tensor at given output index. */
43 TfLiteTensor* GetOutputTensor(size_t index) const;
44
45 /** @brief Gets the model's data type. */
46 TfLiteType GetType() const;
47
48 /** @brief Gets the pointer to the model's input shape. */
49 TfLiteIntArray* GetInputShape(size_t index) const;
50
51 /** @brief Gets the pointer to the model's output shape at given output index. */
52 TfLiteIntArray* GetOutputShape(size_t index) const;
53
54 /** @brief Gets the number of input tensors the model has. */
55 size_t GetNumInputs() const;
56
57 /** @brief Gets the number of output tensors the model has. */
58 size_t GetNumOutputs() const;
59
60 /** @brief Logs the tensor information to stdout. */
61 void LogTensorInfo(TfLiteTensor* tensor);
62
63 /** @brief Logs the interpreter information to stdout. */
64 void LogInterpreterInfo();
65
66 /** @brief Initialise the model class object.
67 * @param[in] allocator Optional: a pre-initialised micro allocator pointer,
68 * if available. If supplied, this allocator will be used
69 * to create the interpreter instance.
70 * @return true if initialisation succeeds, false otherwise.
71 **/
72 bool Init(tflite::MicroAllocator* allocator = nullptr);
73
74 /**
75 * @brief Gets the allocator pointer for this instance.
76 * @return Pointer to a tflite::MicroAllocator object, if
77 * available; nullptr otherwise.
78 **/
79 tflite::MicroAllocator* GetAllocator();
80
81 /** @brief Checks if this object has been initialised. */
82 bool IsInited() const;
83
84 /** @brief Checks if the model uses signed data. */
85 bool IsDataSigned() const;
86
87 /** @brief Runs the inference (invokes the interpreter). */
88 bool RunInference();
89
90 /** @brief Model information handler common to all models.
91 * @return true or false based on execution success.
92 **/
93 bool ShowModelInfoHandler();
94
95 /** @brief Gets a pointer to the tensor arena. */
96 uint8_t* GetTensorArena();
97
98 protected:
99 /** @brief Gets the pointer to the NN model data array.
100 * @return Pointer of uint8_t type.
101 **/
102 virtual const uint8_t* ModelPointer() = 0;
103
104 /** @brief Gets the model size.
105 * @return size_t, size in bytes.
106 **/
107 virtual size_t ModelSize() = 0;
108
109 /**
110 * @brief Gets the op resolver for the model instance.
111 * @return const reference to a tflite::MicroOpResolver object.
112 **/
113 virtual const tflite::MicroOpResolver& GetOpResolver() = 0;
114
115 /**
116 * @brief Add all the operators required for the given model.
117 * Implementation of this should come from the use case.
118 * @return true is ops are successfully added, false otherwise.
119 **/
120 virtual bool EnlistOperations() = 0;
121
122 /** @brief Gets the total size of tensor arena available for use. */
123 size_t GetActivationBufferSize();
124
125 private:
126 tflite::MicroErrorReporter _m_uErrorReporter; /* Error reporter object. */
127 tflite::ErrorReporter* _m_pErrorReporter = nullptr; /* Pointer to the error reporter. */
128 const tflite::Model* _m_pModel = nullptr; /* Tflite model pointer. */
129 tflite::MicroInterpreter* _m_pInterpreter = nullptr; /* Tflite interpreter. */
130 tflite::MicroAllocator* _m_pAllocator = nullptr; /* Tflite micro allocator. */
131 bool _m_inited = false; /* Indicates whether this object has been initialised. */
132
133 std::vector<TfLiteTensor*> _m_input = {}; /* Model's input tensor pointers. */
134 std::vector<TfLiteTensor*> _m_output = {}; /* Model's output tensor pointers. */
135 TfLiteType _m_type = kTfLiteNoType;/* Model's data type. */
136
137 };
138
139} /* namespace app */
140} /* namespace arm */
141
142#endif /* MODEL_HPP */