blob: 82dd919de9dddc6a5bac6369b0e2157bd779272c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Activation.hpp"
7
telsoa014fcda012018-03-09 14:13:49 +00008#include <cmath>
9
10namespace armnn
11{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010012float Activation(float in,
13 ActivationFunction function,
14 float a,
15 float b)
16{
17 float output;
18
19 // Compute the result of the activation function.
20 switch (function)
21 {
22 case ActivationFunction::Linear:
23 {
24 output = a * in + b;
25 break;
26 }
27 case ActivationFunction::Sigmoid:
28 {
29 output = 1.f / (1.f + expf(-in));
30 break;
31 }
32 case ActivationFunction::ReLu:
33 {
34 output = std::max(0.f, in);
35 break;
36 }
37 case ActivationFunction::BoundedReLu:
38 {
39 output = std::min(a, std::max(b, in));
40 break;
41 }
42 case ActivationFunction::SoftReLu:
43 {
44 output = logf(1.0f + expf(in));
45 break;
46 }
47 case ActivationFunction::LeakyReLu:
48 {
49 output = in > 0.0f ? in : (in * a);
50 break;
51 }
52 case ActivationFunction::Abs:
53 {
54 output = in < 0 ? -in : in;
55 break;
56 }
57 case ActivationFunction::Sqrt:
58 {
59 output = sqrtf(in);
60 break;
61 }
62 case ActivationFunction::Square:
63 {
64 output = in * in;
65 break;
66 }
67 case ActivationFunction::TanH:
68 {
69 output = a * tanhf(b * in);
70 break;
71 }
David Monahan3b3c3812020-02-25 09:03:29 +000072 case ActivationFunction::Elu:
73 {
74 output = (in >= 0) ? in : a * (expf(in) - 1);
75 break;
76 }
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010077 default:
78 {
79 throw InvalidArgumentException("Unsupported activation function");
80 }
81 }
82
83 return output;
84}
85
86
87void Activation(Decoder<float>& in,
88 Encoder<float>& out,
89 const TensorInfo& tensorInfo,
90 ActivationFunction function,
91 float a,
92 float b)
93{
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010094 unsigned int numElements = tensorInfo.GetNumElements();
95
96 for (unsigned int i = 0; i < numElements; i++)
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010097 {
98 out.Set(Activation(in.Get(), function, a, b));
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010099 ++in;
100 ++out;
101 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100102 in -= numElements;
103 out -= numElements;
telsoa014fcda012018-03-09 14:13:49 +0000104}
105
106} //namespace armnn