Fix for sign extending LOGICAL LEFT/RIGHT SHIFT results

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I04261178694c004409aef2ff5c84c32b04729433
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 3abf961..c697db0 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -288,27 +288,32 @@
 template <int Rank, DType Dtype>
 int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
 {
-    int32_t num_bits = 0;
     switch (Dtype)
     {
         case DType_INT8:
-            num_bits = 8;
+            this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+                REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
+                (int32_t)b);
+                return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
+            };
             break;
         case DType_INT16:
-            num_bits = 16;
+            this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+                REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
+                (int32_t)b);
+                return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
+            };
             break;
         case DType_INT32:
-            num_bits = 32;
+            this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+                REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
+                (int32_t)b);
+                return static_cast<OutEigenType>(static_cast<int32_t>(a << b));
+            };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
     }
-    this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
-        uint32_t mask = ONES_MASK(num_bits);
-        REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
-        (int32_t)b);
-        return (a << b) & mask;
-    };
 
     return 0;
 }
@@ -316,29 +321,33 @@
 template <int Rank, DType Dtype>
 int OpLogicalRightShift<Rank, Dtype>::register_fcn()
 {
-    int32_t num_bits = 0;
     switch (Dtype)
     {
         case DType_INT8:
-            num_bits = 8;
+            this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+                REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
+                (int32_t)b);
+                return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b);
+            };
             break;
         case DType_INT16:
-            num_bits = 16;
+            this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+                REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
+                (int32_t)b);
+                return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b);
+            };
             break;
         case DType_INT32:
-            num_bits = 32;
+            this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+                REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
+                (int32_t)b);
+                return static_cast<OutEigenType>(static_cast<int32_t>(a) >> b);
+            };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
     }
 
-    this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
-        uint32_t mask = ONES_MASK(num_bits) >> b;
-        REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
-        (int32_t)b);
-        return (a >> b) & mask;
-    };
-
     return 0;
 }