blob: 760c9a0ccd99fe94a518fefc63081303e48b9b33 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Activation.hpp"
7
8#include <boost/log/trivial.hpp>
9
10#include <cmath>
11
12namespace armnn
13{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010014float Activation(float in,
15 ActivationFunction function,
16 float a,
17 float b)
18{
19 float output;
20
21 // Compute the result of the activation function.
22 switch (function)
23 {
24 case ActivationFunction::Linear:
25 {
26 output = a * in + b;
27 break;
28 }
29 case ActivationFunction::Sigmoid:
30 {
31 output = 1.f / (1.f + expf(-in));
32 break;
33 }
34 case ActivationFunction::ReLu:
35 {
36 output = std::max(0.f, in);
37 break;
38 }
39 case ActivationFunction::BoundedReLu:
40 {
41 output = std::min(a, std::max(b, in));
42 break;
43 }
44 case ActivationFunction::SoftReLu:
45 {
46 output = logf(1.0f + expf(in));
47 break;
48 }
49 case ActivationFunction::LeakyReLu:
50 {
51 output = in > 0.0f ? in : (in * a);
52 break;
53 }
54 case ActivationFunction::Abs:
55 {
56 output = in < 0 ? -in : in;
57 break;
58 }
59 case ActivationFunction::Sqrt:
60 {
61 output = sqrtf(in);
62 break;
63 }
64 case ActivationFunction::Square:
65 {
66 output = in * in;
67 break;
68 }
69 case ActivationFunction::TanH:
70 {
71 output = a * tanhf(b * in);
72 break;
73 }
74 default:
75 {
76 throw InvalidArgumentException("Unsupported activation function");
77 }
78 }
79
80 return output;
81}
82
83
84void Activation(Decoder<float>& in,
85 Encoder<float>& out,
86 const TensorInfo& tensorInfo,
87 ActivationFunction function,
88 float a,
89 float b)
90{
91 for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
92 {
93 out.Set(Activation(in.Get(), function, a, b));
94
95 ++in;
96 ++out;
97 }
98}
telsoa014fcda012018-03-09 14:13:49 +000099
100void Activation(const float* in,
101 float* out,
102 const TensorInfo& tensorInfo,
103 ActivationFunction function,
104 float a,
105 float b)
106{
107 for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
108 {
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100109 out[i] = Activation(in[i], function, a, b);
telsoa014fcda012018-03-09 14:13:49 +0000110 }
111}
112
113} //namespace armnn