blob: fdb6091ad7c234f22a9111859a06e13a3d253b25 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "Activation.hpp"
7
8#include <boost/log/trivial.hpp>
9
10#include <cmath>
11
12namespace armnn
13{
14
15void Activation(const float* in,
16 float* out,
17 const TensorInfo& tensorInfo,
18 ActivationFunction function,
19 float a,
20 float b)
21{
22 for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
23 {
24 float input = in[i];
25 float output;
26
telsoa01c577f2c2018-08-31 09:22:23 +010027 // Compute the result of the activation function.
telsoa014fcda012018-03-09 14:13:49 +000028 switch (function)
29 {
30 case ActivationFunction::Linear:
31 {
32 output = a * input + b;
33 break;
34 }
35 case ActivationFunction::Sigmoid:
36 {
37 output = 1.f / (1.f + expf(-input));
38 break;
39 }
40 case ActivationFunction::ReLu:
41 {
42 output = std::max(0.f, input);
43 break;
44 }
45 case ActivationFunction::BoundedReLu:
46 {
47 output = std::min(a, std::max(b, input));
48 break;
49 }
50 case ActivationFunction::SoftReLu:
51 {
52 output = logf(1.0f + expf(input));
53 break;
54 }
55 case ActivationFunction::LeakyReLu:
56 {
57 output = input > 0.0f ? input : (input * a);
58 break;
59 }
60 case ActivationFunction::Abs:
61 {
62 output = input < 0 ? -input : input;
63 break;
64 }
65 case ActivationFunction::Sqrt:
66 {
67 output = sqrtf(input);
68 break;
69 }
70 case ActivationFunction::Square:
71 {
72 output = input * input;
73 break;
74 }
75 case ActivationFunction::TanH:
76 {
77 output = a * tanhf(b * input);
78 break;
79 }
80 default:
81 {
82 BOOST_LOG_TRIVIAL(error) << "Unsupported activation function";
83 return;
84 }
85 }
86
87 out[i] = output;
88 }
89}
90
91} //namespace armnn