blob: c799817baf109dd433b3ada87c773aa6277561a2 [file] [log] [blame]
/*
SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
SPDX-License-Identifier: Apache-2.0
*/
#include "tosa_checker.h"
#include <optional>
#include <sstream>
#include <string>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
PYBIND11_MODULE(_tosa_checker_wrapper, m) {
/**
* tosa_checker::TOSAChecker
*/
pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m,
"TOSAChecker");
tosa_checker_class.def(pybind11::init<const std::string&>(),
pybind11::arg("model_path"));
tosa_checker_class.def(
"is_tosa_compatible",
[](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); },
"Check if a model is compatible with the TOSA specification");
tosa_checker_class.def(
"_get_tosa_compatibility_for_ops",
[](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs);
},
pybind11::arg("elide_large_elements_attrs") = false,
"Get all the operators of the models with a TOSA compatibility flag for "
"each operator");
tosa_checker_class.def(
"_get_used_tosa_ops",
[](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
return tc.GetUsedTOSAOps(elide_large_elements_attrs);
},
pybind11::arg("elide_large_elements_attrs") = false,
"Get the TOSA operators used by the model after its TOSA legalization");
tosa_checker_class.def(
"_get_mlir_model_representation",
[](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
return tc.GetMLIRModelRepresentation(elide_large_elements_attrs);
},
pybind11::arg("elide_large_elements_attrs") = false,
"Get the MLIR representation of the model");
tosa_checker_class.def(
"_get_mlir_tosa_model_representation",
[](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs);
},
pybind11::arg("elide_large_elements_attrs") = false,
"Get the MLIR representation of the TOSA legalized model");
/**
* tosa_checker::TOSAChecker::Operator
*/
pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class,
"_Operator")
.def_readonly("name", &tosa_checker::TOSAChecker::Operator::name)
.def_readonly("location", &tosa_checker::TOSAChecker::Operator::location)
.def_readonly("attributes",
&tosa_checker::TOSAChecker::Operator::attributes)
.def_readonly("is_tosa_compatible",
&tosa_checker::TOSAChecker::Operator::is_tosa_compatible)
.def_readonly("mlir_representation",
&tosa_checker::TOSAChecker::Operator::mlir_representation)
.def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) {
std::stringstream stream;
stream << o;
return stream.str();
});
}