blob: 180f57e210500082eeba0ed096a7128230d9a3fc [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
Nikhil Raj77605822018-09-03 11:25:56 +01005
6#pragma once
7
surmeh0149b9e102018-05-17 14:11:25 +01008#include "DriverTestHelpers.hpp"
Nikhil Raj77605822018-09-03 11:25:56 +01009
surmeh0149b9e102018-05-17 14:11:25 +010010#include <boost/test/unit_test.hpp>
11#include <log/log.h>
12
telsoa01ce3e84a2018-08-31 09:31:35 +010013#include <OperationsUtils.h>
surmeh0149b9e102018-05-17 14:11:25 +010014
15BOOST_AUTO_TEST_SUITE(Convolution2DTests)
16
telsoa01ce3e84a2018-08-31 09:31:35 +010017using namespace android::hardware;
surmeh0149b9e102018-05-17 14:11:25 +010018using namespace driverTestHelpers;
telsoa01ce3e84a2018-08-31 09:31:35 +010019using namespace armnn_driver;
surmeh0149b9e102018-05-17 14:11:25 +010020
Nikhil Raj77605822018-09-03 11:25:56 +010021namespace driverTestHelpers
surmeh0149b9e102018-05-17 14:11:25 +010022{
Kevin Mayedc5ffa2019-05-22 12:02:53 +010023#define ARMNN_ANDROID_FP16_TEST(result, fp16Expectation, fp32Expectation, fp16Enabled) \
24 if (fp16Enabled) \
25 { \
26 BOOST_TEST((result == fp16Expectation || result == fp32Expectation), result << \
27 " does not match either " << fp16Expectation << "[fp16] or " << fp32Expectation << "[fp32]"); \
28 } else \
29 { \
30 BOOST_TEST(result == fp32Expectation); \
31 }
surmeh0149b9e102018-05-17 14:11:25 +010032
Nikhil Raj77605822018-09-03 11:25:56 +010033void SetModelFp16Flag(V1_0::Model& model, bool fp16Enabled);
34
Matteo Martincigha5f9e762019-06-17 13:26:34 +010035#if defined(ARMNN_ANDROID_NN_V1_1) || defined(ARMNN_ANDROID_NN_V1_2)
Nikhil Raj77605822018-09-03 11:25:56 +010036void SetModelFp16Flag(V1_1::Model& model, bool fp16Enabled);
37#endif
38
39template<typename HalPolicy>
40void PaddingTestImpl(android::nn::PaddingScheme paddingScheme, bool fp16Enabled = false)
surmeh0149b9e102018-05-17 14:11:25 +010041{
Nikhil Raj77605822018-09-03 11:25:56 +010042 using HalModel = typename HalPolicy::Model;
43 using HalOperationType = typename HalPolicy::OperationType;
44
Kevin Mayedc5ffa2019-05-22 12:02:53 +010045 armnn::Compute computeDevice = armnn::Compute::GpuAcc;
46
47#ifndef ARMCOMPUTECL_ENABLED
48 computeDevice = armnn::Compute::CpuRef;
49#endif
50
51 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(computeDevice, fp16Enabled));
Nikhil Raj77605822018-09-03 11:25:56 +010052 HalModel model = {};
surmeh0149b9e102018-05-17 14:11:25 +010053
54 uint32_t outSize = paddingScheme == android::nn::kPaddingSame ? 2 : 1;
55
56 // add operands
Nikhil Raj77605822018-09-03 11:25:56 +010057 float weightValue[] = {1.f, -1.f, 0.f, 1.f};
58 float biasValue[] = {0.f};
surmeh0149b9e102018-05-17 14:11:25 +010059
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010060 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 2, 3, 1});
61 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 2, 2, 1}, weightValue);
62 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1}, biasValue);
63 AddIntOperand<HalPolicy>(model, (int32_t)paddingScheme); // padding
64 AddIntOperand<HalPolicy>(model, 2); // stride x
65 AddIntOperand<HalPolicy>(model, 2); // stride y
66 AddIntOperand<HalPolicy>(model, 0); // no activation
67 AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, outSize, 1});
surmeh0149b9e102018-05-17 14:11:25 +010068
69 // make the convolution operation
70 model.operations.resize(1);
Nikhil Raj77605822018-09-03 11:25:56 +010071 model.operations[0].type = HalOperationType::CONV_2D;
surmeh0149b9e102018-05-17 14:11:25 +010072 model.operations[0].inputs = hidl_vec<uint32_t>{0, 1, 2, 3, 4, 5, 6};
73 model.operations[0].outputs = hidl_vec<uint32_t>{7};
74
75 // make the prepared model
Nikhil Raj77605822018-09-03 11:25:56 +010076 SetModelFp16Flag(model, fp16Enabled);
Sadik Armagane6e54a82019-05-08 10:18:05 +010077 android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
surmeh0149b9e102018-05-17 14:11:25 +010078
79 // construct the request
80 DataLocation inloc = {};
81 inloc.poolIndex = 0;
82 inloc.offset = 0;
83 inloc.length = 6 * sizeof(float);
84 RequestArgument input = {};
85 input.location = inloc;
86 input.dimensions = hidl_vec<uint32_t>{};
87
88 DataLocation outloc = {};
89 outloc.poolIndex = 1;
90 outloc.offset = 0;
91 outloc.length = outSize * sizeof(float);
92 RequestArgument output = {};
93 output.location = outloc;
94 output.dimensions = hidl_vec<uint32_t>{};
95
96 Request request = {};
97 request.inputs = hidl_vec<RequestArgument>{input};
98 request.outputs = hidl_vec<RequestArgument>{output};
99
surmeh0149b9e102018-05-17 14:11:25 +0100100 // set the input data (matching source test)
Nikhil Raj77605822018-09-03 11:25:56 +0100101 float indata[] = {1024.25f, 1.f, 0.f, 3.f, -1, -1024.25f};
surmeh0149b9e102018-05-17 14:11:25 +0100102 AddPoolAndSetData(6, request, indata);
103
104 // add memory for the output
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100105 android::sp<IMemory> outMemory = AddPoolAndGetData<float>(outSize, request);
Nikhil Raj77605822018-09-03 11:25:56 +0100106 float* outdata = reinterpret_cast<float*>(static_cast<void*>(outMemory->getPointer()));
surmeh0149b9e102018-05-17 14:11:25 +0100107
108 // run the execution
109 Execute(preparedModel, request);
110
111 // check the result
Nikhil Raj77605822018-09-03 11:25:56 +0100112 switch (paddingScheme)
surmeh0149b9e102018-05-17 14:11:25 +0100113 {
Nikhil Raj77605822018-09-03 11:25:56 +0100114 case android::nn::kPaddingValid:
Kevin Mayedc5ffa2019-05-22 12:02:53 +0100115 ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
Nikhil Raj77605822018-09-03 11:25:56 +0100116 break;
117 case android::nn::kPaddingSame:
Kevin Mayedc5ffa2019-05-22 12:02:53 +0100118 ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
119 BOOST_TEST(outdata[1] == 0.f);
Nikhil Raj77605822018-09-03 11:25:56 +0100120 break;
121 default:
surmeh0149b9e102018-05-17 14:11:25 +0100122 BOOST_TEST(false);
Nikhil Raj77605822018-09-03 11:25:56 +0100123 break;
surmeh0149b9e102018-05-17 14:11:25 +0100124 }
125}
126
Nikhil Raj77605822018-09-03 11:25:56 +0100127} // namespace driverTestHelpers
surmeh0149b9e102018-05-17 14:11:25 +0100128
129BOOST_AUTO_TEST_SUITE_END()