| # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. |
| # SPDX-License-Identifier: Apache-2.0 |
| """Sequential trainer.""" |
| # pylint: disable=too-many-arguments |
| # pylint: disable=too-many-locals |
| # pylint: disable=too-many-statements |
| from __future__ import annotations |
| |
| import logging |
| import math |
| import os |
| import tempfile |
| from collections import defaultdict |
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
| from typing import Callable |
| from typing import cast |
| from typing import Generator as GeneratorType |
| from typing import get_args |
| from typing import Literal |
| |
| import numpy as np |
| import tensorflow as tf |
| from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 |
| from numpy.random import Generator |
| |
| from mlia.nn.rewrite.core.extract import extract |
| from mlia.nn.rewrite.core.extract import ExtractPaths |
| from mlia.nn.rewrite.core.graph_edit.diff import diff_stats |
| from mlia.nn.rewrite.core.graph_edit.join import join_models |
| from mlia.nn.rewrite.core.graph_edit.record import record_model |
| from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count |
| from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read |
| from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel |
| from mlia.nn.tensorflow.config import TFLiteModel |
| from mlia.nn.tensorflow.tflite_convert import convert_to_tflite |
| from mlia.nn.tensorflow.tflite_graph import load_fb |
| from mlia.nn.tensorflow.tflite_graph import save_fb |
| from mlia.utils.logging import log_action |
| |
| |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
| logger = logging.getLogger(__name__) |
| |
| AUGMENTATION_PRESETS = { |
| "none": (None, None), |
| "gaussian": (None, 1.0), |
| "mixup": (1.0, None), |
| "mixout": (1.6, None), |
| "mix_gaussian_large": (2.0, 1.0), |
| "mix_gaussian_small": (1.6, 0.3), |
| } |
| |
| LearningRateSchedule = Literal["cosine", "late", "constant"] |
| LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) |
| |
| |
| @dataclass |
| class TrainingParameters: |
| """Define default parameters for the training.""" |
| |
| augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] |
| batch_size: int = 32 |
| steps: int = 48000 |
| learning_rate: float = 1e-3 |
| learning_rate_schedule: LearningRateSchedule = "cosine" |
| num_procs: int = 1 |
| num_threads: int = 0 |
| show_progress: bool = True |
| checkpoint_at: list | None = None |
| |
| |
| def train( # pylint: disable=too-many-arguments |
| source_model: str, |
| unmodified_model: Any, |
| output_model: str, |
| input_tfrec: str, |
| rewrite: Callable, |
| is_qat: bool, |
| input_tensors: list, |
| output_tensors: list, |
| train_params: TrainingParameters = TrainingParameters(), |
| ) -> Any: |
| """Extract and train a model, and return the results.""" |
| if unmodified_model: |
| unmodified_model_dir = ( |
| tempfile.TemporaryDirectory() # pylint: disable=consider-using-with |
| ) |
| unmodified_model_dir_path = unmodified_model_dir.name |
| extract( |
| unmodified_model_dir_path, |
| source_model, |
| input_tfrec, |
| input_tensors, |
| output_tensors, |
| dequantize_output=True, |
| ) |
| else: |
| unmodified_model_dir = None |
| unmodified_model_dir_path = None |
| |
| results = [] |
| with tempfile.TemporaryDirectory() as train_dir: |
| extract( |
| train_dir, |
| source_model, |
| input_tfrec, |
| input_tensors, |
| output_tensors, |
| num_procs=train_params.num_procs, |
| num_threads=train_params.num_threads, |
| dequantize_output=True, |
| ) |
| |
| tflite_filenames = train_in_dir( |
| train_dir=train_dir, |
| baseline_dir=unmodified_model_dir_path, |
| output_filename=Path(train_dir, "new.tflite"), |
| rewrite=rewrite, |
| is_qat=is_qat, |
| train_params=train_params, |
| ) |
| |
| for i, filename in enumerate(tflite_filenames): |
| results.append( |
| eval_in_dir( |
| train_dir, |
| filename, |
| train_params.num_procs, |
| train_params.num_threads, |
| ) |
| ) |
| |
| if output_model: |
| if i + 1 < len(tflite_filenames): |
| # Append the same _@STEPS.tflite postfix used by intermediate |
| # checkpoints for all but the last output |
| postfix = filename.split("_@")[-1] |
| output_filename = output_model.split(".tflite")[0] + postfix |
| else: |
| output_filename = output_model |
| join_in_dir(train_dir, filename, output_filename) |
| |
| # Assess the output diff between the parts after the rewrite subgraph |
| # in original and optimized model |
| optimized_end_path = Path(train_dir, "optimized_end.tfrec") |
| end_path = Path(train_dir, "end.tfrec") |
| |
| record_model( |
| str(input_tfrec), |
| output_filename, |
| optimized_end_path, |
| num_procs=train_params.num_procs, |
| num_threads=train_params.num_threads, |
| ) |
| mae, nrmse = diff_stats(end_path, str(optimized_end_path)) |
| |
| if unmodified_model_dir: |
| cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() |
| |
| return results, [ |
| mae, |
| nrmse, |
| ] |
| |
| |
| def eval_in_dir( |
| target_dir: str, |
| new_part: str, |
| num_procs: int = 1, |
| num_threads: int = 0, |
| ) -> tuple: |
| """Evaluate a model in a given directory.""" |
| model_input_path = Path(target_dir, "input_orig.tfrec") |
| model_output_path = Path(target_dir, "output_orig.tfrec") |
| model_input = ( |
| model_input_path |
| if model_input_path.exists() |
| else ExtractPaths.tfrec.input(target_dir, False) |
| ) |
| output = ( |
| model_output_path |
| if model_output_path.exists() |
| else ExtractPaths.tfrec.output(target_dir, False) |
| ) |
| |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| predict = Path(tmp_dir, "predict.tfrec") |
| record_model( |
| str(model_input), |
| new_part, |
| str(predict), |
| num_procs=num_procs, |
| num_threads=num_threads, |
| ) |
| mae, nrmse = diff_stats(str(output), str(predict)) |
| |
| return mae, nrmse |
| |
| |
| def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: |
| """Join two models in a given directory.""" |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| new_end = Path(tmp_dir, "new_end.tflite") |
| join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end) |
| join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model) |
| |
| |
| def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: |
| assert ( |
| len(model.input_tensors()) == 1 |
| ), f"Can only train replacements with a single input tensor right now, \ |
| found {model.input_tensors()}" |
| |
| assert ( |
| len(model.output_tensors()) == 1 |
| ), f"Can only train replacements with a single output tensor right now, \ |
| found {model.output_tensors()}" |
| |
| input_name = model.input_tensors()[0] |
| output_name = model.output_tensors()[0] |
| return (input_name, output_name) |
| |
| |
| def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None: |
| """Assert that teacher and replaced sub-graph are compatible.""" |
| assert len(teacher.shape_from_name) == len( |
| replace.shape_from_name |
| ), f"Baseline and train models must have the same number of inputs and outputs. \ |
| Teacher: {teacher.shape_from_name}\nTrain dir: {replace.shape_from_name}" |
| |
| assert all( |
| tn == rn and (ts[1:] == rs[1:]).all() |
| for (tn, ts), (rn, rs) in zip( |
| teacher.shape_from_name.items(), replace.shape_from_name.items() |
| ) |
| ), "Baseline and train models must have the same input and output shapes for the \ |
| subgraph being replaced. Teacher: {teacher.shape_from_name}\n \ |
| Train dir: {replace.shape_from_name}" |
| |
| |
| def set_up_data_pipeline( |
| teacher: TFLiteModel, |
| replace: TFLiteModel, |
| train_dir: str, |
| augmentations: tuple[float | None, float | None], |
| steps: int, |
| batch_size: int = 32, |
| ) -> tf.data.Dataset: |
| """Create a data pipeline for training of the replacement model.""" |
| _check_model_compatibility(teacher, replace) |
| |
| input_name, output_name = _get_io_tensors(teacher) |
| |
| model_is_quantized = replace.is_tensor_quantized(name=input_name) |
| |
| input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized) |
| total = numpytf_count(str(input_filename)) |
| dict_inputs = numpytf_read(str(input_filename)) |
| |
| inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) |
| |
| steps_per_epoch = math.ceil(total / batch_size) |
| epochs = int(math.ceil(steps / steps_per_epoch)) |
| logger.info( |
| "Training on %d items for %d steps (%d epochs with batch size %d)", |
| total, |
| epochs * steps_per_epoch, |
| epochs, |
| batch_size, |
| ) |
| |
| teacher_dir = Path(teacher.model_path).parent |
| if any(augmentations): |
| # Map the teacher inputs here because the augmentation stage passes these |
| # through a TFLite model to get the outputs |
| teacher_outputs = numpytf_read( |
| str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized)) |
| ).map(lambda d: tf.squeeze(d[input_name], axis=0)) |
| else: |
| teacher_outputs = numpytf_read( |
| str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized)) |
| ).map(lambda d: tf.squeeze(d[output_name], axis=0)) |
| |
| dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) |
| if epochs > 1: |
| dataset = dataset.cache() |
| dataset = dataset.shuffle(total).repeat().batch(batch_size) |
| |
| if any(augmentations): |
| augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations) |
| |
| def get_augment_results( |
| train: Any, teach: Any # pylint: disable=redefined-outer-name |
| ) -> tuple: |
| """Return results of train and teach based on augmentations.""" |
| augmented_train = augment_train({input_name: train})[input_name] |
| # If augmentation of the input is enabled, we need to re-generate |
| # the reference output by running the augmented data through the |
| # teacher model. |
| if model_is_quantized: |
| # If the input model is quantized we have to additionally |
| # - quantize the augmented data before running it through the |
| # (quantized) teacher model |
| # - de-quantize the output for the training. |
| augmented_teach = teacher.dequantize_outputs( |
| teacher( |
| teacher.quantize_inputs(augment_teacher({input_name: teach})) |
| ) |
| )[output_name] |
| else: |
| augmented_teach = teacher(augment_teacher({input_name: teach}))[ |
| output_name |
| ] |
| return (augmented_train, augmented_teach) |
| |
| dataset = dataset.map( |
| lambda augment_train, augment_teach: tf.py_function( |
| get_augment_results, |
| inp=[augment_train, augment_teach], |
| Tout=[tf.float32, tf.float32], |
| ) |
| ) |
| |
| # Restore data shapes of the dataset as they are set to unknown per default |
| # and get lost during augmentation with tf.py_function. |
| shape_in, shape_out = ( |
| teacher.shape_from_name[name].tolist() for name in (input_name, output_name) |
| ) |
| for shape in (shape_in, shape_out): |
| shape[0] = None # set dynamic batch size |
| |
| def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]: |
| input_.set_shape(shape_in) |
| output.set_shape(shape_out) |
| return input_, output |
| |
| dataset = dataset.map(restore_shapes) |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) |
| return dataset |
| |
| |
| def train_in_dir( |
| train_dir: str, |
| baseline_dir: Any, |
| output_filename: Path, |
| rewrite: Callable, |
| is_qat: bool, |
| train_params: TrainingParameters = TrainingParameters(), |
| ) -> list[str]: |
| """Train a replacement for replace.tflite using the input.tfrec \ |
| and output.tfrec in train_dir. |
| |
| If baseline_dir is provided, train the replacement to match baseline |
| outputs for train_dir inputs. Result saved as new.tflite in train_dir. |
| """ |
| teacher_dir = baseline_dir if baseline_dir else train_dir |
| teacher = ParallelTFLiteModel( |
| ExtractPaths.tflite.replace(teacher_dir), |
| train_params.num_procs, |
| train_params.num_threads, |
| batch_size=train_params.batch_size, |
| ) |
| replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) |
| |
| input_name, output_name = _get_io_tensors(teacher) |
| |
| model_is_quantized = replace.is_tensor_quantized(name=input_name) |
| |
| if model_is_quantized: |
| replace.check_datatypes(np.int8) |
| |
| dataset = set_up_data_pipeline( |
| teacher, |
| replace, |
| train_dir, |
| augmentations=train_params.augmentations, |
| steps=train_params.steps, |
| batch_size=train_params.batch_size, |
| ) |
| |
| input_shape = teacher.shape_from_name[input_name][1:] |
| |
| output_shape = teacher.shape_from_name[output_name][1:] |
| |
| optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) |
| loss_fn = keras.losses.MeanSquaredError() |
| |
| model = create_model( |
| rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized |
| ) |
| |
| logger.info(model.summary()) |
| |
| steps_so_far = 0 |
| |
| def cosine_decay( |
| epoch_step: int, logs: Any # pylint: disable=unused-argument |
| ) -> None: |
| """Cosine decay from learning rate at start of the run to zero at the end.""" |
| current_step = epoch_step + steps_so_far |
| cd_learning_rate = ( |
| train_params.learning_rate |
| * (math.cos(math.pi * current_step / train_params.steps) + 1) |
| / 2.0 |
| ) |
| keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) |
| |
| def late_decay( |
| epoch_step: int, logs: Any # pylint: disable=unused-argument |
| ) -> None: |
| """Constant until the last 20% of the run, then linear decay to zero.""" |
| current_step = epoch_step + steps_so_far |
| steps_remaining = train_params.steps - current_step |
| decay_length = train_params.steps // 5 |
| decay_fraction = min(steps_remaining, decay_length) / decay_length |
| ld_learning_rate = train_params.learning_rate * decay_fraction |
| keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) |
| |
| assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, ( |
| f'Learning rate schedule "{train_params.learning_rate_schedule}" ' |
| f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}." |
| ) |
| if train_params.learning_rate_schedule == "cosine": |
| callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] |
| elif train_params.learning_rate_schedule == "late": |
| callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] |
| elif train_params.learning_rate_schedule == "constant": |
| callbacks = [] |
| |
| callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] |
| output_filenames: list = [] |
| checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ |
| train_params.steps |
| ] |
| model, output_filenames = model_fit( |
| model, |
| train_params, |
| checkpoints, |
| optimizer, |
| dataset, |
| callbacks, |
| output_filename, |
| rewrite, |
| replace, |
| input_name, |
| output_name, |
| model_is_quantized, |
| output_filenames, |
| input_shape, |
| output_shape, |
| loss_fn, |
| post_process=True, |
| ) |
| |
| # Placeholder for now, will be parametrized later (MLIA-1114) |
| # rewrite.check_optimization( # type: ignore[attr-defined] |
| # model, number_of_clusters=32 |
| # ) |
| if model_is_quantized and is_qat: |
| model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] |
| checkpoints = ( |
| train_params.checkpoint_at if train_params.checkpoint_at else [] |
| ) + [train_params.steps] |
| output_filenames = [] |
| |
| if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] |
| rewrite.training_callbacks() # type: ignore[attr-defined] |
| ).issubset(callbacks): |
| callbacks.pop(-1) |
| |
| optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) |
| model = model_compile(model, optimizer, loss_fn) |
| |
| model, output_filenames = model_fit( |
| model, |
| train_params, |
| checkpoints, |
| optimizer, |
| dataset, |
| callbacks, |
| output_filename, |
| rewrite, |
| replace, |
| input_name, |
| output_name, |
| model_is_quantized, |
| output_filenames, |
| input_shape, |
| output_shape, |
| loss_fn, |
| ) |
| # Placeholder for now, will be parametrized later (MLIA-1114) |
| # rewrite.check_optimization( # type: ignore[attr-defined] |
| # model, number_of_clusters=32 |
| # ) |
| |
| teacher.close() |
| return output_filenames |
| |
| |
| def model_compile( |
| model: keras.Model, |
| optimizer: keras.optimizers.Nadam, |
| loss_fn: keras.losses.Loss, |
| ) -> keras.Model: |
| """Compiles a tflite model.""" |
| model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) |
| return model |
| |
| |
| def create_model( # pylint: disable=too-many-arguments |
| rewrite: Callable, |
| input_shape: int, |
| output_shape: int, |
| optimizer: Callable, |
| loss_fn: Callable, |
| model_is_quantized: bool, |
| model_to_load_from: keras.model | None = None, |
| ) -> keras.Model: |
| """Create a model, optionally from another.""" |
| model = rewrite(input_shape, output_shape) |
| if model_is_quantized: |
| model = rewrite.quantize(model) # type: ignore[attr-defined] |
| model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) |
| if model_to_load_from: |
| model.set_weights(model_to_load_from.get_weights()) |
| return model |
| |
| |
| def model_fit( # pylint: disable=too-many-arguments |
| model: keras.Model, |
| train_params: TrainingParameters, |
| checkpoints: list, |
| optimizer: tf.optimizers.Nadam, |
| dataset: tf.data.Dataset, |
| callbacks: list, |
| output_filename: Path, |
| rewrite: Callable, |
| replace: TFLiteModel, |
| input_name: str, |
| output_name: str, |
| model_is_quantized: bool, |
| output_filenames: list, |
| input_shape: int, |
| output_shape: int, |
| loss_fn: Callable, |
| post_process: bool = False, |
| ) -> keras.Model: |
| """Train a tflite model.""" |
| steps_so_far = 0 |
| while steps_so_far < train_params.steps: |
| steps_to_train = checkpoints.pop(0) - steps_so_far |
| lr_start = optimizer.learning_rate.numpy() |
| model.fit( |
| dataset, |
| epochs=1, |
| steps_per_epoch=steps_to_train, |
| callbacks=callbacks, |
| verbose=train_params.show_progress, |
| ) |
| steps_so_far += steps_to_train |
| logger.info( |
| "lr decayed from %f to %f over %d steps", |
| lr_start, |
| optimizer.learning_rate.numpy(), |
| steps_to_train, |
| ) |
| |
| if steps_so_far < train_params.steps: |
| filename = Path(output_filename).stem |
| filename_dir = Path(output_filename).parent.as_posix() |
| ext = Path(output_filename).suffix |
| checkpoint_filename = ( |
| filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext |
| ) |
| # If post processing we are stripping the clustering/pruning layers below |
| # Thus copy the model before saving, so training can continue |
| if post_process: |
| model_to_save = create_model( |
| rewrite, |
| input_shape, |
| output_shape, |
| optimizer, |
| loss_fn, |
| model_is_quantized, |
| model_to_load_from=model, |
| ) |
| else: |
| model_to_save = model |
| else: |
| checkpoint_filename = str(output_filename) |
| model_to_save = model |
| |
| with log_action( |
| f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" |
| ): |
| if post_process: |
| model_to_save = rewrite.post_process( # type: ignore[attr-defined] |
| model_to_save |
| ) |
| save_as_tflite( |
| model_to_save, |
| checkpoint_filename, |
| input_name, |
| replace.shape_from_name[input_name], |
| output_name, |
| replace.shape_from_name[output_name], |
| model_is_quantized, |
| ) |
| output_filenames.append(checkpoint_filename) |
| |
| return model_to_save, output_filenames |
| |
| |
| def save_as_tflite( |
| keras_model: keras.Model, |
| filename: str, |
| input_name: str, |
| input_shape: list, |
| output_name: str, |
| output_shape: list, |
| model_is_quantized: bool = False, |
| ) -> None: |
| """Save Keras model as TFLite file.""" |
| |
| @contextmanager |
| def fixed_input(keras_model: keras.Model, tmp_shape: list) -> GeneratorType: |
| """Fix the input shape of the Keras model temporarily. |
| |
| This avoids artifacts during conversion to TensorFlow Lite. |
| """ |
| orig_shape = keras_model.input.shape |
| keras_model.input.set_shape(tf.TensorShape(tmp_shape)) |
| try: |
| yield keras_model |
| finally: |
| # Restore original shape to not interfere with further training |
| keras_model.input.set_shape(orig_shape) |
| |
| with fixed_input(keras_model, input_shape) as fixed_model: |
| convert_to_tflite(fixed_model, model_is_quantized, Path(filename)) |
| |
| # Now fix the shapes and names to match those we expect |
| flatbuffer = load_fb(filename) |
| i = flatbuffer.subgraphs[0].inputs[0] |
| flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) |
| flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8") |
| output = flatbuffer.subgraphs[0].outputs[0] |
| flatbuffer.subgraphs[0].tensors[output].shape = np.array( |
| output_shape, dtype=np.int32 |
| ) |
| flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8") |
| save_fb(flatbuffer, filename) |
| |
| |
| def augment_fn_twins( |
| inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None] |
| ) -> Any: |
| """Return a pair of twinned augmentation functions with the same sequence \ |
| of random numbers.""" |
| seed = np.random.randint(2**32 - 1) |
| rng1 = np.random.default_rng(seed) |
| rng2 = np.random.default_rng(seed) |
| return augment_fn(inputs, augmentations, rng1), augment_fn( |
| inputs, augmentations, rng2 |
| ) |
| |
| |
| def augment_fn( |
| inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator |
| ) -> Any: |
| """Augmentation module.""" |
| assert len(augmentations) == 2, ( |
| f"Unexpected number of augmentation parameters: {len(augmentations)} " |
| "(must be 2)" |
| ) |
| |
| mixup_strength, gaussian_strength = augmentations |
| |
| augments = [] |
| |
| if mixup_strength: |
| mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2) |
| |
| def mixup_augment(augment_dict: dict) -> dict: |
| return { |
| k: mixup(rng, v.numpy(), mixup_range) for k, v in augment_dict.items() |
| } |
| |
| augments.append(mixup_augment) |
| |
| if gaussian_strength: |
| values = defaultdict(list) |
| for numpy_dict in inputs.as_numpy_iterator(): |
| for key, value in numpy_dict.items(): |
| values[key].append(value) |
| noise_scale = { |
| k: np.std(v, axis=0).astype(np.float32) for k, v in values.items() |
| } |
| |
| def gaussian_strength_augment(augment_dict: dict) -> dict: |
| return { |
| k: v |
| + rng.standard_normal(v.shape).astype(np.float32) |
| * gaussian_strength |
| * noise_scale[k] |
| for k, v in augment_dict.items() |
| } |
| |
| augments.append(gaussian_strength_augment) |
| |
| if len(augments) == 0: |
| return lambda x: x |
| if len(augments) == 1: |
| return augments[0] |
| if len(augments) == 2: |
| return lambda x: augments[1](augments[0](x)) |
| |
| raise RuntimeError( |
| "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}" |
| ) |
| |
| |
| def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: |
| """Each tensor in the batch becomes a linear combination of it \ |
| and one other tensor.""" |
| batch_a = batch |
| batch_b = np.array(batch) |
| rng.shuffle(batch_b) # randomly pair up tensors in the batch |
| # random mixing coefficient for each pair |
| beta = rng.uniform( |
| low=beta_range[0], high=beta_range[1], size=batch.shape[0] |
| ).astype(np.float32) |
| return (batch_a.T * beta).T + ( |
| batch_b.T * (1.0 - beta) |
| ).T # return linear combinations |