blob: 69d97ccf645b47b736fd57d4eb414ba18bdebd89 [file] [log] [blame]
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <string>
7#include <map>
8#include <vector>
9#include <algorithm>
10#include <cmath>
11
12# pragma once
13
14namespace asr
15{
16/**
17* @brief Class used to Decode the output of the ASR inference
18*
19*/
20 class Decoder
21 {
22 public:
23 std::map<int, std::string> m_labels;
24 /**
25 * @brief Default constructor
26 * @param[in] labels - map of labels to be used for decoding to text.
27 */
28 Decoder(std::map<int, std::string>& labels);
29
30 /**
31 * @brief Function to decode the output into a text string
32 * @param[in] output - the output vector to decode.
33 */
34 template<typename T>
35 std::string DecodeOutput(std::vector<T>& contextToProcess)
36 {
37 int rowLength = 29;
38
39 std::vector<char> unfilteredText;
40
41 for(int row = 0; row < contextToProcess.size()/rowLength; ++row)
42 {
43 std::vector<int16_t> rowVector;
44 for(int j = 0; j < rowLength; ++j)
45 {
46 rowVector.emplace_back(static_cast<int16_t>(contextToProcess[row * rowLength + j]));
47 }
48
49 int max_index = std::distance(rowVector.begin(),std::max_element(rowVector.begin(), rowVector.end()));
50 unfilteredText.emplace_back(this->m_labels.at(max_index)[0]);
51 }
52
53 std::string filteredText = FilterCharacters(unfilteredText);
54 return filteredText;
55 }
56
57 /**
58 * @brief Function to filter out unwanted characters
59 * @param[in] unfiltered - the unfiltered output to be processed.
60 */
61 std::string FilterCharacters(std::vector<char>& unfiltered);
62 };
63} // namespace asr