blob: 57871487ef25f130b909bd303ebeff34fbac0cb8 [file] [log] [blame]
Georgios Pinitasd8734b52017-12-22 15:27:52 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010024#include "arm_compute/graph/backends/CL/CLFunctionFactory.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000025
26#include "arm_compute/core/utils/misc/Cast.h"
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Graph.h"
Georgios Pinitasda2491f2018-06-01 17:49:09 +010028#include "arm_compute/graph/backends/FunctionHelpers.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000029#include "arm_compute/runtime/CL/CLFunctions.h"
30
Georgios Pinitasd8734b52017-12-22 15:27:52 +000031using namespace arm_compute::utils::cast;
32
33namespace arm_compute
34{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010035namespace graph
Georgios Pinitasd8734b52017-12-22 15:27:52 +000036{
37namespace backends
38{
Georgios Pinitasda2491f2018-06-01 17:49:09 +010039/** Target specific information structure used to pass information to the layer templates */
40struct CLTargetInfo
Georgios Pinitasd8734b52017-12-22 15:27:52 +000041{
Georgios Pinitasda2491f2018-06-01 17:49:09 +010042 using TensorType = arm_compute::ICLTensor;
43 static Target TargetType;
44};
45
46Target CLTargetInfo::TargetType = Target::CL;
47
48/** Collection of CL convolution functions */
49struct CLConvolutionLayerFunctions
Georgios Pinitasd8734b52017-12-22 15:27:52 +000050{
Georgios Pinitasda2491f2018-06-01 17:49:09 +010051 using GenericConvolutionLayer = CLConvolutionLayer;
52 using GEMMConvolutionLayer = CLGEMMConvolutionLayer;
53 using DirectConvolutionLayer = CLDirectConvolutionLayer;
54 using WinogradConvolutionLayer = CLWinogradConvolutionLayer;
55};
Georgios Pinitasd8734b52017-12-22 15:27:52 +000056
Georgios Pinitasda2491f2018-06-01 17:49:09 +010057/** Collection of CL depthwise convolution functions */
58struct CLDepthwiseConvolutionLayerFunctions
Georgios Pinitasd8734b52017-12-22 15:27:52 +000059{
Georgios Pinitasda2491f2018-06-01 17:49:09 +010060 using GenericDepthwiseConvolutionLayer = CLDepthwiseConvolutionLayer;
61 using DepthwiseConvolutionLayer3x3 = CLDepthwiseConvolutionLayer3x3;
62};
Georgios Pinitasd8734b52017-12-22 15:27:52 +000063
Georgios Pinitasda2491f2018-06-01 17:49:09 +010064/** Collection of CL element-wise functions */
65struct CLEltwiseFunctions
Georgios Pinitasd8734b52017-12-22 15:27:52 +000066{
Georgios Pinitasda2491f2018-06-01 17:49:09 +010067 using Addition = CLArithmeticAddition;
68 using Subtraction = CLArithmeticSubtraction;
69 using Multiplication = CLPixelWiseMultiplication;
70};
Georgios Pinitasd8734b52017-12-22 15:27:52 +000071
72std::unique_ptr<IFunction> CLFunctionFactory::create(INode *node, GraphContext &ctx)
73{
74 if(node == nullptr)
75 {
76 return nullptr;
77 }
78
79 NodeType type = node->type();
80 switch(type)
81 {
82 case NodeType::ActivationLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010083 return detail::create_activation_layer<CLActivationLayer, CLTargetInfo>(*polymorphic_downcast<ActivationLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +000084 case NodeType::BatchNormalizationLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010085 return detail::create_batch_normalization_layer<CLBatchNormalizationLayer, CLTargetInfo>(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
Georgios Pinitas087eaf62018-05-16 15:52:35 +010086 case NodeType::ChannelShuffleLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010087 return detail::create_channel_shuffle_layer<CLChannelShuffleLayer, CLTargetInfo>(*polymorphic_downcast<ChannelShuffleLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +000088 case NodeType::ConvolutionLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010089 return detail::create_convolution_layer<CLConvolutionLayerFunctions, CLTargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
Georgios Pinitas087eaf62018-05-16 15:52:35 +010090 case NodeType::DeconvolutionLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010091 return detail::create_deconvolution_layer<CLDeconvolutionLayer, CLTargetInfo>(*polymorphic_downcast<DeconvolutionLayerNode *>(node), ctx);
Georgios Pinitase2220552018-07-20 13:23:44 +010092 case NodeType::ConcatenateLayer:
93 return detail::create_concatenate_layer<CLConcatenateLayer, CLTargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +000094 case NodeType::DepthwiseConvolutionLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010095 return detail::create_depthwise_convolution_layer<CLDepthwiseConvolutionLayerFunctions, CLTargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +000096 case NodeType::EltwiseLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010097 return detail::create_eltwise_layer<CLEltwiseFunctions, CLTargetInfo>(*polymorphic_downcast<EltwiseLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +000098 case NodeType::FlattenLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +010099 return detail::create_flatten_layer<CLFlattenLayer, CLTargetInfo>(*polymorphic_downcast<FlattenLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000100 case NodeType::FullyConnectedLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +0100101 return detail::create_fully_connected_layer<CLFullyConnectedLayer, CLTargetInfo>(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000102 case NodeType::NormalizationLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +0100103 return detail::create_normalization_layer<CLNormalizationLayer, CLTargetInfo>(*polymorphic_downcast<NormalizationLayerNode *>(node), ctx);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000104 case NodeType::PoolingLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +0100105 return detail::create_pooling_layer<CLPoolingLayer, CLTargetInfo>(*polymorphic_downcast<PoolingLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000106 case NodeType::ReshapeLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +0100107 return detail::create_reshape_layer<CLReshapeLayer, CLTargetInfo>(*polymorphic_downcast<ReshapeLayerNode *>(node));
Georgios Pinitas087eaf62018-05-16 15:52:35 +0100108 case NodeType::ResizeLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +0100109 return detail::create_resize_layer<CLScale, CLTargetInfo>(*polymorphic_downcast<ResizeLayerNode *>(node));
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000110 case NodeType::SoftmaxLayer:
Georgios Pinitasda2491f2018-06-01 17:49:09 +0100111 return detail::create_softmax_layer<CLSoftmaxLayer, CLTargetInfo>(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000112 default:
113 return nullptr;
114 }
115}
116} // namespace backends
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100117} // namespace graph
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000118} // namespace arm_compute