Add check of operator API to precommit

Attempt to avoid API getting out of sync.

Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Change-Id: Ic7b72c3f906e4a38cb26159bb67e9b1c4e22ca96
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4a6c9c7..2c72a26 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -37,4 +37,13 @@
         language: system
         entry: clang-format
         types: ["c++"]
-        args: ["-i"]
\ No newline at end of file
+        args: ["-i"]
+
+-   repo: local
+    hooks:
+    -   id: check-operator-api
+        name: check-operator-api
+        language: system
+        entry: python3 scripts/operator_api/generate_api.py
+        pass_filenames: false
+        always_run: true
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index c5c762d..fd33466 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -4,6 +4,7 @@
 import copy
 import os
 import subprocess
+from pathlib import Path
 from xml.dom import minidom
 
 from jinja2 import Environment
@@ -12,6 +13,10 @@
 # Note: main script designed to be run from the scripts/operator_api/ directory
 
 
+def getBasePath():
+    return Path(__file__).resolve().parent.parent.parent
+
+
 def getTosaArgTypes(tosaXml):
     """
     Returns a list of the TOSA argument types from tosa.xml.
@@ -326,7 +331,11 @@
     The values are the arguments required by each Serialization library operator.
     """
     serialLibAtts = {}
-    with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
+    base_path = getBasePath()
+    attr_def = (
+        base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def"
+    )
+    with open(attr_def) as file:
         preamble = True
         inAtt = False
         opName = ""
@@ -368,15 +377,15 @@
     clangFormat(outfile)
 
 
-def generate(environment, dataTypes, operators):
+def generate(environment, dataTypes, operators, base_path):
     # Generate include/operators.h
     template = environment.get_template("operators_h.j2")
-    outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
+    outfile = base_path / "reference_model/include/operators.h"
     renderTemplate(environment, dataTypes, operators, template, outfile)
 
     # Generate src/operators.cc
     template = environment.get_template("operators_cc.j2")
-    outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
+    outfile = base_path / "reference_model/src/operators.cc"
     renderTemplate(environment, dataTypes, operators, template, outfile)
 
 
@@ -392,7 +401,8 @@
         for name in allSerialLibAtts.keys()
     ]
     serAtts = sorted(serAtts, key=len, reverse=True)
-    tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+    base_path = getBasePath()
+    tosaXml = minidom.parse(base_path / "thirdparty/specification/tosa.xml")
     opsXml = tosaXml.getElementsByTagName("operator")
     opNames = [
         op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
@@ -407,8 +417,11 @@
 
 
 if __name__ == "__main__":
-    environment = Environment(loader=FileSystemLoader("templates/"))
-    tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+    base_path = getBasePath()
+    environment = Environment(
+        loader=FileSystemLoader(Path(__file__).resolve().parent / "templates")
+    )
+    tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml"))
     dataTypes = getTosaDataTypes(tosaXml)
     operators = getOperators(tosaXml)
-    generate(environment, dataTypes, operators)
+    generate(environment, dataTypes, operators, base_path)