Extend reference model API with eager operator execution entrypoints

- Adds a script to generate operators.h and operators.cc
- Adds jinja2 templates for generating operators.h and operators.cc
- Adds unit tests for a subset of the operators generated
- Includes the TOSA specification as a submodule
- Adds supporting C++ and header files

Signed-off-by: Grant Watson <grant.watson@arm.com>
Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c
diff --git a/scripts/operator_api/README.md b/scripts/operator_api/README.md
new file mode 100644
index 0000000..381d90c
--- /dev/null
+++ b/scripts/operator_api/README.md
@@ -0,0 +1,19 @@
+# Generate eager operator execution entrypoints
+
+## Introduction
+
+The generate_api.py script will generate an extended reference model API with eager operator execution entrypoints.
+The following files will be generated: include/operators.h and src/operators.cc
+
+## Requirements
+
+* Python 3.6 or later
+* Jinja2 (install with ```pip install Jinja2```)
+
+## Running from the command line
+
+The script can be run from the scripts/operator-api directory as follows:
+
+```bash
+python generate_api.py
+```
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
new file mode 100644
index 0000000..1f89f74
--- /dev/null
+++ b/scripts/operator_api/generate_api.py
@@ -0,0 +1,349 @@
+"""Generate extended reference model API with eager operator execution entrypoints"""
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import copy
+import os
+import subprocess
+from xml.dom import minidom
+
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
+
+
+def getTosaArgTypes(tosaXml):
+    """
+    Returns a list of the TOSA argument types from tosa.xml.
+    """
+    argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"}
+    argTypesXml = tosaXml.getElementsByTagName("type")
+    for argTypeXml in argTypesXml:
+        argTypes.add(argTypeXml.getAttribute("name"))
+    argTypes.remove("TABLE_SIZE")
+    return argTypes
+
+
+def getTosaDataTypes(tosaXml):
+    """
+    Returns a list of the TOSA data types from tosa.xml.
+    """
+    argTypes = getTosaArgTypes(tosaXml)
+    dataTypes = set()
+    dataTypesXml = tosaXml.getElementsByTagName("typesupport")
+    for dataTypeXml in dataTypesXml:
+        for argType in argTypes:
+            dataType = dataTypeXml.getAttribute(argType)
+            if dataType != "":
+                dataTypes.add(f"tosa_datatype_{dataType}")
+    return sorted(dataTypes)
+
+
+def getSerializeOpType(tosaOpName):
+    """
+    Returns the Serialization library operator that matches the TOSA operator specified.
+    """
+    map = {
+        "avg_pool2d": "Pool",
+        "conv2d": "Conv",
+        "conv3d": "Conv",
+        "depthwise_conv2d": "Conv",
+        "fully_connected": "FullyConnected",
+        "matmul": "MatMul",
+        "max_pool2d": "Pool",
+        "transpose_conv2d": "Conv",
+        "clamp": "Clamp",
+        "arithmetic_right_shift": "ArithmeticRightShift",
+        "mul": "Mul",
+        "table": "Table",
+        "negate": "Negate",
+        "pad": "Pad",
+        "reshape": "Reshape",
+        "slice": "Slice",
+        "tile": "Tile",
+        "transpose": "Transpose",
+        "resize": "Resize",
+        "rescale": "Rescale",
+        "cond_if": "CondIf",
+        "while_loop": "WhileLoop",
+    }
+    if tosaOpName not in map.keys():
+        return "None"
+    else:
+        return map[tosaOpName]
+
+
+def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
+    """
+    Returns the arguments required by the Serialization library for the TOSA operator specified.
+    Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
+    that value is used for initialization, otherwise a default value e.g. 0 is used.
+    """
+    serOpType = getSerializeOpType(tosaOpName)
+    if serOpType not in allSerializeArgs.keys():
+        return {}
+    else:
+        serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
+        tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
+        serTosaTypeMap = {"ResizeMode": "tosa_mode"}
+        for arg in serOpArgs:
+            argName = arg["name"]
+            init = ""
+            # Translate TOSA data types to Serialization data types for initialization
+            if arg["dType"] in serTosaTypeMap.keys():
+                init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
+            # Initialize Serialization arguments to their matching function parameter
+            elif argName in tosaArgsDict:
+                if arg["SV"] == "V":
+                    shape = tosaArgsDict[argName]["shape"]
+                    if shape == "[]":
+                        init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
+                    else:
+                        init = f"(&client_{argName}[0], &client_{argName}{shape})"
+                else:
+                    init = f" = client_{argName}"
+            else:
+                # Initialize Serialization arguments with no matching fuction parameter
+                if arg["SV"] == "V":
+                    init = ""
+                else:
+                    if arg["dType"] == "DType":
+                        arg["dType"] = "tosa::DType"
+                        init = " = tosa::DType::DType_FP32"
+                    else:
+                        init = " = 0"
+            arg["init"] = init
+        return serOpArgs
+
+
+def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
+    """
+    Replace TOSA argument data types with their matching Serialization argument data types.
+    Delete TOSA arguments where the type couldn't be determined.
+    Add Serialization arguments that have no matching TOSA argument.
+    """
+    tosaArgTypes = getTosaArgTypes(tosaXml)
+    serArgsDict = {arg["name"]: arg for arg in serializeArgs}
+    tosaArgsNames = [arg["name"] for arg in tosaArgs]
+    delTosaArgs = []
+    # Replace TOSA argument data types with their matching Serialization argument data types.
+    for tosaArg in tosaArgs:
+        if tosaArg["type"] in tosaArgTypes:
+            if tosaArg["name"] in serArgsDict:
+                tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
+            else:
+                # Delete TOSA argument whose data type can't be determined
+                delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
+                # Delete corresponding length argument if one exists
+                lenArgName = f"{tosaArg['name']}_len"
+                if lenArgName in tosaArgsNames:
+                    delTosaArgs.append(tosaArgsNames.index(lenArgName))
+    # Delete TOSA arguments where the type couldn't be determined
+    for index in sorted(delTosaArgs, key=int, reverse=True):
+        del tosaArgs[index]
+    # Add Serialization arguments that have no matching TOSA argument
+    tosaArgNames = [arg["name"] for arg in tosaArgs]
+    for serArg in serializeArgs:
+        if (serArg["name"] not in tosaArgNames) and (
+            not serArg["dType"] == "tosa::DType"
+        ):
+            serArgName = serArg["name"]
+            if serArg["SV"] == "V":
+                # For vector data types, insert a matching length argument
+                tosaArgs.insert(
+                    len(tosaArgs) - 1,
+                    {
+                        "name": f"{serArgName}_len",
+                        "type": "int32_t",
+                        "shape": "",
+                        "category": "",
+                    },
+                )
+                init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
+                shape = "[]"
+            else:
+                init = f" = client_{serArg['name']}"
+                shape = ""
+            serArg["init"] = init
+            # Insert new argument
+            tosaArgs.insert(
+                len(tosaArgs) - 1,
+                {
+                    "name": serArgName,
+                    "type": serArg["dType"],
+                    "shape": shape,
+                    "category": "",
+                },
+            )
+
+
+def getOperators(tosaXml):
+    """
+    Return a list of TOSA operators as defined by tosa.xml.
+    """
+    operators = []
+    ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
+    opsXml = tosaXml.getElementsByTagName("operator")
+    allSerializeArgs = getSerializeArgs()
+    for opXml in opsXml:
+        opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
+        if opName not in ignoreOps:
+            operator = {"name": opName}
+            operator["serializeAttType"] = getSerializeOpType(opName)
+            tosaArgs = getTosaArgs(opXml)
+            serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
+            updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
+            operator["arguments"] = tosaArgs
+            operator["serializeArgs"] = serializeArgs
+            operator["inputs"] = [
+                arg["name"] for arg in tosaArgs if arg["category"] == "input"
+            ]
+            operator["outputs"] = [
+                arg["name"] for arg in tosaArgs if arg["category"] == "output"
+            ]
+            operators.append(operator)
+    return operators
+
+
+def getTosaArgs(opXml):
+    """
+    Return the arguments required for the TOSA operator specified.
+    """
+    arguments = []
+    argsXml = opXml.getElementsByTagName("argument")
+    tosaTensorTypes = getTosaArgTypes(tosaXml)
+    tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
+    for xmlArg in argsXml:
+        argName = xmlArg.getAttribute("name").lower()
+        argType = xmlArg.getAttribute("type")
+        argShape = xmlArg.getAttribute("shape")
+        argCategory = xmlArg.getAttribute("category")
+        # Update argument type
+        if argType[-1:] == "*":
+            argType = argType[:-1]
+        if argCategory in ["input", "output"] and argType in tosaTensorTypes:
+            argType = "tosa_tensor_t"
+            argShape = ""
+        if argType in tosaTypeMap:
+            argType = tosaTypeMap[argType]
+        # Add a length argument for arrays with unknown compile-time size
+        if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
+            argShape = "[]"
+            arguments.append(
+                {
+                    "name": f"{argName}_len",
+                    "type": "int32_t",
+                    "shape": "",
+                    "category": "",
+                }
+            )
+        elif argShape == "" or not argShape[0] == "[":
+            argShape = ""
+        # Append argument
+        arguments.append(
+            {
+                "name": argName,
+                "type": argType,
+                "shape": argShape,
+                "category": argCategory,
+            }
+        )
+    return arguments
+
+
+def clangFormat(filename):
+    cmd = ["clang-format", "-i", filename]
+    with open(os.devnull, "w") as devnull:
+        subprocess.check_call(cmd, stdout=devnull)
+
+
+def getSerializeArgs():
+    """
+    Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
+    The values are the arguments required by each Serialization library operator.
+    """
+    serializeArgs = {}
+    with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
+        preamble = True
+        inAtt = False
+        opName = ""
+        args = []
+        for line in file:
+            if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
+                continue
+            else:
+                preamble = False
+            line = line.lstrip().rstrip()
+            if not inAtt and "DEF_ATTRIBUTE(" in line:
+                opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
+                inAtt = True
+            elif inAtt:
+                vals = line.split(",")
+                argName = vals[2].lstrip().strip()
+                if ")" in argName:
+                    argName = argName[:-1]
+                arg = {
+                    "name": argName,
+                    "dType": vals[0].lstrip().strip(),
+                    "SV": vals[1].lstrip().strip(),
+                }
+                args.append(arg)
+                if ")" in line:
+                    serializeArgs[opName] = args
+                    opName = ""
+                    args = []
+                    inAtt = False
+    return serializeArgs
+
+
+def renderTemplate(environment, dataTypes, operators, template, outfile):
+    content = template.render(dataTypes=dataTypes, operators=operators)
+    with open(outfile, mode="w", encoding="utf-8") as output:
+        output.write(content)
+        print(f"Created {outfile}")
+
+    clangFormat(outfile)
+
+
+def generate(environment, dataTypes, operators):
+    # Generate include/operators.h
+    template = environment.get_template("operators_h.j2")
+    outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
+    renderTemplate(environment, dataTypes, operators, template, outfile)
+
+    # Generate src/operators.cc
+    template = environment.get_template("operators_cc.j2")
+    outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
+    renderTemplate(environment, dataTypes, operators, template, outfile)
+
+
+def getSerializeOpTypeMap():
+    """
+    Utility function for generating the map used in getSerializeOpType()
+    """
+    import re
+
+    allSerializeArgs = getSerializeArgs()
+    serArgs = [
+        re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
+        for name in allSerializeArgs.keys()
+    ]
+    serArgs = sorted(serArgs, key=len, reverse=True)
+    tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+    opsXml = tosaXml.getElementsByTagName("operator")
+    opNames = [
+        op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
+    ]
+    map = {}
+    for opName in opNames:
+        for serArg in serArgs:
+            if serArg in opName:
+                components = serArg.split("_")
+                map[opName] = "".join(x.title() for x in components)
+    return map
+
+
+if __name__ == "__main__":
+    environment = Environment(loader=FileSystemLoader("templates/"))
+    tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+    dataTypes = getTosaDataTypes(tosaXml)
+    operators = getOperators(tosaXml)
+    generate(environment, dataTypes, operators)
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
new file mode 100644
index 0000000..6b0ed6e
--- /dev/null
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -0,0 +1,176 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//         http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#include "operators.h"
+#include "model_runner_impl.h"
+#include "ops/op_factory.h"
+
+#define TOSA_RETURN_ON_ERROR(status)                                                                                   \
+    do                                                                                                                 \
+    {                                                                                                                  \
+        if (status != 0)                                                                                               \
+        {                                                                                                              \
+            return tosa_status_error;                                                                                  \
+        }                                                                                                              \
+    } while (false)
+
+#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status)                                                                      \
+    do                                                                                                                 \
+    {                                                                                                                  \
+        if (status != GraphStatus::TOSA_VALID)                                                                         \
+        {                                                                                                              \
+            auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status);                                   \
+            return static_cast<tosa_status_t>(ustatus);                                                                \
+        }                                                                                                              \
+    } while (false)
+
+namespace {
+
+tosa::DType translate_client_datatype(tosa_datatype_t type)
+{
+    switch (type)
+    {
+        case tosa_datatype_fp16_t:
+            return tosa::DType::DType_FP16;
+        case tosa_datatype_fp32_t:
+            return tosa::DType::DType_FP32;
+        default:
+            return tosa::DType::DType_UNKNOWN;
+    }
+};
+
+tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
+{
+    std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
+    return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
+}
+
+tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
+    switch(mode) {
+        case tosa_mode_nearest:
+            return tosa::ResizeMode_NEAREST;
+        case tosa_mode_max:            
+        case tosa_mode_bilinear:
+            return tosa::ResizeMode_BILINEAR;
+        default:
+            return tosa::ResizeMode_UNKNOWN;            
+    }
+}
+
+}    // namespace
+
+extern "C"
+{
+    {% for operator in operators: %}
+    tosa_status_t tosa_run_{{ operator.name }} (
+        {%- for arg in operator.arguments: -%}
+            {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+            {% if loop.index < operator.arguments|length %},{% endif %}
+        {%- endfor -%}
+    )
+    {
+        // Create operator attributes
+        {% for arg in operator.serializeArgs: %}
+            {%- if arg.SV == "V": -%}
+                const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
+            {%- else: -%}
+                const {{arg.dType}} {{arg.name}}{{arg.init}};
+            {%- endif -%}
+        {%- endfor -%}
+
+        Tosa{{operator.serializeAttType}}Attribute attr
+        {%- if operator.serializeArgs|length > 0 -%}
+        (
+            {%- for arg in operator.serializeArgs: -%}
+                {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+            {%- endfor -%}
+        )
+        {%- endif -%};
+
+        // Create tensors
+        {% for input in operator.inputs: -%}
+            tosa::TosaSerializationTensor* {{input}}  = translate_client_tensor(client_{{input}}, "{{input}}");
+        {%- endfor -%}
+        {% for output in operator.outputs: %}
+            tosa::TosaSerializationTensor* {{output}}  = translate_client_tensor(client_{{output}}, "{{output}}");
+        {%- endfor %}
+
+        // Create operator
+        auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
+                                                      {%- if operator.serializeAttType != "None" -%}
+                                                        tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
+                                                      {%- else -%}
+                                                        tosa::Attribute::Attribute_NONE
+                                                      {%- endif -%},
+                                                      &attr, {
+                                                                {%- for input in operator.inputs: -%}
+                                                                    {{input}}->GetName()
+                                                                    {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+                                                                {%- endfor -%}
+                                                             },
+                                                             {
+                                                                {%- for output in operator.outputs: -%}
+                                                                    {{output}}->GetName()
+                                                                    {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+                                                                {%- endfor -%}
+                                                             });
+
+        // Create a tosa single-op basic block
+        tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
+                                                 {
+                                                    {%- for input in operator.inputs: -%}
+                                                        {{input}},
+                                                    {%- endfor -%}
+                                                    {%- for output in operator.outputs: -%}
+                                                        {{output}}
+                                                        {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+                                                    {%- endfor -%}
+                                                 },
+                                                 {
+                                                    {%- for input in operator.inputs: -%}
+                                                        {{input}}->GetName()
+                                                        {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+                                                    {%- endfor -%}
+                                                 },
+                                                 {
+                                                    {%- for output in operator.outputs: -%}
+                                                        {{output}}->GetName()
+                                                        {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+                                                    {%- endfor -%}
+                                                 });
+
+        // Setup model
+        TosaReference::ModelRunnerImpl runner;
+        TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+        {% for input in operator.inputs: -%}
+            TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
+        {%- endfor %}
+
+        // Execute
+        TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+        // Extract outputs
+        {% for output in operator.outputs: -%}
+            TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
+        {%- endfor %}
+
+        return tosa_status_valid;
+    }
+    {% endfor %}
+
+}    // extern "C"
\ No newline at end of file
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
new file mode 100644
index 0000000..803b76a
--- /dev/null
+++ b/scripts/operator_api/templates/operators_h.j2
@@ -0,0 +1,74 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//         http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#ifndef OPERATORS_H_
+#define OPERATORS_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+    // Note status needs to be aligned with graph_status
+    enum tosa_status_t
+    {
+        tosa_status_valid         = 0,
+        tosa_status_unpredictable = 1,
+        tosa_status_error         = 2
+    };
+
+    enum tosa_mode_t
+    {
+        tosa_mode_unknown  = 0,
+        tosa_mode_nearest  = 1,
+        tosa_mode_bilinear = 2,
+        tosa_mode_min      = 3,
+        tosa_mode_max      = 4
+    };
+
+    enum tosa_datatype_t
+    {
+        {% for dataType in dataTypes: -%}
+            {{dataType}} = {{loop.index-1}},
+        {% endfor -%}
+    };
+
+    struct tosa_tensor_t
+    {
+        int32_t* shape;
+        int32_t num_dims;
+        tosa_datatype_t data_type;
+        uint8_t* data;
+        size_t size;
+    };
+
+    {% for operator in operators: %}
+        tosa_status_t tosa_run_{{ operator.name }} (
+            {%- for arg in operator.arguments: -%}
+                {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+                {% if loop.index < operator.arguments|length %},{% endif %}
+            {%- endfor -%});
+    {% endfor %}
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#endif // OPERATORS_H_
\ No newline at end of file