blob: e5427ebc939dcbe7355821ffb6af4b79764cc0d1 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefLayerSupport.hpp"
Matthew Sloyan67fd5262022-12-07 19:28:18 +00007
Francis Murtagh9270d9e2022-08-12 13:54:17 +01008#include <tosaCommon/TosaMappings.hpp>
9
10#include <armnn/Types.hpp>
11#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan67fd5262022-12-07 19:28:18 +000012
13#include <graph_status.h>
14#include <model_runner.h>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010015
16#include <vector>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010017
18namespace armnn
19{
20
Francis Murtagh9270d9e2022-08-12 13:54:17 +010021bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
22 const std::vector<TensorInfo>& infos,
23 const BaseDescriptor& descriptor,
24 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
25 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
26 Optional<std::string&> reasonIfUnsupported) const
27{
28 IgnoreUnused(lstmParamsInfo);
29 IgnoreUnused(quantizedLstmInputParamsInfo);
Matthew Sloyan67fd5262022-12-07 19:28:18 +000030 IgnoreUnused(reasonIfUnsupported);
Francis Murtagh9270d9e2022-08-12 13:54:17 +010031
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010032 std::vector<const TensorInfo*> inputInfos;
33 std::vector<const TensorInfo*> outputInfos;
Francis Murtagh9270d9e2022-08-12 13:54:17 +010034
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010035 switch (type)
36 {
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000037 case LayerType::Input:
38 case LayerType::Output:
39 return true;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010040 case LayerType::Addition:
41 // Setup inputs and outputs
42 inputInfos.push_back(&infos[0]);
43 inputInfos.push_back(&infos[1]);
44 outputInfos.push_back(&infos[2]);
45 break;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000046 case LayerType::Constant:
47 outputInfos.push_back(&infos[0]);
48 break;
49 case LayerType::Convolution2d:
50 {
51 inputInfos.push_back(&infos[0]); // input
52 outputInfos.push_back(&infos[1]); // output
53 inputInfos.push_back(&infos[2]); // weights
54
55 auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
56 if(conv2dDesc->m_BiasEnabled)
57 {
58 inputInfos.push_back(&infos[3]); // bias
59 }
60 break;
61 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +000062 case LayerType::Pooling2d:
Cathal Corbettb30e6552022-12-07 11:50:50 +000063 case LayerType::Reshape:
Cathal Corbett3b9acd52022-12-09 12:17:27 +000064 case LayerType::Slice:
Cathal Corbettbd18eab2022-11-15 12:56:16 +000065 // Setup inputs and outputs
66 inputInfos.push_back(&infos[0]);
67 outputInfos.push_back(&infos[1]);
68 break;
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +000069 case LayerType::TransposeConvolution2d:
70 {
71 inputInfos.push_back(&infos[0]); // input
72 outputInfos.push_back(&infos[1]); // output
73 inputInfos.push_back(&infos[2]); // weights
74
75 auto conv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
76 if(conv2dDesc->m_BiasEnabled)
77 {
78 inputInfos.push_back(&infos[3]); // bias
79 }
80 break;
81 }
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010082 default:
83 break;
84 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010085
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000086 auto mappings = GetTosaMapping(nullptr, type, inputInfos, outputInfos, descriptor);
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010087 if (mappings->GetName() == "")
88 {
89 // There currently isn't a TOSA mapping for this layer, as the default was returned.
90 return false;
91 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +010092
Matthew Sloyan67fd5262022-12-07 19:28:18 +000093 TosaSerializationHandler handler;
94
95 // Add mappings to main block as the TOSA Reference Model requires the graph to be in one block called main.
96 auto* block = new TosaSerializationBasicBlock("main",
97 mappings->GetOperators(),
98 mappings->GetTensors(),
99 mappings->GetInputs(),
100 mappings->GetOutputs());
101 handler.GetBlocks().emplace_back(block);
102
103 GraphStatus status;
104 TosaReference::IModelRunner runner;
105
106#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
107 // There currently isn't a way to disable the output from the TOSA Reference Model, but it does have a file pointer
108 // to write debug output to, so set this to /dev/null (if it exists on the system) to hide the output.
109 func_debug_t funcDebug;
110
111 FILE* file = fopen("/dev/null", "w");
112 funcDebug.func_debug_file = (file == nullptr) ? stderr : file;
113
114 runner.setFuncDebug(funcDebug);
115#endif
116
117 // Initialise the model runner with the TosaSerializationHandler, which runs validation on the mapping.
118 status = runner.initialize(handler);
119
120#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
121 // Reset FuncDebug as they can persist across multiple IModelRunner instances.
122 funcDebug.func_debug_file = stderr;
123 runner.setFuncDebug(funcDebug);
124#endif
125
126 if(status == GraphStatus::TOSA_ERROR || status == GraphStatus::TOSA_UNPREDICTABLE)
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100127 {
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000128 return false;
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100129 }
Matthew Sloyan67fd5262022-12-07 19:28:18 +0000130 else
131 {
132 return true;
133 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100134}
135
136} // namespace armnn