blob: cf09fdefb856afefb2d219d6dfb3a4b68f7a4869 [file] [log] [blame]
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""Contains helper functions that can be used across the example apps."""
5
6import os
7import errno
8from pathlib import Path
9
10import numpy as np
11
12
13def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
14 """Creates a dictionary of labels from the input labels file.
15
16 Args:
17 labels_file: Path to file containing labels to map model outputs.
18 include_rgb: Adds randomly generated RGB values to the values of the
19 dictionary. Used for plotting bounding boxes of different colours.
20
21 Returns:
22 Dictionary with classification indices for keys and labels for values.
23
24 Raises:
25 FileNotFoundError:
26 Provided `labels_file_path` does not exist.
27 """
28 labels_file = Path(labels_file_path)
29 if not labels_file.is_file():
30 raise FileNotFoundError(
31 errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path
32 )
33
34 labels = {}
35 with open(labels_file, "r") as f:
36 for idx, line in enumerate(f, 0):
37 if include_rgb:
38 labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255)
39 else:
40 labels[idx] = line.strip("\n")
41 return labels