blob: f162df0b552f4c1705fd0a5df497d51d1b7fe544 [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Optional.hpp>
7#include <armnn/Types.hpp>
8
9#include <gpuFsa/GpuFsaLayerSupport.hpp>
10
11#include <doctest/doctest.h>
12
13#include <iostream>
14
15using namespace armnn;
16
17TEST_SUITE("GpuFsaLayerSupport")
18{
19
20TEST_CASE("IsLayerSupportedGpuFsaConv2d")
21{
22 TensorInfo inputInfo ({ 1, 5, 5, 1 }, DataType::Float32);
23 TensorInfo outputInfo({ 1, 3, 3, 1 }, DataType::Float32);
24 TensorInfo weightsInfo({ 1, 3, 3, 1 }, DataType::Float32, 0.0f, 0, true);
25 TensorInfo biasesInfo ({ 1 }, DataType::Float32, 0.0f, 0, true);
26
27 Convolution2dDescriptor desc;
28 desc.m_BiasEnabled = true;
29 desc.m_DataLayout = DataLayout::NHWC;
30
31 GpuFsaLayerSupport supportChecker;
32 std::string reasonIfNotSupported;
33 auto supported = supportChecker.IsLayerSupported(LayerType::Convolution2d,
34 {inputInfo, outputInfo, weightsInfo, biasesInfo},
35 desc,
36 EmptyOptional(),
37 EmptyOptional(),
38 reasonIfNotSupported);
39 CHECK(supported);
40}
41
42TEST_CASE("IsLayerSupportedGpuFsaConv2dUnsupported")
43{
44 TensorInfo inputInfo ({ 1, 5, 5, 1 }, DataType::Float32);
45 TensorInfo outputInfo({ 1, 3, 3, 1 }, DataType::Float32);
46 TensorInfo weightsInfo({ 1, 3, 3, 1 }, DataType::Float32, 0.0f, 0, true);
47
48 // NCHW is unsupported.
49 Convolution2dDescriptor desc;
50 desc.m_DataLayout = DataLayout::NCHW;
51
52 GpuFsaLayerSupport supportChecker;
53 std::string reasonIfNotSupported;
54 auto supported = supportChecker.IsLayerSupported(LayerType::Convolution2d,
55 {inputInfo, outputInfo, weightsInfo, TensorInfo()},
56 desc,
57 EmptyOptional(),
58 EmptyOptional(),
59 reasonIfNotSupported);
60 CHECK(!supported);
61 REQUIRE(reasonIfNotSupported.find("NCHW not supported by this kernel") != std::string::npos);
62}
63
64}