blob: 2b0c84e226a7f0368363884ba5c926539d5b4c77 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Activation.hpp"
7
8#include <boost/log/trivial.hpp>
9
10#include <cmath>
11
12namespace armnn
13{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010014float Activation(float in,
15 ActivationFunction function,
16 float a,
17 float b)
18{
19 float output;
20
21 // Compute the result of the activation function.
22 switch (function)
23 {
24 case ActivationFunction::Linear:
25 {
26 output = a * in + b;
27 break;
28 }
29 case ActivationFunction::Sigmoid:
30 {
31 output = 1.f / (1.f + expf(-in));
32 break;
33 }
34 case ActivationFunction::ReLu:
35 {
36 output = std::max(0.f, in);
37 break;
38 }
39 case ActivationFunction::BoundedReLu:
40 {
41 output = std::min(a, std::max(b, in));
42 break;
43 }
44 case ActivationFunction::SoftReLu:
45 {
46 output = logf(1.0f + expf(in));
47 break;
48 }
49 case ActivationFunction::LeakyReLu:
50 {
51 output = in > 0.0f ? in : (in * a);
52 break;
53 }
54 case ActivationFunction::Abs:
55 {
56 output = in < 0 ? -in : in;
57 break;
58 }
59 case ActivationFunction::Sqrt:
60 {
61 output = sqrtf(in);
62 break;
63 }
64 case ActivationFunction::Square:
65 {
66 output = in * in;
67 break;
68 }
69 case ActivationFunction::TanH:
70 {
71 output = a * tanhf(b * in);
72 break;
73 }
74 default:
75 {
76 throw InvalidArgumentException("Unsupported activation function");
77 }
78 }
79
80 return output;
81}
82
83
84void Activation(Decoder<float>& in,
85 Encoder<float>& out,
86 const TensorInfo& tensorInfo,
87 ActivationFunction function,
88 float a,
89 float b)
90{
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010091 unsigned int numElements = tensorInfo.GetNumElements();
92
93 for (unsigned int i = 0; i < numElements; i++)
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010094 {
95 out.Set(Activation(in.Get(), function, a, b));
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010096 ++in;
97 ++out;
98 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010099 in -= numElements;
100 out -= numElements;
telsoa014fcda012018-03-09 14:13:49 +0000101}
102
103} //namespace armnn