blob: 8fe1f26382f2768d1247998c1721e8e4943d47d2 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the modules needed to create a test model and run the TOSA Checker."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tosa_checker as tc\n",
"import tensorflow as tf\n",
"import tempfile\n",
"import os"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a simple model that is compatible with the TOSA specification."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n"
]
}
],
"source": [
"input = tf.keras.layers.Input(shape=(16,))\n",
"x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n",
"model = tf.keras.models.Model(inputs=[input], outputs=x)\n",
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"tflite_model = converter.convert()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"_, tflite_file = tempfile.mkstemp('.tflite')\n",
"with open(tflite_file, \"wb\") as f:\n",
" f.write(tflite_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the TOSA Checker to check this model."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Is model TOSA compatible ? True\n"
]
}
],
"source": [
"checker = tc.TOSAChecker(model_path=tflite_file)\n",
"result = checker.is_tosa_compatible()\n",
"print(\"Is model TOSA compatible ? {}\".format(result))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.0 ('tosa_checker': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}