blob: c33e1897970f2c73c493c4563573707929321b23 [file] [log] [blame]
/*
* Copyright (c) 2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
#error "This experimental feature must be enabled with -DENABLE_EXPERIMENTAL_DYNAMIC_FUSION"
#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */
#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL
#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL
#include "arm_compute/core/experimental/ClWorkload.h"
#include "src/core/experimental/dynamic_fusion/WorkloadImpl/ITensorDescPack.h"
#include "support/Cast.h"
#include "support/DeepCopy.h"
#include <map>
#include <tuple>
#include <type_traits>
namespace arm_compute
{
namespace experimental
{
namespace dynamic_fusion
{
enum class OperatorComplexity
{
Complex = 0,
Simple
};
struct ClKernelGraph;
struct OpTensorContent
{
public:
using Id = DependencyGraph::Id;
OpTensorContent() = default;
OpTensorContent(Id id)
: id{ id }, desc{}
{
}
OpTensorContent(Id id, ITensorInfo *desc)
: id{ id }, desc{ desc }
{
}
~OpTensorContent() = default;
OpTensorContent(const OpTensorContent &) = default;
OpTensorContent &operator=(const OpTensorContent &) = default;
OpTensorContent(OpTensorContent &&) = default;
OpTensorContent &operator=(OpTensorContent &&) = default;
bool operator==(const OpTensorContent &other) const
{
return desc == other.desc;
}
const ITensorInfo *get_tensor_info() const
{
return desc;
}
ITensorInfo *get_tensor_info()
{
return desc;
}
Id id{};
ITensorInfo *desc{};
};
struct OperatorContent
{
public:
using Id = DependencyGraph::Id;
OperatorContent() = default;
OperatorContent(const OperatorGraph::Implementation *graph, Id id, const ITensorDescPack<OpTensorContent> &tensors)
: _graph{ graph }, _id{ id }, _tensors{ tensors }
{
}
OperatorContent(const OperatorContent &op) = default;
OperatorContent &operator=(const OperatorContent &op) = default;
OperatorContent(OperatorContent &&op) = default;
OperatorContent &operator=(OperatorContent &&op) = default;
virtual ~OperatorContent() = default;
virtual OperatorComplexity complexity() const = 0;
virtual bool operator==(const OperatorContent &other) const = 0;
virtual Status translate(ClKernelGraph &kernel_graph) const = 0;
protected:
const OperatorGraph::Implementation *_graph {};
Id _id{};
ITensorDescPack<OpTensorContent> _tensors{};
};
struct Conv2dContent : public OperatorContent
{
public:
Conv2dContent() = default;
Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors)
: OperatorContent(graph, id, tensors), desc(desc), forced_method(), forced_method_enabled(false)
{
}
// Temporary. Do not need to pass ConvolutionMethod
Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors, ConvolutionMethod method)
: OperatorContent(graph, id, tensors), desc(desc), forced_method(method), forced_method_enabled(true)
{
}
~Conv2dContent() = default;
Conv2dContent(const Conv2dContent &) = default;
Conv2dContent &operator=(const Conv2dContent &) = default;
Conv2dContent(Conv2dContent &&) = default;
Conv2dContent &operator=(Conv2dContent &&) = default;
bool operator==(const OperatorContent &other) const override;
OperatorComplexity complexity() const override
{
return OperatorComplexity::Complex;
}
void set_method(ConvolutionMethod method)
{
forced_method_enabled = true;
forced_method = method;
}
Status translate(ClKernelGraph &kernel_graph) const override;
/** Replicate heuristics of @ref ClConv2d::get_convolution_method(), except that non-supported data types and data layouts are removed from the heuristics
*
* @param src
* @param weights
* @param dst
* @param conv2d_desc
* @param gpu_target
* @return ConvolutionMethod
*/
static ConvolutionMethod select_conv_method(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const Conv2dDescriptor &conv2d_desc, const GPUTarget gpu_target);
Conv2dDescriptor desc{};
ConvolutionMethod forced_method{ ConvolutionMethod::GEMM_CONV2D };
bool forced_method_enabled{ false };
private:
Status translate_direct_conv2d(ClKernelGraph &kernel_graph) const;
};
class AddContent : public OperatorContent
{
public:
AddContent() = default;
AddContent(const OperatorGraph::Implementation *graph, Id id, const AddDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors)
: OperatorContent(graph, id, tensors), desc(desc)
{
}
~AddContent() = default;
AddContent(const AddContent &) = default;
AddContent &operator=(const AddContent &) = default;
AddContent(AddContent &&) = default;
AddContent &operator=(AddContent &&) = default;
bool operator==(const OperatorContent &other) const override;
OperatorComplexity complexity() const override
{
return OperatorComplexity::Simple;
}
Status translate(ClKernelGraph &kernel_graph) const override;
private:
AddDescriptor desc{};
};
struct OperatorGraph::Implementation
{
public:
template <typename ContentT, typename... Args>
void add_node(Operator::Id id, Args &&... args)
{
operators[id] = utils::memory::make_deep_unique<OperatorContent, ContentT>(this, id, std::forward<Args>(args)...);
}
template <typename... Args>
void add_tensor(OpTensor::Id id, Args &&... args)
{
tensors[id] = utils::memory::make_deep_unique<OpTensorContent, OpTensorContent>(id, std::forward<Args>(args)...);
}
using Dependency = DependencyGraph;
using OperatorMap = std::map<Operator::Id, utils::memory::deep_unique_ptr<OperatorContent>>;
using OpTensorMap = std::map<OpTensor::Id, utils::memory::deep_unique_ptr<OpTensorContent>>;
Implementation() = default;
~Implementation() = default;
friend bool operator==(const OperatorGraph::Implementation &graph0, const OperatorGraph::Implementation &graph1)
{
return graph0.graph == graph1.graph && graph0.operators == graph1.operators && graph0.tensors == graph1.tensors;
}
Dependency graph{};
OperatorMap operators{};
OpTensorMap tensors{};
Status status{};
};
std::vector<const OperatorContent *> traverse(const OperatorGraph::Implementation &graph);
std::vector<OperatorContent *> traverse(OperatorGraph::Implementation &graph);
Status translate(ClKernelGraph &kernel_graph, const OperatorGraph::Implementation &op_graph);
} // namespace dynamic_fusion
} // namespace experimental
} // namespace arm_compute
#endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL