blob: 40f93ce420a27b033311eab9738bf63ef3ca45d1 [file] [log] [blame]
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "ParserFlatbuffersSerializeFixture.hpp"
#include <armnnDeserializer/IDeserializer.hpp>
#include <doctest/doctest.h>
#include <string>
TEST_SUITE("Deserializer_BatchMatMul")
{
struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture
{
explicit BatchMatMulFixture(const std::string& inputXShape,
const std::string& inputYShape,
const std::string& outputShape,
const std::string& dataType)
{
m_JsonString = R"(
{
inputIds:[
0,
1
],
outputIds:[
3
],
layers:[
{
layer_type:"InputLayer",
layer:{
base:{
layerBindingId:0,
base:{
index:0,
layerName:"InputXLayer",
layerType:"Input",
inputSlots:[
{
index:0,
connection:{
sourceLayerIndex:0,
outputSlotIndex:0
},
}
],
outputSlots:[
{
index:0,
tensorInfo:{
dimensions:)" + inputXShape + R"(,
dataType:)" + dataType + R"(
},
}
],
},
}
},
},
{
layer_type:"InputLayer",
layer:{
base:{
layerBindingId:1,
base:{
index:1,
layerName:"InputYLayer",
layerType:"Input",
inputSlots:[
{
index:0,
connection:{
sourceLayerIndex:0,
outputSlotIndex:0
},
}
],
outputSlots:[
{
index:0,
tensorInfo:{
dimensions:)" + inputYShape + R"(,
dataType:)" + dataType + R"(
},
}
],
},
}
},
},
{
layer_type:"BatchMatMulLayer",
layer:{
base:{
index:2,
layerName:"BatchMatMulLayer",
layerType:"BatchMatMul",
inputSlots:[
{
index:0,
connection:{
sourceLayerIndex:0,
outputSlotIndex:0
},
},
{
index:1,
connection:{
sourceLayerIndex:1,
outputSlotIndex:0
},
}
],
outputSlots:[
{
index:0,
tensorInfo:{
dimensions:)" + outputShape + R"(,
dataType:)" + dataType + R"(
},
}
],
},
descriptor:{
transposeX:false,
transposeY:false,
adjointX:false,
adjointY:false,
dataLayoutX:NHWC,
dataLayoutY:NHWC
}
},
},
{
layer_type:"OutputLayer",
layer:{
base:{
layerBindingId:0,
base:{
index:3,
layerName:"OutputLayer",
layerType:"Output",
inputSlots:[
{
index:0,
connection:{
sourceLayerIndex:2,
outputSlotIndex:0
},
}
],
outputSlots:[
{
index:0,
tensorInfo:{
dimensions:)" + outputShape + R"(,
dataType:)" + dataType + R"(
},
}
],
}
}
},
}
]
}
)";
Setup();
}
};
struct SimpleBatchMatMulFixture : BatchMatMulFixture
{
SimpleBatchMatMulFixture()
: BatchMatMulFixture("[ 1, 2, 2, 1 ]",
"[ 1, 2, 2, 1 ]",
"[ 1, 2, 2, 1 ]",
"Float32")
{}
};
TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest")
{
RunTest<4, armnn::DataType::Float32>(
0,
{{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }},
{"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}},
{{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}});
}
}