blob: 2ce94b830752d0cc3abe9be9b910b18b837df9a7 [file] [log] [blame]
Alex Tawsedaba3cf2023-09-29 15:55:38 +01001# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16Utility functions related to data and models that are common to all the model conditioning examples.
17"""
18import tensorflow as tf
19import numpy as np
20
21
22def get_data():
23 """Downloads and returns the pre-processed data and labels for training and testing.
24
25 Returns:
26 Tuple of: (train data, train labels, test data, test labels)
27 """
28
29 # To save time we use the MNIST dataset for this example.
30 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
31
32 # Convolution operations require data to have 4 dimensions.
33 # We divide by 255 to help training and cast to float32 for TensorFlow.
34 x_train = (x_train[..., np.newaxis] / 255.0).astype(np.float32)
35 x_test = (x_test[..., np.newaxis] / 255.0).astype(np.float32)
36
37 return x_train, y_train, x_test, y_test
38
39
40def create_model():
41 """Create and returns a simple Keras model for training MNIST.
42
43 We will use a simple convolutional neural network for this example,
44 but the model optimization methods employed should be compatible with a
45 wide variety of CNN architectures such as Mobilenet and Inception etc.
46
47 Returns:
48 Uncompiled Keras model.
49 """
50
51 keras_model = tf.keras.models.Sequential([
Alex Tawsedaba3cf2023-09-29 15:55:38 +010052 tf.keras.layers.Conv2D(32, 3, padding='same',
53 input_shape=(28, 28, 1), activation=tf.nn.relu),
alexander3c798932021-03-26 21:42:19 +000054 tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
55 tf.keras.layers.MaxPool2D(),
56 tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
57 tf.keras.layers.MaxPool2D(),
58 tf.keras.layers.Flatten(),
59 tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)
60 ])
61
62 return keras_model