blob: dac06676bf0f9da6803c57038c3da10ee219330a [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
Tracy Narine10403ec2023-11-28 11:55:08 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Francis Murtagh9270d9e2022-08-12 13:54:17 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefLayerSupport.hpp"
Matthew Sloyan67fd5262022-12-07 19:28:18 +00007
Francis Murtagh9270d9e2022-08-12 13:54:17 +01008#include <tosaCommon/TosaMappings.hpp>
9
10#include <armnn/Types.hpp>
11#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan67fd5262022-12-07 19:28:18 +000012
13#include <graph_status.h>
14#include <model_runner.h>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015
16#include <vector>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010017
18namespace armnn
19{
20
Francis Murtagh9270d9e2022-08-12 13:54:17 +010021bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
22 const std::vector<TensorInfo>& infos,
23 const BaseDescriptor& descriptor,
24 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
25 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
26 Optional<std::string&> reasonIfUnsupported) const
27{
28 IgnoreUnused(lstmParamsInfo);
29 IgnoreUnused(quantizedLstmInputParamsInfo);
Matthew Sloyan67fd5262022-12-07 19:28:18 +000030 IgnoreUnused(reasonIfUnsupported);
Francis Murtagh9270d9e2022-08-12 13:54:17 +010031
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010032 std::vector<const TensorInfo*> inputInfos;
33 std::vector<const TensorInfo*> outputInfos;
Francis Murtagh9270d9e2022-08-12 13:54:17 +010034
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010035 switch (type)
36 {
Tracy Narine10403ec2023-11-28 11:55:08 +000037 case LayerType::Activation:
38 inputInfos.push_back(&infos[0]);
39 outputInfos.push_back(&infos[1]);
40 break;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000041 case LayerType::Input:
42 case LayerType::Output:
43 return true;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010044 case LayerType::Addition:
Nikhil Raj9a339462022-12-05 11:24:35 +000045 case LayerType::Multiplication:
46 case LayerType::Subtraction:
Tianle Cheng7790dc62023-12-12 13:52:22 +000047 case LayerType::ElementwiseBinary:
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010048 // Setup inputs and outputs
49 inputInfos.push_back(&infos[0]);
50 inputInfos.push_back(&infos[1]);
51 outputInfos.push_back(&infos[2]);
52 break;
Kevin May5b58e312022-12-15 10:15:21 +000053 case LayerType::Concat:
54 for (unsigned int i = 0; i < infos.size() - 1; ++i)
55 {
56 inputInfos.push_back(&infos[i]);
57 }
58 outputInfos.push_back(&infos.back());
59 break;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000060 case LayerType::Constant:
61 outputInfos.push_back(&infos[0]);
62 break;
63 case LayerType::Convolution2d:
64 {
65 inputInfos.push_back(&infos[0]); // input
66 outputInfos.push_back(&infos[1]); // output
67 inputInfos.push_back(&infos[2]); // weights
68
69 auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
70 if(conv2dDesc->m_BiasEnabled)
71 {
72 inputInfos.push_back(&infos[3]); // bias
73 }
74 break;
75 }
David Monahand7fca092023-01-12 14:53:34 +000076 case LayerType::ElementwiseUnary:
Cathal Corbettbd18eab2022-11-15 12:56:16 +000077 case LayerType::Pooling2d:
Teresa Charlinca5a23a2023-12-15 14:20:47 +000078 case LayerType::Quantize:
Cathal Corbettb30e6552022-12-07 11:50:50 +000079 case LayerType::Reshape:
Teresa Charlince655882023-11-21 15:44:13 +000080 case LayerType::Resize:
Cathal Corbett3b9acd52022-12-09 12:17:27 +000081 case LayerType::Slice:
Cathal Corbett0bb096d2022-12-22 13:09:38 +000082 case LayerType::Transpose:
Kevin May1bea6be2023-12-12 11:18:46 +000083 {
Cathal Corbettbd18eab2022-11-15 12:56:16 +000084 inputInfos.push_back(&infos[0]);
85 outputInfos.push_back(&infos[1]);
86 break;
Kevin May1bea6be2023-12-12 11:18:46 +000087 }
88 case LayerType::Splitter:
89 {
90 inputInfos.push_back(&infos[0]);
91 for (unsigned int i = 1; i < infos.size(); ++i)
92 {
93 outputInfos.push_back(&infos[i]);
94 }
95 break;
96 }
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +000097 case LayerType::TransposeConvolution2d:
98 {
99 inputInfos.push_back(&infos[0]); // input
100 outputInfos.push_back(&infos[1]); // output
101 inputInfos.push_back(&infos[2]); // weights
102
103 auto conv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
104 if(conv2dDesc->m_BiasEnabled)
105 {
106 inputInfos.push_back(&infos[3]); // bias
107 }
108 break;
109 }
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100110 default:
David Monahand7fca092023-01-12 14:53:34 +0000111 // Default to false for all unsupported layers.
112 return false;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100113 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100114
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000115 auto mappings = GetTosaMapping(nullptr, type, inputInfos, outputInfos, descriptor);
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100116 if (mappings->GetName() == "")
117 {
118 // There currently isn't a TOSA mapping for this layer, as the default was returned.
119 return false;
120 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100121
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000122 TosaSerializationHandler handler;
123
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100124 // Add all mappings to main block.
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000125 auto* block = new TosaSerializationBasicBlock("main",
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100126 "main",
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000127 mappings->GetOperators(),
128 mappings->GetTensors(),
129 mappings->GetInputs(),
130 mappings->GetOutputs());
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100131
132 std::vector<TosaSerializationBasicBlock*> blocks;
133 blocks.emplace_back(block);
134
135 // Add blocks to the main region.
136 auto* region = new TosaSerializationRegion("main", blocks);
137 handler.GetRegions().emplace_back(region);
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000138
139 GraphStatus status;
140 TosaReference::IModelRunner runner;
141
142#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
143 // There currently isn't a way to disable the output from the TOSA Reference Model, but it does have a file pointer
144 // to write debug output to, so set this to /dev/null (if it exists on the system) to hide the output.
145 func_debug_t funcDebug;
146
147 FILE* file = fopen("/dev/null", "w");
148 funcDebug.func_debug_file = (file == nullptr) ? stderr : file;
149
150 runner.setFuncDebug(funcDebug);
151#endif
152
153 // Initialise the model runner with the TosaSerializationHandler, which runs validation on the mapping.
154 status = runner.initialize(handler);
155
156#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
157 // Reset FuncDebug as they can persist across multiple IModelRunner instances.
158 funcDebug.func_debug_file = stderr;
159 runner.setFuncDebug(funcDebug);
160#endif
161
162 if(status == GraphStatus::TOSA_ERROR || status == GraphStatus::TOSA_UNPREDICTABLE)
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100163 {
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000164 return false;
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100165 }
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000166 else
167 {
168 return true;
169 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100170}
171
172} // namespace armnn