blob: 92db569d978b388cd59227a6fd7217a5d00a1c30 [file] [log] [blame]
Michalis Spyroue9362622018-11-23 17:41:37 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25#include "warp_helpers.h"
26
Michalis Spyrou7bb6fb82018-12-07 11:24:36 +000027#if defined(DATA_TYPE) && defined(OPERATION)
28
29#if defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
Michalis Spyroue9362622018-11-23 17:41:37 +000030/** Calculate reverse square root
31 *
32 * @param[in] input Pointer to the first element.
33 *
34 * @return reverse square root
35 */
36inline VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) inverse_sqrt(const VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) input)
37{
38 return rsqrt(input);
39}
40
41/** Calculate exponential
42 *
43 * @param[in] input Pointer to the first element.
44 *
45 * @return exponential
46 */
47inline VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) exponential(const VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) input)
48{
49 return exp(input);
50}
Michalis Spyrou7bb6fb82018-12-07 11:24:36 +000051#else // !defined(VEC_SIZE) || !defined(LAST_ACCESSED_X)
52/** Calculate reverse square root
53 *
54 * @param[in] input Single element.
55 *
56 * @return reverse square root
57 */
58inline DATA_TYPE inverse_sqrt(const DATA_TYPE input)
59{
60 return rsqrt(input);
61}
62
63/** Calculate exponential
64 *
65 * @param[in] input Single element.
66 *
67 * @return exponential
68 */
69inline DATA_TYPE exponential(const DATA_TYPE input)
70{
71 return exp(input);
72}
73#endif // defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
Michalis Spyroue9362622018-11-23 17:41:37 +000074
75/** Applies element wise unary operator in a tensor.
76 *
77 * @param[in] in_ptr Pointer to the source image. Supported data types: F16/32.
78 * @param[in] in_stride_x Stride of the source image in X dimension (in bytes)
79 * @param[in] in_step_x in_stride_x * number of elements along X processed per work item (in bytes)
80 * @param[in] in_offset_first_element_in_bytes Offset of the first element in the source image
81 * @param[out] out_ptr Pointer to the destination image. Supported data types: F16/32.
82 * @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
83 * @param[in] out_step_y out_stride_y * number of elements along Y processed per work item (in bytes)
84 * @param[in] out_offset_first_element_in_bytes Offset of the first element in the destination image
85 */
86__kernel void elementwise_unary(
87 VECTOR_DECLARATION(in),
88 VECTOR_DECLARATION(out))
89{
90 Vector in = CONVERT_TO_VECTOR_STRUCT(in);
91 Vector out = CONVERT_TO_VECTOR_STRUCT(out);
92
93#if defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
94 // Check if access on width gets out of bounds
95 // If it does shift access vector to access elements within bounds
96 const int xi = (int)(get_global_id(0) * VEC_SIZE);
97 in.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * in_stride_x;
98 out.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * out_stride_x;
99
100 VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
101 data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
102
103 VSTORE(VEC_SIZE)
104 (OPERATION(data), 0, (__global DATA_TYPE *)out.ptr);
105#else // !defined(VEC_SIZE) || !defined(LAST_ACCESSED_X)
Michalis Spyrou7bb6fb82018-12-07 11:24:36 +0000106 *((__global DATA_TYPE *)(out.ptr)) = (DATA_TYPE)(OPERATION(*((__global DATA_TYPE *)in.ptr)));
Michalis Spyroue9362622018-11-23 17:41:37 +0000107#endif // defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
108}
Michalis Spyrou7bb6fb82018-12-07 11:24:36 +0000109#endif // defined(DATA_TYPE) && defined(OPERATION)