blob: 70c62453ba7fe6cb5abfffcac9a4b43516d7dbad [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +01002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef MODEL_HPP
18#define MODEL_HPP
19
20#include "TensorFlowLiteMicro.hpp"
alexander3c798932021-03-26 21:42:19 +000021
22#include <cstdint>
23
24namespace arm {
25namespace app {
26
27 /**
28 * @brief NN model class wrapping the underlying TensorFlow-Lite-Micro API.
29 */
30 class Model {
31 public:
32 /** @brief Constructor. */
33 Model();
34
35 /** @brief Destructor. */
36 ~Model();
37
38 /** @brief Gets the pointer to the model's input tensor at given input index. */
39 TfLiteTensor* GetInputTensor(size_t index) const;
40
41 /** @brief Gets the pointer to the model's output tensor at given output index. */
42 TfLiteTensor* GetOutputTensor(size_t index) const;
43
44 /** @brief Gets the model's data type. */
45 TfLiteType GetType() const;
46
47 /** @brief Gets the pointer to the model's input shape. */
48 TfLiteIntArray* GetInputShape(size_t index) const;
49
50 /** @brief Gets the pointer to the model's output shape at given output index. */
51 TfLiteIntArray* GetOutputShape(size_t index) const;
52
53 /** @brief Gets the number of input tensors the model has. */
54 size_t GetNumInputs() const;
55
56 /** @brief Gets the number of output tensors the model has. */
57 size_t GetNumOutputs() const;
58
59 /** @brief Logs the tensor information to stdout. */
60 void LogTensorInfo(TfLiteTensor* tensor);
61
62 /** @brief Logs the interpreter information to stdout. */
63 void LogInterpreterInfo();
64
65 /** @brief Initialise the model class object.
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010066 * @param[in] tensorArenaAddress Pointer to the tensor arena buffer.
67 * @param[in] tensorArenaAddress Size of the tensor arena buffer in bytes.
68 * @param[in] nnModelAddr Pointer to the model.
69 * @param[in] nnModelSize Size of the model in bytes, if known.
alexander3c798932021-03-26 21:42:19 +000070 * @param[in] allocator Optional: a pre-initialised micro allocator pointer,
71 * if available. If supplied, this allocator will be used
72 * to create the interpreter instance.
73 * @return true if initialisation succeeds, false otherwise.
74 **/
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010075 bool Init(uint8_t* tensorArenaAddr,
76 uint32_t tensorArenaSize,
Kshitij Sisodia937052d2022-05-13 16:44:16 +010077 const uint8_t* nnModelAddr,
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010078 uint32_t nnModelSize,
79 tflite::MicroAllocator* allocator = nullptr);
alexander3c798932021-03-26 21:42:19 +000080
81 /**
82 * @brief Gets the allocator pointer for this instance.
83 * @return Pointer to a tflite::MicroAllocator object, if
84 * available; nullptr otherwise.
85 **/
86 tflite::MicroAllocator* GetAllocator();
87
88 /** @brief Checks if this object has been initialised. */
89 bool IsInited() const;
90
91 /** @brief Checks if the model uses signed data. */
92 bool IsDataSigned() const;
93
Cisco Cervellera02101092021-09-07 11:34:43 +010094 /** @brief Checks if the model uses Ethos-U operator */
95 bool ContainsEthosUOperator() const;
96
97 /** @brief Runs the inference (invokes the interpreter). */
98 virtual bool RunInference();
alexander3c798932021-03-26 21:42:19 +000099
100 /** @brief Model information handler common to all models.
101 * @return true or false based on execution success.
102 **/
103 bool ShowModelInfoHandler();
104
105 /** @brief Gets a pointer to the tensor arena. */
106 uint8_t* GetTensorArena();
107
108 protected:
109 /** @brief Gets the pointer to the NN model data array.
110 * @return Pointer of uint8_t type.
111 **/
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100112 const uint8_t* ModelPointer();
alexander3c798932021-03-26 21:42:19 +0000113
114 /** @brief Gets the model size.
115 * @return size_t, size in bytes.
116 **/
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100117 uint32_t ModelSize();
alexander3c798932021-03-26 21:42:19 +0000118
119 /**
120 * @brief Gets the op resolver for the model instance.
121 * @return const reference to a tflite::MicroOpResolver object.
122 **/
123 virtual const tflite::MicroOpResolver& GetOpResolver() = 0;
124
125 /**
126 * @brief Add all the operators required for the given model.
127 * Implementation of this should come from the use case.
128 * @return true is ops are successfully added, false otherwise.
129 **/
130 virtual bool EnlistOperations() = 0;
131
132 /** @brief Gets the total size of tensor arena available for use. */
133 size_t GetActivationBufferSize();
134
135 private:
Kshitij Sisodia937052d2022-05-13 16:44:16 +0100136 tflite::ErrorReporter* m_pErrorReporter{nullptr}; /* Pointer to the error reporter. */
137 const tflite::Model* m_pModel{nullptr}; /* Tflite model pointer. */
138 tflite::MicroInterpreter* m_pInterpreter{nullptr}; /* Tflite interpreter. */
139 tflite::MicroAllocator* m_pAllocator{nullptr}; /* Tflite micro allocator. */
140 bool m_inited{false}; /* Indicates whether this object has been initialised. */
141 const uint8_t* m_modelAddr{nullptr}; /* Model address */
142 uint32_t m_modelSize{0}; /* Model size */
alexander3c798932021-03-26 21:42:19 +0000143
Kshitij Sisodia937052d2022-05-13 16:44:16 +0100144 std::vector<TfLiteTensor*> m_input{}; /* Model's input tensor pointers. */
145 std::vector<TfLiteTensor*> m_output{}; /* Model's output tensor pointers. */
146 TfLiteType m_type{kTfLiteNoType}; /* Model's data type. */
alexander3c798932021-03-26 21:42:19 +0000147 };
148
149} /* namespace app */
150} /* namespace arm */
151
152#endif /* MODEL_HPP */