blob: d28174d6a6e8c8533358bc84e710720628c9c59e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "TensorCopyUtils.hpp"
8#include "WorkloadTestUtils.hpp"
9
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
11
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <test/TensorHelpers.hpp>
arovir0143095f32018-10-09 18:04:24 +010013
arovir0143095f32018-10-09 18:04:24 +010014#include <boost/multi_array.hpp>
15
telsoa014fcda012018-03-09 14:13:49 +000016struct ActivationFixture
17{
18 ActivationFixture()
19 {
20 auto boostArrayExtents = boost::extents
Matthew Sloyan171214c2020-09-09 09:07:37 +010021 [armnn::numeric_cast<boost::multi_array_types::extent_gen::index>(batchSize)]
22 [armnn::numeric_cast<boost::multi_array_types::extent_gen::index>(channels)]
23 [armnn::numeric_cast<boost::multi_array_types::extent_gen::index>(height)]
24 [armnn::numeric_cast<boost::multi_array_types::extent_gen::index>(width)];
telsoa014fcda012018-03-09 14:13:49 +000025 output.resize(boostArrayExtents);
26 outputExpected.resize(boostArrayExtents);
27 input.resize(boostArrayExtents);
28
29 unsigned int inputShape[] = { batchSize, channels, height, width };
30 unsigned int outputShape[] = { batchSize, channels, height, width };
31
32 inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32);
33 outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32);
34
35 input = MakeRandomTensor<float, 4>(inputTensorInfo, 21453);
36 }
37
38 unsigned int width = 17;
39 unsigned int height = 29;
40 unsigned int channels = 2;
41 unsigned int batchSize = 5;
42
43 boost::multi_array<float, 4> output;
44 boost::multi_array<float, 4> outputExpected;
45 boost::multi_array<float, 4> input;
46
47 armnn::TensorInfo inputTensorInfo;
48 armnn::TensorInfo outputTensorInfo;
49
telsoa01c577f2c2018-08-31 09:22:23 +010050 // Parameters used by some of the activation functions.
telsoa014fcda012018-03-09 14:13:49 +000051 float a = 0.234f;
52 float b = -12.345f;
53};
54
55
56struct PositiveActivationFixture : public ActivationFixture
57{
58 PositiveActivationFixture()
59 {
60 input = MakeRandomTensor<float, 4>(inputTensorInfo, 2342423, 0.0f, 1.0f);
61 }
62};