blob: f0fac25577e0cd82bbd6d76448a1ed0e962b0e2b [file] [log] [blame]
Freddie Liardetdd23f2a2021-06-17 13:30:11 +01001/*
Jakub Sujak0d27b2e2023-08-24 14:01:20 +01002 * Copyright (c) 2021, 2023 Arm Limited.
Freddie Liardetdd23f2a2021-06-17 13:30:11 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph/DataLayerVisitor.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/graph/Graph.h"
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010028#include "arm_compute/graph/nodes/Nodes.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010029#include "arm_compute/graph/TypePrinter.h"
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010030
31namespace arm_compute
32{
33namespace graph
34{
35namespace
36{
37template <typename T>
38void add_convolution_layer_data(DataLayerVisitor::LayerData &layer_data, T &node)
39{
40 PadStrideInfo ps_info = node.convolution_info();
41 DataLayout layout = node.output(0)->desc().layout;
42 // Add data layout
43 layer_data["data_layout"] = to_string(layout);
44 // Add padding info
45 std::ostringstream padding;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010046 padding << "[" << to_string(ps_info.pad_left()) << "," << to_string(ps_info.pad_top()) << ","
47 << to_string(ps_info.pad_bottom()) << "," << to_string(ps_info.pad_right()) << "]";
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010048
49 layer_data["pad"] = padding.str();
50
51 // Add stride info
52 std::ostringstream stride;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010053 stride << "[" << to_string(ps_info.stride().first) << "," << to_string(ps_info.stride().second) << "]";
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010054
55 layer_data["stride"] = stride.str();
56
57 // Add dilation info
58 // graph api does not support dilation > 1
59 layer_data["dilation"] = "[1,1]";
60
61 // Add bias enabled?
62 // Assumes three inputs (input, weights, bias)
63 std::string bias_enabled = node.input(2) == nullptr ? "0" : "1";
64 layer_data["bias_enabled"] = bias_enabled;
65
66 // Change input names for weights / bias (if applicable)
67 // Assumes input(1) is weights and input(2) is bias
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010068 if (layer_data.count("input_shape1"))
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010069 {
70 layer_data["weights_shape"] = layer_data["input_shape1"];
71 layer_data.erase("input_shape1");
72 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010073 if (layer_data.count("input_shape2"))
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010074 {
75 layer_data["bias_shape"] = layer_data["input_shape2"];
76 layer_data.erase("input_shape2");
77 }
78}
79
80template <typename T>
81void add_convolution_layer_method(DataLayerVisitor::LayerData &layer_data, T &node)
82{
83 std::ostringstream method;
84 method << node.convolution_method();
85 layer_data["convolution_method"] = method.str();
86}
87
88template <typename T>
89void add_generic_layer_data(DataLayerVisitor::LayerData &layer_data, T &node)
90{
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010091 // Loop over each input tensor
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010092 for (size_t tensor_no = 0; tensor_no < node.num_inputs(); ++tensor_no)
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010093 {
94 // Add input tensor shapes
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010095 if (node.input(tensor_no) != nullptr)
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010096 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010097 layer_data["input_shape" + to_string(tensor_no)] =
98 "[" + to_string(node.input(tensor_no)->desc().shape) + "]";
Freddie Liardetdd23f2a2021-06-17 13:30:11 +010099 }
100 }
101 // Add output tensor shape
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100102 if (node.output(0) != nullptr)
Freddie Liardetdd23f2a2021-06-17 13:30:11 +0100103 {
104 layer_data["output_shape0"] = "[" + to_string(node.output(0)->desc().shape) + "]";
105 }
106}
107} // namespace
108
109void DataLayerVisitor::visit(ConvolutionLayerNode &n)
110{
111 _layer_data.clear();
112 add_generic_layer_data<ConvolutionLayerNode>(_layer_data, n);
113 add_convolution_layer_data<ConvolutionLayerNode>(_layer_data, n);
114 add_convolution_layer_method<ConvolutionLayerNode>(_layer_data, n);
115}
116
117void DataLayerVisitor::visit(DepthwiseConvolutionLayerNode &n)
118{
119 _layer_data.clear();
120 add_generic_layer_data<DepthwiseConvolutionLayerNode>(_layer_data, n);
121 add_convolution_layer_data<DepthwiseConvolutionLayerNode>(_layer_data, n);
122}
123
124void DataLayerVisitor::visit(FusedConvolutionBatchNormalizationNode &n)
125{
126 _layer_data.clear();
127 add_generic_layer_data<FusedConvolutionBatchNormalizationNode>(_layer_data, n);
128 add_convolution_layer_data<FusedConvolutionBatchNormalizationNode>(_layer_data, n);
129 add_convolution_layer_method<FusedConvolutionBatchNormalizationNode>(_layer_data, n);
130}
131
Freddie Liardetdd23f2a2021-06-17 13:30:11 +0100132void DataLayerVisitor::visit(FusedDepthwiseConvolutionBatchNormalizationNode &n)
133{
134 _layer_data.clear();
135 add_generic_layer_data<FusedDepthwiseConvolutionBatchNormalizationNode>(_layer_data, n);
136 add_convolution_layer_data<FusedDepthwiseConvolutionBatchNormalizationNode>(_layer_data, n);
137}
138
139void DataLayerVisitor::visit(OutputNode &n)
140{
141 _layer_data.clear();
142 ARM_COMPUTE_UNUSED(n);
143}
144
145void DataLayerVisitor::default_visit(INode &n)
146{
147 _layer_data.clear();
148 add_generic_layer_data<INode>(_layer_data, n);
149}
150
151const DataLayerVisitor::LayerData &DataLayerVisitor::layer_data() const
152{
153 return _layer_data;
154}
155} // namespace graph
156} // namespace arm_compute