blob: f0e473a0b870eaa59badeb274c67a00d24d3ff00 [file] [log] [blame]
/*
SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
SPDX-License-Identifier: Apache-2.0
*/
#ifndef TOSA_CHECKER_H_
#define TOSA_CHECKER_H_
#include <map>
#include <optional>
#include <string>
#include <vector>
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
#include "mlir/include/mlir/IR/OwningOpRef.h"
namespace tosa_checker {
class TOSAChecker {
public:
struct Operator {
Operator(std::string name, std::string location,
std::map<std::string, std::string> attributes,
bool is_tosa_compatible, std::string mlir_representation)
: name(std::move(name)),
location(std::move(location)),
attributes(std::move(attributes)),
is_tosa_compatible(is_tosa_compatible),
mlir_representation(std::move(mlir_representation)) {}
std::string name;
std::string location;
std::map<std::string, std::string> attributes;
bool is_tosa_compatible;
std::string mlir_representation;
};
TOSAChecker(const std::string& model_path);
bool IsTOSACompatible();
std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs);
std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs);
std::string GetMLIRModelRepresentation(bool elide_large_attrs);
std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs);
private:
static bool IsTOSACompatibleOp(mlir::Operation& op);
template <typename T>
static std::string GetMLIRRepresentation(T&& op);
template <typename T>
static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs);
static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model);
static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible,
bool elide_large_attrs);
static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR(
const std::string& model_path, mlir::MLIRContext* context);
static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module);
static std::map<std::string, std::string> GetAttributes(
mlir::Operation& op, bool elide_large_attrs);
private:
static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16;
mlir::MLIRContext m_context;
mlir::OwningOpRef<mlir::ModuleOp> m_model;
mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model;
};
} // namespace tosa_checker
std::ostream& operator<<(std::ostream& os,
const tosa_checker::TOSAChecker::Operator& op);
#endif