Add type check when getting message by ID

When the kernel driver handles incoming rpmsg messages, it uses the ID
in the message header to find the corresponding mailbox message. The
mailbox messages are stored as a generic message struct that is later
cast to the specific message type.

There is currently no type information stored in the generic message
struct so only the ID is used to distinguish between the mailbox
messages. This means if an incorrect ID is received that matches a
mailbox message with a different type, the kernel driver will cast the
message struct to the wrong type.

Type information has now been added to the generic message struct and
will be checked when trying to find the corresponding mailbox message so
both the ID and type must be correct to find a matching message.

Change-Id: Ifdbceea6ec4ae7078f424a32ce1ff5474bd22fef
diff --git a/kernel/ethosu_cancel_inference.c b/kernel/ethosu_cancel_inference.c
index 6661522..4d7b544 100644
--- a/kernel/ethosu_cancel_inference.c
+++ b/kernel/ethosu_cancel_inference.c
@@ -159,11 +159,12 @@
 	struct ethosu_mailbox_msg *msg;
 	struct ethosu_cancel_inference *cancellation;
 
-	msg = ethosu_mailbox_find(mailbox, msg_id);
+	msg = ethosu_mailbox_find(mailbox, msg_id,
+				  ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ);
 	if (IS_ERR(msg)) {
 		dev_warn(dev,
-			 "Id for cancel inference msg not found. id=%ddev",
-			 msg_id);
+			 "Id for cancel inference msg not found. Id=0x%x: %ld",
+			 msg_id, PTR_ERR(msg));
 
 		return;
 	}
diff --git a/kernel/ethosu_capabilities.c b/kernel/ethosu_capabilities.c
index 8611edf..83bd8cf 100644
--- a/kernel/ethosu_capabilities.c
+++ b/kernel/ethosu_capabilities.c
@@ -68,11 +68,12 @@
 	struct ethosu_mailbox_msg *msg;
 	struct ethosu_capabilities *cap;
 
-	msg = ethosu_mailbox_find(mailbox, msg_id);
+	msg = ethosu_mailbox_find(mailbox, msg_id,
+				  ETHOSU_CORE_MSG_CAPABILITIES_REQ);
 	if (IS_ERR(msg)) {
 		dev_warn(dev,
-			 "Id for capabilities msg not found. id=%d\n",
-			 msg_id);
+			 "Id for capabilities msg not found. Id=0x%0x: %ld\n",
+			 msg_id, PTR_ERR(msg));
 
 		return;
 	}
diff --git a/kernel/ethosu_inference.c b/kernel/ethosu_inference.c
index 6befd3a..dd0b7b9 100644
--- a/kernel/ethosu_inference.c
+++ b/kernel/ethosu_inference.c
@@ -429,11 +429,12 @@
 	int ret;
 	int i;
 
-	msg = ethosu_mailbox_find(mailbox, msg_id);
+	msg = ethosu_mailbox_find(mailbox, msg_id,
+				  ETHOSU_CORE_MSG_INFERENCE_REQ);
 	if (IS_ERR(msg)) {
 		dev_warn(dev,
-			 "Id for inference msg not found. Id=%d\n",
-			 msg_id);
+			 "Id for inference msg not found. Id=0x%x: %ld\n",
+			 msg_id, PTR_ERR(msg));
 
 		return;
 	}
diff --git a/kernel/ethosu_mailbox.c b/kernel/ethosu_mailbox.c
index 4c64f17..3e7284b 100644
--- a/kernel/ethosu_mailbox.c
+++ b/kernel/ethosu_mailbox.c
@@ -79,12 +79,16 @@
 }
 
 struct ethosu_mailbox_msg *ethosu_mailbox_find(struct ethosu_mailbox *mbox,
-					       int msg_id)
+					       int msg_id,
+					       uint32_t msg_type)
 {
 	struct ethosu_mailbox_msg *ptr = (struct ethosu_mailbox_msg *)idr_find(
 		&mbox->msg_idr, msg_id);
 
 	if (ptr == NULL)
+		return ERR_PTR(-ENOENT);
+
+	if (ptr->type != msg_type)
 		return ERR_PTR(-EINVAL);
 
 	return ptr;
@@ -147,6 +151,8 @@
 		}
 	};
 
+	msg->type = rpmsg.header.type;
+
 	return rpmsg_send(mbox->ept, &rpmsg, sizeof(rpmsg.header));
 }
 
@@ -172,6 +178,8 @@
 	struct ethosu_core_msg_inference_req *inf_req = &rpmsg.inf_req;
 	uint32_t i;
 
+	msg->type = rpmsg.header.type;
+
 	/* Verify that the uapi and core has the same number of pmus */
 	if (pmu_event_config_count != ETHOSU_CORE_PMU_MAX) {
 		dev_err(mbox->dev, "PMU count misconfigured.");
@@ -218,6 +226,8 @@
 	};
 	struct ethosu_core_msg_network_info_req *info_req = &rpmsg.net_info_req;
 
+	msg->type = rpmsg.header.type;
+
 	if (network != NULL) {
 		info_req->network.type = ETHOSU_CORE_NETWORK_BUFFER;
 		ethosu_core_set_size(network, &info_req->network.buffer);
@@ -246,6 +256,8 @@
 		}
 	};
 
+	msg->type = rpmsg.header.type;
+
 	return rpmsg_send(mbox->ept, &rpmsg,
 			  sizeof(rpmsg.header) + sizeof(rpmsg.cancel_req));
 }
diff --git a/kernel/ethosu_mailbox.h b/kernel/ethosu_mailbox.h
index edf922b..c192b54 100644
--- a/kernel/ethosu_mailbox.h
+++ b/kernel/ethosu_mailbox.h
@@ -51,8 +51,9 @@
 };
 
 struct ethosu_mailbox_msg {
-	int  id;
-	void (*fail)(struct ethosu_mailbox_msg *msg);
+	int      id;
+	uint32_t type;
+	void     (*fail)(struct ethosu_mailbox_msg *msg);
 };
 
 /****************************************************************************
@@ -93,7 +94,8 @@
  * Return: a valid pointer on success, otherwise an error ptr.
  */
 struct ethosu_mailbox_msg *ethosu_mailbox_find(struct ethosu_mailbox *mbox,
-					       int msg_id);
+					       int msg_id,
+					       uint32_t msg_type);
 
 /**
  * ethosu_mailbox_fail() - Fail mailbox messages
diff --git a/kernel/ethosu_network_info.c b/kernel/ethosu_network_info.c
index 0e205db..5bfa150 100644
--- a/kernel/ethosu_network_info.c
+++ b/kernel/ethosu_network_info.c
@@ -127,11 +127,12 @@
 	struct ethosu_network_info *info;
 	uint32_t i;
 
-	msg = ethosu_mailbox_find(mailbox, msg_id);
+	msg = ethosu_mailbox_find(mailbox, msg_id,
+				  ETHOSU_CORE_MSG_NETWORK_INFO_REQ);
 	if (IS_ERR(msg)) {
 		dev_warn(dev,
-			 "Id for network info msg not found. msg.id=0x%x\n",
-			 msg_id);
+			 "Id for network info msg not found. Id=0x%x: %ld\n",
+			 msg_id, PTR_ERR(msg));
 
 		return;
 	}