Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2019 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include <reference/workloads/ArgMinMax.hpp> |
| 7 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 8 | #include <doctest/doctest.h> |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 9 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 10 | TEST_SUITE("RefArgMinMax") |
| 11 | { |
| 12 | TEST_CASE("ArgMinTest") |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 13 | { |
| 14 | const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32); |
Inki Dae | d4619e2 | 2020-09-10 15:33:54 +0900 | [diff] [blame] | 15 | const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64); |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 16 | |
| 17 | std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f}); |
Inki Dae | d4619e2 | 2020-09-10 15:33:54 +0900 | [diff] [blame] | 18 | std::vector<int64_t> outputValues(outputInfo.GetNumElements()); |
| 19 | std::vector<int64_t> expectedValues({ 0, 1, 0 }); |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 20 | |
| 21 | ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()), |
| 22 | outputValues.data(), |
| 23 | inputInfo, |
| 24 | outputInfo, |
| 25 | armnn::ArgMinMaxFunction::Min, |
| 26 | -2); |
| 27 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 28 | CHECK(std::equal(outputValues.begin(), outputValues.end(), expectedValues.begin(), expectedValues.end())); |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 29 | |
| 30 | } |
| 31 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 32 | TEST_CASE("ArgMaxTest") |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 33 | { |
| 34 | const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32); |
Inki Dae | d4619e2 | 2020-09-10 15:33:54 +0900 | [diff] [blame] | 35 | const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64); |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 36 | |
| 37 | std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f }); |
Inki Dae | d4619e2 | 2020-09-10 15:33:54 +0900 | [diff] [blame] | 38 | std::vector<int64_t> outputValues(outputInfo.GetNumElements()); |
| 39 | std::vector<int64_t> expectedValues({ 1, 0, 1 }); |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 40 | |
| 41 | ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()), |
| 42 | outputValues.data(), |
| 43 | inputInfo, |
| 44 | outputInfo, |
| 45 | armnn::ArgMinMaxFunction::Max, |
| 46 | -2); |
| 47 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 48 | CHECK(std::equal(outputValues.begin(), outputValues.end(), expectedValues.begin(), expectedValues.end())); |
Narumol Prangnawarat | 4dc64a6 | 2019-09-16 17:00:22 +0100 | [diff] [blame] | 49 | |
| 50 | } |
| 51 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 52 | } |