blob: 567a4023c0cde1e2d065eaa9285204366c052bb5 [file] [log] [blame]
SiCongLi1af54162021-10-06 15:25:57 +01001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_EXPERIMENTAL_IPOSTOP
25#define ARM_COMPUTE_EXPERIMENTAL_IPOSTOP
26
27#include <memory>
28#include <numeric>
29#include <vector>
30
31namespace arm_compute
32{
33namespace experimental
34{
35/** Type of Post Op */
36enum class PostOpType
37{
38 Activation,
39 Eltwise_Add,
ramelg016049eda2021-10-29 10:52:53 +010040 Eltwise_PRelu
SiCongLi1af54162021-10-06 15:25:57 +010041};
42/** An ordered sequence of type of Post Ops */
43using PostOpTypeSequence = std::vector<PostOpType>;
44/** An elementwise n-ary operation that can be appended to and fused with (at kernel-level) other operators
45 * It contains:
46 * 1. The attributes of the original operator.
47 * 2. Any additional tensor argument.
SiCongLieb8bd812021-10-29 15:05:49 +010048 * 3. The position of the previous op's dst tensor in its argument list ( @ref prev_dst_pos )
SiCongLi1af54162021-10-06 15:25:57 +010049 *
50 * For example, a series of chained ops:
51 *
52 * div(src1, relu(conv(src0, weights, bias, conv_info), act_info), div_info)
53 *
54 * translates to
55 *
56 * dst = conv(src0, weights, bias, conv_info) // main op
57 * dst = relu(dst, act_info) // previous dst is placed in the first (and only) argument
58 * dst = div(src1, dst, div_info) // previous dst is placed in the second argument
59 *
60 * which in turn translates to:
61 *
62 * main op: conv(src0, weights, bias, conv_info)
63 * post op1: relu(act_info, prev_dst_pos = 0)
64 * post op2: div(div_info, src1, prev_dst_pos = 1)
65 *
SiCongLieb8bd812021-10-29 15:05:49 +010066 * @note: On Broadcasting
67 * For n-ary post ops, the tensor arguments must not "widen" the dst tensor of the main op
68 * For example, for a dst of shape [14, 1, 34]:
69 * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dim 0
70 * * post_op_arg1 = [14, 1, 34] is allowed: no broadcast
71 * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dims 0 and 1
72 * * post_op_arg1 = [14, 15, 34] is NOT allowed: broadcast widens the dst tensor
73 *
SiCongLid9287352021-11-03 19:01:22 +000074 * @note: On Data layout
75 * All post ops are data layout agnostic. This means post ops do not have an inherent idea of "width", "height" and so on.
76 * Should we want to perform a post op with 2 tensors of different data layouts (where data layouts are significant to both),
77 * then we need to perform necessary permutation op beforehand to unify their data layout before they can be fused with a post op
78 *
79 * Note although post ops themselves should be able to support any data layout, the main op they fuse to may impose
80 * additional restrictions in the presence of post ops. For example, the implementation of a gemm op may only allow
81 * NHWC data layout if post ops are provided. Such restrictions are main op implementation specific.
82 *
SiCongLieb8bd812021-10-29 15:05:49 +010083 * @note: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type
84 * @note: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime
SiCongLi1af54162021-10-06 15:25:57 +010085 * and the lifetime of its copies. This is almost guaranteed as IPostOp is only meant to be used at configure time
86 * after the ITensor or ITensorInfo objects are already constructed
87 */
88template <typename TensorRelatedT>
89struct IPostOp
90{
91 /** Get the arity of the post op
SiCongLieb8bd812021-10-29 15:05:49 +010092 * @note: that this is one fewer than the arity of the original op, because we implicitly pass the previous op's dst
SiCongLi1af54162021-10-06 15:25:57 +010093 * tensor as one of the arguments
94 */
95 size_t arity() const
96 {
97 return arguments().size();
98 }
99 /** The position of previous op's dst in current op's argument list */
100 virtual int prev_dst_pos() const = 0;
101 /** The IPostOp type */
102 virtual PostOpType type() const = 0;
103 /** The argument tensors
104 * The order of the argument tensor is strictly preserved
105 */
106 virtual std::vector<TensorRelatedT *> arguments() = 0;
107 virtual std::vector<const TensorRelatedT *> arguments() const = 0;
108 /** Clone method used in cases where PostOps are owned by unique_ptr
SiCongLieb8bd812021-10-29 15:05:49 +0100109 * @note: This performs a shallow copy of the TensorRelatedT if TensorRelatedT points to a resource
SiCongLi1af54162021-10-06 15:25:57 +0100110 */
111 virtual std::unique_ptr<IPostOp<TensorRelatedT>> clone() const = 0;
112 virtual ~IPostOp()
113 {
114 }
115};
116
117/** A sequence of PostOps that can be appended to the end of other operators */
118template <typename TensorRelatedT>
119class PostOpList
120{
121public:
122 /** Constructor */
123 PostOpList() = default;
124 /** Destructor */
125 ~PostOpList() = default;
126 PostOpList(const PostOpList &other)
127 {
128 for(const auto &op : other._post_ops)
129 {
130 this->_post_ops.push_back(op->clone());
131 }
132 }
133 PostOpList &operator=(const PostOpList &other)
134 {
135 PostOpList tmp{ other };
136 std::swap(tmp, *this);
137 return *this;
138 }
139 PostOpList(PostOpList &&other) = default;
140 PostOpList &operator=(PostOpList &&other) = default;
141
142 /** Add a new post op at the end of the list */
143 template <typename OpT, typename... Args>
144 void push_back_op(Args &&... args)
145 {
146 _post_ops.push_back(std::make_unique<OpT>(std::forward<Args>(args)...));
147 }
148
149 /** Number of post ops */
150 size_t size() const
151 {
152 return _post_ops.size();
153 }
154
155 /** Total number of post ops */
156 size_t total_num_arguments() const
157 {
158 return std::accumulate(_post_ops.begin(), _post_ops.end(), 0, [](size_t op1_arity, const auto & op2)
159 {
160 return op1_arity + op2->arity();
161 });
162 }
163
164 /** Get the underlying post op list */
165 std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> &get_list()
166 {
167 return _post_ops;
168 }
169 const std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> &get_list() const
170 {
171 return _post_ops;
172 }
173
174private:
175 std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> _post_ops{};
176};
177
178} // namespace experimental
179} // namespace arm_compute
ramelg016049eda2021-10-29 10:52:53 +0100180#endif //ARM_COMPUTE_EXPERIMENTAL_IPOSTOP