blob: 073ffd413def64eca2ee14648a94437cdbcadcc8 [file] [log] [blame]
Freddie Liardetdd23f2a2021-06-17 13:30:11 +01001/*
Jakub Sujak0d27b2e2023-08-24 14:01:20 +01002 * Copyright (c) 2021, 2023 Arm Limited.
Freddie Liardetdd23f2a2021-06-17 13:30:11 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph/DataLayerVisitor.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/graph/Graph.h"
28#include "arm_compute/graph/TypePrinter.h"
29#include "arm_compute/graph/nodes/Nodes.h"
30
31namespace arm_compute
32{
33namespace graph
34{
35namespace
36{
37template <typename T>
38void add_convolution_layer_data(DataLayerVisitor::LayerData &layer_data, T &node)
39{
40 PadStrideInfo ps_info = node.convolution_info();
41 DataLayout layout = node.output(0)->desc().layout;
42 // Add data layout
43 layer_data["data_layout"] = to_string(layout);
44 // Add padding info
45 std::ostringstream padding;
46 padding << "[" << to_string(ps_info.pad_left()) << ","
47 << to_string(ps_info.pad_top()) << ","
48 << to_string(ps_info.pad_bottom()) << ","
49 << to_string(ps_info.pad_right()) << "]";
50
51 layer_data["pad"] = padding.str();
52
53 // Add stride info
54 std::ostringstream stride;
55 stride << "[" << to_string(ps_info.stride().first) << ","
56 << to_string(ps_info.stride().second) << "]";
57
58 layer_data["stride"] = stride.str();
59
60 // Add dilation info
61 // graph api does not support dilation > 1
62 layer_data["dilation"] = "[1,1]";
63
64 // Add bias enabled?
65 // Assumes three inputs (input, weights, bias)
66 std::string bias_enabled = node.input(2) == nullptr ? "0" : "1";
67 layer_data["bias_enabled"] = bias_enabled;
68
69 // Change input names for weights / bias (if applicable)
70 // Assumes input(1) is weights and input(2) is bias
71 if(layer_data.count("input_shape1"))
72 {
73 layer_data["weights_shape"] = layer_data["input_shape1"];
74 layer_data.erase("input_shape1");
75 }
76 if(layer_data.count("input_shape2"))
77 {
78 layer_data["bias_shape"] = layer_data["input_shape2"];
79 layer_data.erase("input_shape2");
80 }
81}
82
83template <typename T>
84void add_convolution_layer_method(DataLayerVisitor::LayerData &layer_data, T &node)
85{
86 std::ostringstream method;
87 method << node.convolution_method();
88 layer_data["convolution_method"] = method.str();
89}
90
91template <typename T>
92void add_generic_layer_data(DataLayerVisitor::LayerData &layer_data, T &node)
93{
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010094 // Loop over each input tensor
95 for(size_t tensor_no = 0; tensor_no < node.num_inputs(); ++tensor_no)
96 {
97 // Add input tensor shapes
98 if(node.input(tensor_no) != nullptr)
99 {
100 layer_data["input_shape" + to_string(tensor_no)] = "[" + to_string(node.input(tensor_no)->desc().shape) + "]";
101 }
102 }
103 // Add output tensor shape
104 if(node.output(0) != nullptr)
105 {
106 layer_data["output_shape0"] = "[" + to_string(node.output(0)->desc().shape) + "]";
107 }
108}
109} // namespace
110
111void DataLayerVisitor::visit(ConvolutionLayerNode &n)
112{
113 _layer_data.clear();
114 add_generic_layer_data<ConvolutionLayerNode>(_layer_data, n);
115 add_convolution_layer_data<ConvolutionLayerNode>(_layer_data, n);
116 add_convolution_layer_method<ConvolutionLayerNode>(_layer_data, n);
117}
118
119void DataLayerVisitor::visit(DepthwiseConvolutionLayerNode &n)
120{
121 _layer_data.clear();
122 add_generic_layer_data<DepthwiseConvolutionLayerNode>(_layer_data, n);
123 add_convolution_layer_data<DepthwiseConvolutionLayerNode>(_layer_data, n);
124}
125
126void DataLayerVisitor::visit(FusedConvolutionBatchNormalizationNode &n)
127{
128 _layer_data.clear();
129 add_generic_layer_data<FusedConvolutionBatchNormalizationNode>(_layer_data, n);
130 add_convolution_layer_data<FusedConvolutionBatchNormalizationNode>(_layer_data, n);
131 add_convolution_layer_method<FusedConvolutionBatchNormalizationNode>(_layer_data, n);
132}
133
Freddie Liardetdd23f2a2021-06-17 13:30:11 +0100134void DataLayerVisitor::visit(FusedDepthwiseConvolutionBatchNormalizationNode &n)
135{
136 _layer_data.clear();
137 add_generic_layer_data<FusedDepthwiseConvolutionBatchNormalizationNode>(_layer_data, n);
138 add_convolution_layer_data<FusedDepthwiseConvolutionBatchNormalizationNode>(_layer_data, n);
139}
140
141void DataLayerVisitor::visit(OutputNode &n)
142{
143 _layer_data.clear();
144 ARM_COMPUTE_UNUSED(n);
145}
146
147void DataLayerVisitor::default_visit(INode &n)
148{
149 _layer_data.clear();
150 add_generic_layer_data<INode>(_layer_data, n);
151}
152
153const DataLayerVisitor::LayerData &DataLayerVisitor::layer_data() const
154{
155 return _layer_data;
156}
157} // namespace graph
158} // namespace arm_compute