blob: eb01d1027ac411c3ef4ce2eff80e41317e6afb93 [file] [log] [blame]
Davide Grohmann144b2d22022-05-31 15:24:02 +02001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#ifndef NETWORKS_H
20#define NETWORKS_H
21
22#include <ethosu_core_interface.h>
23
24#include <cstdio>
25#include <inttypes.h>
26
27using namespace EthosU;
28
29namespace MessageHandler {
30
31class Networks {
32public:
33 virtual ~Networks() {}
34 virtual bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &size) = 0;
35};
36
37template <typename T>
38class BaseNetworks : public Networks {
39public:
40 bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &size) override {
41 switch (buffer.type) {
42 case ETHOSU_CORE_NETWORK_BUFFER:
43 data = reinterpret_cast<void *>(buffer.buffer.ptr);
44 size = buffer.buffer.size;
45 return false;
46 case ETHOSU_CORE_NETWORK_INDEX:
47 return T::getIndexedNetwork(buffer.index, data, size);
48 default:
49 printf("Error: Unsupported network model type. type=%" PRIu32 "\n", buffer.type);
50 return true;
51 }
52 }
53};
54
55class NoIndexedNetworks : public BaseNetworks<NoIndexedNetworks> {
56 static bool getIndexedNetwork(const uint32_t index, void *&data, size_t &size) {
57 printf("Error: Network model index out of range. index=%" PRIu32 "\n", index);
58 return true;
59 }
60};
61
62} // namespace MessageHandler
63
64#endif