COMPMID-3239: Fix projection and peephole in NEQLSTMLayer
- Peephole and projection has been fixed to be working
- Small internal kernel copying data between tensors to
cover the case where num_units and output_size is different
is added.
Below is strictly outside of this patch's scope but are changes
helping this patch working (directly or indirectly) or making
NEQLSTM more complete.
- Consideration for layer normalization is added to InfoHelpers
- QSYMM8 data type is added to helper function to
print out tensors.
- NE/CLLSTMLayer::validate() logic has been modified to use correct
value for shape validation.
Change-Id: I40b4e71dfdbe8432caa2fe4a9af60a725362cc33
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3157
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 32ff813..56f22e2 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -444,7 +444,7 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_batches);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
}
@@ -453,9 +453,9 @@
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_batches);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_batches);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_batches);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
}
// Check peephole optimization