| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Import the modules needed to create a test model and run the TOSA Checker." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import tosa_checker as tc\n", |
| "import tensorflow as tf\n", |
| "import tempfile\n", |
| "import os" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Create a simple model that is compatible with the TOSA specification." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n" |
| ] |
| } |
| ], |
| "source": [ |
| "input = tf.keras.layers.Input(shape=(16,))\n", |
| "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n", |
| "model = tf.keras.models.Model(inputs=[input], outputs=x)\n", |
| "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", |
| "tflite_model = converter.convert()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "_, tflite_file = tempfile.mkstemp('.tflite')\n", |
| "with open(tflite_file, \"wb\") as f:\n", |
| " f.write(tflite_model)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Use the TOSA Checker to check this model." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Is model TOSA compatible ? True\n" |
| ] |
| } |
| ], |
| "source": [ |
| "checker = tc.TOSAChecker(model_path=tflite_file)\n", |
| "result = checker.is_tosa_compatible()\n", |
| "print(\"Is model TOSA compatible ? {}\".format(result))" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3.8.0 ('tosa_checker': venv)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.8.0" |
| }, |
| "orig_nbformat": 4 |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 2 |
| } |