MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser
* Added support for float operations with int8 weights to TFLite Parser
* Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected,
TransposeConv and UnidirectionalSequenceLstm
* Renamed subgraphIndex to subgraph to fix name-shadowing warning.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index fc000bf..108b878 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -15,7 +15,10 @@
const std::string& filterShape,
const std::string& filterData,
const std::string biasShape = "",
- const std::string biasData = "")
+ const std::string biasData = "",
+ const std::string dataType = "UINT8",
+ const std::string weightsDataType = "UINT8",
+ const std::string biasDataType = "INT32")
{
std::string inputTensors = "[ 0, 2 ]";
std::string biasTensor = "";
@@ -26,7 +29,7 @@
biasTensor = R"(
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -47,7 +50,7 @@
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -59,7 +62,7 @@
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -71,7 +74,7 @@
},
{
"shape": )" + filterShape + R"(,
- "type": "UINT8",
+ "type": )" + weightsDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -353,4 +356,27 @@
{{"output", { 20 }}});
}
+struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture
+{
+ FullyConnectedWeightsBiasFloat()
+ : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
+ "[ 1, 1 ]", // outputShape
+ "[ 1, 4 ]", // filterShape
+ "[ 2, 3, 4, 5 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // filterShape
+ "FLOAT32", // input and output dataType
+ "INT8", // weights dataType
+ "FLOAT32") // bias dataType
+ {}
+};
+
+TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat")
+{
+ RunTest<2, armnn::DataType::Float32>(
+ 0,
+ { 10, 20, 30, 40 },
+ { 400 });
+}
+
}