blob: 60d0f7c6d6004903e4fb68e4ff14b5bd727a337e [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
David Monahand7fca092023-01-12 14:53:34 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Francis Murtagh9270d9e2022-08-12 13:54:17 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefLayerSupport.hpp"
Matthew Sloyan67fd5262022-12-07 19:28:18 +00007
Francis Murtagh9270d9e2022-08-12 13:54:17 +01008#include <tosaCommon/TosaMappings.hpp>
9
10#include <armnn/Types.hpp>
11#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan67fd5262022-12-07 19:28:18 +000012
13#include <graph_status.h>
14#include <model_runner.h>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015
16#include <vector>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010017
18namespace armnn
19{
20
Francis Murtagh9270d9e2022-08-12 13:54:17 +010021bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
22 const std::vector<TensorInfo>& infos,
23 const BaseDescriptor& descriptor,
24 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
25 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
26 Optional<std::string&> reasonIfUnsupported) const
27{
28 IgnoreUnused(lstmParamsInfo);
29 IgnoreUnused(quantizedLstmInputParamsInfo);
Matthew Sloyan67fd5262022-12-07 19:28:18 +000030 IgnoreUnused(reasonIfUnsupported);
Francis Murtagh9270d9e2022-08-12 13:54:17 +010031
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010032 std::vector<const TensorInfo*> inputInfos;
33 std::vector<const TensorInfo*> outputInfos;
Francis Murtagh9270d9e2022-08-12 13:54:17 +010034
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010035 switch (type)
36 {
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000037 case LayerType::Input:
38 case LayerType::Output:
39 return true;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010040 case LayerType::Addition:
Nikhil Raj9a339462022-12-05 11:24:35 +000041 case LayerType::Multiplication:
42 case LayerType::Subtraction:
Tianle Cheng7790dc62023-12-12 13:52:22 +000043 case LayerType::ElementwiseBinary:
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010044 // Setup inputs and outputs
45 inputInfos.push_back(&infos[0]);
46 inputInfos.push_back(&infos[1]);
47 outputInfos.push_back(&infos[2]);
48 break;
Kevin May5b58e312022-12-15 10:15:21 +000049 case LayerType::Concat:
50 for (unsigned int i = 0; i < infos.size() - 1; ++i)
51 {
52 inputInfos.push_back(&infos[i]);
53 }
54 outputInfos.push_back(&infos.back());
55 break;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000056 case LayerType::Constant:
57 outputInfos.push_back(&infos[0]);
58 break;
59 case LayerType::Convolution2d:
60 {
61 inputInfos.push_back(&infos[0]); // input
62 outputInfos.push_back(&infos[1]); // output
63 inputInfos.push_back(&infos[2]); // weights
64
65 auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
66 if(conv2dDesc->m_BiasEnabled)
67 {
68 inputInfos.push_back(&infos[3]); // bias
69 }
70 break;
71 }
David Monahand7fca092023-01-12 14:53:34 +000072 case LayerType::ElementwiseUnary:
Cathal Corbettbd18eab2022-11-15 12:56:16 +000073 case LayerType::Pooling2d:
Cathal Corbettb30e6552022-12-07 11:50:50 +000074 case LayerType::Reshape:
Teresa Charlince655882023-11-21 15:44:13 +000075 case LayerType::Resize:
Cathal Corbett3b9acd52022-12-09 12:17:27 +000076 case LayerType::Slice:
Cathal Corbett0bb096d2022-12-22 13:09:38 +000077 case LayerType::Transpose:
Kevin May1bea6be2023-12-12 11:18:46 +000078 {
Cathal Corbettbd18eab2022-11-15 12:56:16 +000079 inputInfos.push_back(&infos[0]);
80 outputInfos.push_back(&infos[1]);
81 break;
Kevin May1bea6be2023-12-12 11:18:46 +000082 }
83 case LayerType::Splitter:
84 {
85 inputInfos.push_back(&infos[0]);
86 for (unsigned int i = 1; i < infos.size(); ++i)
87 {
88 outputInfos.push_back(&infos[i]);
89 }
90 break;
91 }
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +000092 case LayerType::TransposeConvolution2d:
93 {
94 inputInfos.push_back(&infos[0]); // input
95 outputInfos.push_back(&infos[1]); // output
96 inputInfos.push_back(&infos[2]); // weights
97
98 auto conv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
99 if(conv2dDesc->m_BiasEnabled)
100 {
101 inputInfos.push_back(&infos[3]); // bias
102 }
103 break;
104 }
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100105 default:
David Monahand7fca092023-01-12 14:53:34 +0000106 // Default to false for all unsupported layers.
107 return false;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100108 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100109
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000110 auto mappings = GetTosaMapping(nullptr, type, inputInfos, outputInfos, descriptor);
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100111 if (mappings->GetName() == "")
112 {
113 // There currently isn't a TOSA mapping for this layer, as the default was returned.
114 return false;
115 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100116
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000117 TosaSerializationHandler handler;
118
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100119 // Add all mappings to main block.
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000120 auto* block = new TosaSerializationBasicBlock("main",
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100121 "main",
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000122 mappings->GetOperators(),
123 mappings->GetTensors(),
124 mappings->GetInputs(),
125 mappings->GetOutputs());
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100126
127 std::vector<TosaSerializationBasicBlock*> blocks;
128 blocks.emplace_back(block);
129
130 // Add blocks to the main region.
131 auto* region = new TosaSerializationRegion("main", blocks);
132 handler.GetRegions().emplace_back(region);
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000133
134 GraphStatus status;
135 TosaReference::IModelRunner runner;
136
137#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
138 // There currently isn't a way to disable the output from the TOSA Reference Model, but it does have a file pointer
139 // to write debug output to, so set this to /dev/null (if it exists on the system) to hide the output.
140 func_debug_t funcDebug;
141
142 FILE* file = fopen("/dev/null", "w");
143 funcDebug.func_debug_file = (file == nullptr) ? stderr : file;
144
145 runner.setFuncDebug(funcDebug);
146#endif
147
148 // Initialise the model runner with the TosaSerializationHandler, which runs validation on the mapping.
149 status = runner.initialize(handler);
150
151#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
152 // Reset FuncDebug as they can persist across multiple IModelRunner instances.
153 funcDebug.func_debug_file = stderr;
154 runner.setFuncDebug(funcDebug);
155#endif
156
157 if(status == GraphStatus::TOSA_ERROR || status == GraphStatus::TOSA_UNPREDICTABLE)
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100158 {
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000159 return false;
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100160 }
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000161 else
162 {
163 return true;
164 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100165}
166
167} // namespace armnn