SiCongLi | 1af5416 | 2021-10-06 15:25:57 +0100 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (c) 2021 Arm Limited. |
| 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
| 24 | #ifndef ARM_COMPUTE_EXPERIMENTAL_IPOSTOP |
| 25 | #define ARM_COMPUTE_EXPERIMENTAL_IPOSTOP |
| 26 | |
| 27 | #include <memory> |
| 28 | #include <numeric> |
| 29 | #include <vector> |
| 30 | |
| 31 | namespace arm_compute |
| 32 | { |
| 33 | namespace experimental |
| 34 | { |
| 35 | /** Type of Post Op */ |
| 36 | enum class PostOpType |
| 37 | { |
| 38 | Activation, |
| 39 | Eltwise_Add, |
ramelg01 | 6049eda | 2021-10-29 10:52:53 +0100 | [diff] [blame] | 40 | Eltwise_PRelu |
SiCongLi | 1af5416 | 2021-10-06 15:25:57 +0100 | [diff] [blame] | 41 | }; |
| 42 | /** An ordered sequence of type of Post Ops */ |
| 43 | using PostOpTypeSequence = std::vector<PostOpType>; |
| 44 | /** An elementwise n-ary operation that can be appended to and fused with (at kernel-level) other operators |
| 45 | * It contains: |
| 46 | * 1. The attributes of the original operator. |
| 47 | * 2. Any additional tensor argument. |
SiCongLi | eb8bd81 | 2021-10-29 15:05:49 +0100 | [diff] [blame] | 48 | * 3. The position of the previous op's dst tensor in its argument list ( @ref prev_dst_pos ) |
SiCongLi | 1af5416 | 2021-10-06 15:25:57 +0100 | [diff] [blame] | 49 | * |
| 50 | * For example, a series of chained ops: |
| 51 | * |
| 52 | * div(src1, relu(conv(src0, weights, bias, conv_info), act_info), div_info) |
| 53 | * |
| 54 | * translates to |
| 55 | * |
| 56 | * dst = conv(src0, weights, bias, conv_info) // main op |
| 57 | * dst = relu(dst, act_info) // previous dst is placed in the first (and only) argument |
| 58 | * dst = div(src1, dst, div_info) // previous dst is placed in the second argument |
| 59 | * |
| 60 | * which in turn translates to: |
| 61 | * |
| 62 | * main op: conv(src0, weights, bias, conv_info) |
| 63 | * post op1: relu(act_info, prev_dst_pos = 0) |
| 64 | * post op2: div(div_info, src1, prev_dst_pos = 1) |
| 65 | * |
SiCongLi | eb8bd81 | 2021-10-29 15:05:49 +0100 | [diff] [blame] | 66 | * @note: On Broadcasting |
| 67 | * For n-ary post ops, the tensor arguments must not "widen" the dst tensor of the main op |
| 68 | * For example, for a dst of shape [14, 1, 34]: |
| 69 | * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dim 0 |
| 70 | * * post_op_arg1 = [14, 1, 34] is allowed: no broadcast |
| 71 | * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dims 0 and 1 |
| 72 | * * post_op_arg1 = [14, 15, 34] is NOT allowed: broadcast widens the dst tensor |
| 73 | * |
SiCongLi | d928735 | 2021-11-03 19:01:22 +0000 | [diff] [blame^] | 74 | * @note: On Data layout |
| 75 | * All post ops are data layout agnostic. This means post ops do not have an inherent idea of "width", "height" and so on. |
| 76 | * Should we want to perform a post op with 2 tensors of different data layouts (where data layouts are significant to both), |
| 77 | * then we need to perform necessary permutation op beforehand to unify their data layout before they can be fused with a post op |
| 78 | * |
| 79 | * Note although post ops themselves should be able to support any data layout, the main op they fuse to may impose |
| 80 | * additional restrictions in the presence of post ops. For example, the implementation of a gemm op may only allow |
| 81 | * NHWC data layout if post ops are provided. Such restrictions are main op implementation specific. |
| 82 | * |
SiCongLi | eb8bd81 | 2021-10-29 15:05:49 +0100 | [diff] [blame] | 83 | * @note: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type |
| 84 | * @note: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime |
SiCongLi | 1af5416 | 2021-10-06 15:25:57 +0100 | [diff] [blame] | 85 | * and the lifetime of its copies. This is almost guaranteed as IPostOp is only meant to be used at configure time |
| 86 | * after the ITensor or ITensorInfo objects are already constructed |
| 87 | */ |
| 88 | template <typename TensorRelatedT> |
| 89 | struct IPostOp |
| 90 | { |
| 91 | /** Get the arity of the post op |
SiCongLi | eb8bd81 | 2021-10-29 15:05:49 +0100 | [diff] [blame] | 92 | * @note: that this is one fewer than the arity of the original op, because we implicitly pass the previous op's dst |
SiCongLi | 1af5416 | 2021-10-06 15:25:57 +0100 | [diff] [blame] | 93 | * tensor as one of the arguments |
| 94 | */ |
| 95 | size_t arity() const |
| 96 | { |
| 97 | return arguments().size(); |
| 98 | } |
| 99 | /** The position of previous op's dst in current op's argument list */ |
| 100 | virtual int prev_dst_pos() const = 0; |
| 101 | /** The IPostOp type */ |
| 102 | virtual PostOpType type() const = 0; |
| 103 | /** The argument tensors |
| 104 | * The order of the argument tensor is strictly preserved |
| 105 | */ |
| 106 | virtual std::vector<TensorRelatedT *> arguments() = 0; |
| 107 | virtual std::vector<const TensorRelatedT *> arguments() const = 0; |
| 108 | /** Clone method used in cases where PostOps are owned by unique_ptr |
SiCongLi | eb8bd81 | 2021-10-29 15:05:49 +0100 | [diff] [blame] | 109 | * @note: This performs a shallow copy of the TensorRelatedT if TensorRelatedT points to a resource |
SiCongLi | 1af5416 | 2021-10-06 15:25:57 +0100 | [diff] [blame] | 110 | */ |
| 111 | virtual std::unique_ptr<IPostOp<TensorRelatedT>> clone() const = 0; |
| 112 | virtual ~IPostOp() |
| 113 | { |
| 114 | } |
| 115 | }; |
| 116 | |
| 117 | /** A sequence of PostOps that can be appended to the end of other operators */ |
| 118 | template <typename TensorRelatedT> |
| 119 | class PostOpList |
| 120 | { |
| 121 | public: |
| 122 | /** Constructor */ |
| 123 | PostOpList() = default; |
| 124 | /** Destructor */ |
| 125 | ~PostOpList() = default; |
| 126 | PostOpList(const PostOpList &other) |
| 127 | { |
| 128 | for(const auto &op : other._post_ops) |
| 129 | { |
| 130 | this->_post_ops.push_back(op->clone()); |
| 131 | } |
| 132 | } |
| 133 | PostOpList &operator=(const PostOpList &other) |
| 134 | { |
| 135 | PostOpList tmp{ other }; |
| 136 | std::swap(tmp, *this); |
| 137 | return *this; |
| 138 | } |
| 139 | PostOpList(PostOpList &&other) = default; |
| 140 | PostOpList &operator=(PostOpList &&other) = default; |
| 141 | |
| 142 | /** Add a new post op at the end of the list */ |
| 143 | template <typename OpT, typename... Args> |
| 144 | void push_back_op(Args &&... args) |
| 145 | { |
| 146 | _post_ops.push_back(std::make_unique<OpT>(std::forward<Args>(args)...)); |
| 147 | } |
| 148 | |
| 149 | /** Number of post ops */ |
| 150 | size_t size() const |
| 151 | { |
| 152 | return _post_ops.size(); |
| 153 | } |
| 154 | |
| 155 | /** Total number of post ops */ |
| 156 | size_t total_num_arguments() const |
| 157 | { |
| 158 | return std::accumulate(_post_ops.begin(), _post_ops.end(), 0, [](size_t op1_arity, const auto & op2) |
| 159 | { |
| 160 | return op1_arity + op2->arity(); |
| 161 | }); |
| 162 | } |
| 163 | |
| 164 | /** Get the underlying post op list */ |
| 165 | std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> &get_list() |
| 166 | { |
| 167 | return _post_ops; |
| 168 | } |
| 169 | const std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> &get_list() const |
| 170 | { |
| 171 | return _post_ops; |
| 172 | } |
| 173 | |
| 174 | private: |
| 175 | std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> _post_ops{}; |
| 176 | }; |
| 177 | |
| 178 | } // namespace experimental |
| 179 | } // namespace arm_compute |
ramelg01 | 6049eda | 2021-10-29 10:52:53 +0100 | [diff] [blame] | 180 | #endif //ARM_COMPUTE_EXPERIMENTAL_IPOSTOP |