blob: cc26f68f578b2cd00496845aef7fda1cd97269cc [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
Mike Kellye2d611e2021-10-14 12:35:58 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
Nikhil Raj77605822018-09-03 11:25:56 +01005
6#pragma once
7
surmeh0149b9e102018-05-17 14:11:25 +01008#include "DriverTestHelpers.hpp"
Nikhil Raj77605822018-09-03 11:25:56 +01009
surmeh0149b9e102018-05-17 14:11:25 +010010#include <log/log.h>
11
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <OperationsUtils.h>
surmeh0149b9e102018-05-17 14:11:25 +010013
telsoa01ce3e84a2018-08-31 09:31:35 +010014using namespace android::hardware;
surmeh0149b9e102018-05-17 14:11:25 +010015using namespace driverTestHelpers;
telsoa01ce3e84a2018-08-31 09:31:35 +010016using namespace armnn_driver;
surmeh0149b9e102018-05-17 14:11:25 +010017
Sadik Armagan188675f2021-02-12 17:16:42 +000018using RequestArgument = V1_0::RequestArgument;
19
Nikhil Raj77605822018-09-03 11:25:56 +010020namespace driverTestHelpers
surmeh0149b9e102018-05-17 14:11:25 +010021{
Kevin Mayedc5ffa2019-05-22 12:02:53 +010022#define ARMNN_ANDROID_FP16_TEST(result, fp16Expectation, fp32Expectation, fp16Enabled) \
23 if (fp16Enabled) \
24 { \
Mike Kellye2d611e2021-10-14 12:35:58 +010025 DOCTEST_CHECK_MESSAGE((result == fp16Expectation || result == fp32Expectation), result << \
Kevin Mayedc5ffa2019-05-22 12:02:53 +010026 " does not match either " << fp16Expectation << "[fp16] or " << fp32Expectation << "[fp32]"); \
27 } else \
28 { \
Mike Kellye2d611e2021-10-14 12:35:58 +010029 DOCTEST_CHECK(result == fp32Expectation); \
Kevin Mayedc5ffa2019-05-22 12:02:53 +010030 }
surmeh0149b9e102018-05-17 14:11:25 +010031
Nikhil Raj77605822018-09-03 11:25:56 +010032void SetModelFp16Flag(V1_0::Model& model, bool fp16Enabled);
33
Nikhil Raj77605822018-09-03 11:25:56 +010034void SetModelFp16Flag(V1_1::Model& model, bool fp16Enabled);
Nikhil Raj77605822018-09-03 11:25:56 +010035
36template<typename HalPolicy>
37void PaddingTestImpl(android::nn::PaddingScheme paddingScheme, bool fp16Enabled = false)
surmeh0149b9e102018-05-17 14:11:25 +010038{
Nikhil Raj77605822018-09-03 11:25:56 +010039 using HalModel = typename HalPolicy::Model;
40 using HalOperationType = typename HalPolicy::OperationType;
41
Kevin Mayedc5ffa2019-05-22 12:02:53 +010042 armnn::Compute computeDevice = armnn::Compute::GpuAcc;
43
44#ifndef ARMCOMPUTECL_ENABLED
45 computeDevice = armnn::Compute::CpuRef;
46#endif
47
48 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(computeDevice, fp16Enabled));
Nikhil Raj77605822018-09-03 11:25:56 +010049 HalModel model = {};
surmeh0149b9e102018-05-17 14:11:25 +010050
51 uint32_t outSize = paddingScheme == android::nn::kPaddingSame ? 2 : 1;
52
53 // add operands
Nikhil Raj77605822018-09-03 11:25:56 +010054 float weightValue[] = {1.f, -1.f, 0.f, 1.f};
Sadik Armagan9150bff2021-05-26 15:40:53 +010055 float biasValue[] = {0.f};
surmeh0149b9e102018-05-17 14:11:25 +010056
Sadik Armagan9150bff2021-05-26 15:40:53 +010057 AddInputOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 2, 3, 1});
58 AddTensorOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 2, 2, 1}, weightValue);
59 AddTensorOperand<HalPolicy>(model, hidl_vec < uint32_t > {1}, biasValue);
60 AddIntOperand<HalPolicy>(model, (int32_t) paddingScheme); // padding
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010061 AddIntOperand<HalPolicy>(model, 2); // stride x
62 AddIntOperand<HalPolicy>(model, 2); // stride y
63 AddIntOperand<HalPolicy>(model, 0); // no activation
Sadik Armagan9150bff2021-05-26 15:40:53 +010064 AddOutputOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 1, outSize, 1});
surmeh0149b9e102018-05-17 14:11:25 +010065
66 // make the convolution operation
67 model.operations.resize(1);
Nikhil Raj77605822018-09-03 11:25:56 +010068 model.operations[0].type = HalOperationType::CONV_2D;
Sadik Armagan9150bff2021-05-26 15:40:53 +010069 model.operations[0].inputs = hidl_vec < uint32_t > {0, 1, 2, 3, 4, 5, 6};
70 model.operations[0].outputs = hidl_vec < uint32_t > {7};
surmeh0149b9e102018-05-17 14:11:25 +010071
72 // make the prepared model
Nikhil Raj77605822018-09-03 11:25:56 +010073 SetModelFp16Flag(model, fp16Enabled);
Sadik Armagane6e54a82019-05-08 10:18:05 +010074 android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
surmeh0149b9e102018-05-17 14:11:25 +010075
76 // construct the request
Sadik Armagan188675f2021-02-12 17:16:42 +000077 V1_0::DataLocation inloc = {};
Sadik Armagan9150bff2021-05-26 15:40:53 +010078 inloc.poolIndex = 0;
79 inloc.offset = 0;
80 inloc.length = 6 * sizeof(float);
81 RequestArgument input = {};
82 input.location = inloc;
83 input.dimensions = hidl_vec < uint32_t > {};
surmeh0149b9e102018-05-17 14:11:25 +010084
Sadik Armagan188675f2021-02-12 17:16:42 +000085 V1_0::DataLocation outloc = {};
Sadik Armagan9150bff2021-05-26 15:40:53 +010086 outloc.poolIndex = 1;
87 outloc.offset = 0;
88 outloc.length = outSize * sizeof(float);
89 RequestArgument output = {};
90 output.location = outloc;
91 output.dimensions = hidl_vec < uint32_t > {};
surmeh0149b9e102018-05-17 14:11:25 +010092
Kevin Mayec1e5b82020-02-26 17:00:39 +000093 V1_0::Request request = {};
Sadik Armagan9150bff2021-05-26 15:40:53 +010094 request.inputs = hidl_vec < RequestArgument > {input};
95 request.outputs = hidl_vec < RequestArgument > {output};
surmeh0149b9e102018-05-17 14:11:25 +010096
surmeh0149b9e102018-05-17 14:11:25 +010097 // set the input data (matching source test)
Nikhil Raj77605822018-09-03 11:25:56 +010098 float indata[] = {1024.25f, 1.f, 0.f, 3.f, -1, -1024.25f};
surmeh0149b9e102018-05-17 14:11:25 +010099 AddPoolAndSetData(6, request, indata);
100
101 // add memory for the output
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100102 android::sp<IMemory> outMemory = AddPoolAndGetData<float>(outSize, request);
Nikhil Raj77605822018-09-03 11:25:56 +0100103 float* outdata = reinterpret_cast<float*>(static_cast<void*>(outMemory->getPointer()));
surmeh0149b9e102018-05-17 14:11:25 +0100104
105 // run the execution
Sadik Armagand4636872020-04-27 10:15:41 +0100106 if (preparedModel.get() != nullptr)
107 {
108 Execute(preparedModel, request);
109 }
surmeh0149b9e102018-05-17 14:11:25 +0100110
111 // check the result
Nikhil Raj77605822018-09-03 11:25:56 +0100112 switch (paddingScheme)
surmeh0149b9e102018-05-17 14:11:25 +0100113 {
Sadik Armagan9150bff2021-05-26 15:40:53 +0100114 case android::nn::kPaddingValid:
115 ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
116 break;
117 case android::nn::kPaddingSame:
118 ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
Mike Kellye2d611e2021-10-14 12:35:58 +0100119 DOCTEST_CHECK(outdata[1] == 0.f);
Sadik Armagan9150bff2021-05-26 15:40:53 +0100120 break;
121 default:
Mike Kellye2d611e2021-10-14 12:35:58 +0100122 DOCTEST_CHECK(false);
Sadik Armagan9150bff2021-05-26 15:40:53 +0100123 break;
surmeh0149b9e102018-05-17 14:11:25 +0100124 }
125}
126
Nikhil Raj77605822018-09-03 11:25:56 +0100127} // namespace driverTestHelpers