blob: 814a0ddd13400d189f51fb06424a503852a7a6ba [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Activation.hpp"
7
telsoa014fcda012018-03-09 14:13:49 +00008#include <cmath>
9
10namespace armnn
11{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010012float Activation(float in,
13 ActivationFunction function,
14 float a,
15 float b)
16{
17 float output;
18
19 // Compute the result of the activation function.
20 switch (function)
21 {
22 case ActivationFunction::Linear:
23 {
24 output = a * in + b;
25 break;
26 }
27 case ActivationFunction::Sigmoid:
28 {
29 output = 1.f / (1.f + expf(-in));
30 break;
31 }
32 case ActivationFunction::ReLu:
33 {
34 output = std::max(0.f, in);
35 break;
36 }
37 case ActivationFunction::BoundedReLu:
38 {
39 output = std::min(a, std::max(b, in));
40 break;
41 }
42 case ActivationFunction::SoftReLu:
43 {
44 output = logf(1.0f + expf(in));
45 break;
46 }
47 case ActivationFunction::LeakyReLu:
48 {
49 output = in > 0.0f ? in : (in * a);
50 break;
51 }
52 case ActivationFunction::Abs:
53 {
54 output = in < 0 ? -in : in;
55 break;
56 }
57 case ActivationFunction::Sqrt:
58 {
59 output = sqrtf(in);
60 break;
61 }
62 case ActivationFunction::Square:
63 {
64 output = in * in;
65 break;
66 }
67 case ActivationFunction::TanH:
68 {
69 output = a * tanhf(b * in);
70 break;
71 }
72 default:
73 {
74 throw InvalidArgumentException("Unsupported activation function");
75 }
76 }
77
78 return output;
79}
80
81
82void Activation(Decoder<float>& in,
83 Encoder<float>& out,
84 const TensorInfo& tensorInfo,
85 ActivationFunction function,
86 float a,
87 float b)
88{
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010089 unsigned int numElements = tensorInfo.GetNumElements();
90
91 for (unsigned int i = 0; i < numElements; i++)
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010092 {
93 out.Set(Activation(in.Get(), function, a, b));
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010094 ++in;
95 ++out;
96 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010097 in -= numElements;
98 out -= numElements;
telsoa014fcda012018-03-09 14:13:49 +000099}
100
101} //namespace armnn