Fix type validation in the network create UAPI

Currently, the network create UAPI will assume that any network type
that isn't a buffer is an index. This means that the Linux kernel NPU
driver will accept any network type value and the user won't get any
feedback that they have specified an incorrect type.

To resolve this, the Linux kernel NPU driver will now return -EINVAL if
an unknown network type is given and a test has been added to validate
this behavior.

Change-Id: Ib7d9f5d5451897787981aae61a4e0a6650a73e05
Signed-off-by: Mikael Olsson <mikael.olsson@arm.com>
diff --git a/kernel/ethosu_network.c b/kernel/ethosu_network.c
index 94354ed..f7871de 100644
--- a/kernel/ethosu_network.c
+++ b/kernel/ethosu_network.c
@@ -15,7 +15,6 @@
  * You should have received a copy of the GNU General Public License
  * along with this program; if not, you can access it online at
  * http://www.gnu.org/licenses/gpl-2.0.html.
- *
  */
 
 /****************************************************************************
@@ -163,14 +162,21 @@
 	net->buf = NULL;
 	kref_init(&net->kref);
 
-	if (uapi->type == ETHOSU_UAPI_NETWORK_BUFFER) {
+	switch (uapi->type) {
+	case ETHOSU_UAPI_NETWORK_BUFFER:
 		net->buf = ethosu_buffer_get_from_fd(uapi->fd);
 		if (IS_ERR(net->buf)) {
 			ret = PTR_ERR(net->buf);
 			goto free_net;
 		}
-	} else {
+
+		break;
+	case ETHOSU_UAPI_NETWORK_INDEX:
 		net->index = uapi->index;
+		break;
+	default:
+		ret = -EINVAL;
+		goto free_net;
 	}
 
 	ret = anon_inode_getfd("ethosu-network", &ethosu_network_fops, net,
diff --git a/tests/run_inference_test.cpp b/tests/run_inference_test.cpp
index 480e26f..6075d7a 100644
--- a/tests/run_inference_test.cpp
+++ b/tests/run_inference_test.cpp
@@ -105,6 +105,20 @@
     } catch (std::exception &e) { throw TestFailureException("NetworkInfo unparsable buffer test: ", e.what()); }
 }
 
+void testNetworkInvalidType(const Device &device) {
+    const std::string expected_error =
+        std::string("IOCTL cmd=") + std::to_string(ETHOSU_IOCTL_NETWORK_CREATE) + " failed: " + std::strerror(EINVAL);
+    struct ethosu_uapi_network_create net_req = {};
+    net_req.type                              = ETHOSU_UAPI_NETWORK_INDEX + 1;
+    try {
+        int r = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, &net_req);
+        FAIL();
+    } catch (Exception &e) {
+        // The call is expected to throw
+        TEST_ASSERT(expected_error.compare(e.what()) == 0);
+    } catch (std::exception &e) { throw TestFailureException("NetworkCreate invalid type test: ", e.what()); }
+}
+
 void testRunInferenceBuffer(const Device &device) {
     try {
         auto networkBuffer = std::make_shared<Buffer>(device, sizeof(networkModelData));
@@ -154,6 +168,7 @@
         testPing(device);
         testDriverVersion(device);
         testCapabilties(device);
+        testNetworkInvalidType(device);
         testNetworkInfoNotExistentIndex(device);
         testNetworkInfoBuffer(device);
         testNetworkInfoUnparsableBuffer(device);