blob: e8febeed45db76653c37ccaf14df2377024239e5 [file] [log] [blame]
Richard Burton00553462021-11-10 16:27:14 +00001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
Richard Burton00553462021-11-10 16:27:14 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "RNNoiseModel.hpp"
alexander31ae9f02022-02-10 16:15:54 +000018#include "log_macros.h"
Richard Burton00553462021-11-10 16:27:14 +000019
20const tflite::MicroOpResolver& arm::app::RNNoiseModel::GetOpResolver()
21{
22 return this->m_opResolver;
23}
24
25bool arm::app::RNNoiseModel::EnlistOperations()
26{
27 this->m_opResolver.AddUnpack();
28 this->m_opResolver.AddFullyConnected();
29 this->m_opResolver.AddSplit();
30 this->m_opResolver.AddSplitV();
31 this->m_opResolver.AddAdd();
32 this->m_opResolver.AddLogistic();
33 this->m_opResolver.AddMul();
34 this->m_opResolver.AddSub();
35 this->m_opResolver.AddTanh();
36 this->m_opResolver.AddPack();
37 this->m_opResolver.AddReshape();
38 this->m_opResolver.AddQuantize();
39 this->m_opResolver.AddConcatenation();
40 this->m_opResolver.AddRelu();
41
Richard Burton00553462021-11-10 16:27:14 +000042 if (kTfLiteOk == this->m_opResolver.AddEthosU()) {
43 info("Added %s support to op resolver\n",
44 tflite::GetString_ETHOSU());
45 } else {
46 printf_err("Failed to add Arm NPU support to op resolver.");
47 return false;
48 }
Richard Burton00553462021-11-10 16:27:14 +000049 return true;
50}
51
Richard Burton00553462021-11-10 16:27:14 +000052bool arm::app::RNNoiseModel::RunInference()
53{
54 return Model::RunInference();
55}
56
57void arm::app::RNNoiseModel::ResetGruState()
58{
59 for (auto& stateMapping: this->m_gruStateMap) {
60 TfLiteTensor* inputGruStateTensor = this->GetInputTensor(stateMapping.second);
61 auto* inputGruState = tflite::GetTensorData<int8_t>(inputGruStateTensor);
62 /* Initial value of states is 0, but this is affected by quantization zero point. */
63 auto quantParams = arm::app::GetTensorQuantParams(inputGruStateTensor);
64 memset(inputGruState, quantParams.offset, inputGruStateTensor->bytes);
65 }
66}
67
68bool arm::app::RNNoiseModel::CopyGruStates()
69{
70 std::vector<std::pair<size_t, std::vector<int8_t>>> tempOutGruStates;
71 /* Saving output states before copying them to input states to avoid output states modification in the tensor.
72 * tflu shares input and output tensors memory, thus writing to input tensor can change output tensor values. */
73 for (auto& stateMapping: this->m_gruStateMap) {
74 TfLiteTensor* outputGruStateTensor = this->GetOutputTensor(stateMapping.first);
75 std::vector<int8_t> tempOutGruState(outputGruStateTensor->bytes);
76 auto* outGruState = tflite::GetTensorData<int8_t>(outputGruStateTensor);
77 memcpy(tempOutGruState.data(), outGruState, outputGruStateTensor->bytes);
78 /* Index of the input tensor and the data to copy. */
79 tempOutGruStates.emplace_back(stateMapping.second, std::move(tempOutGruState));
80 }
81 /* Updating input GRU states with saved GRU output states. */
82 for (auto& stateMapping: tempOutGruStates) {
83 auto outputGruStateTensorData = stateMapping.second;
84 TfLiteTensor* inputGruStateTensor = this->GetInputTensor(stateMapping.first);
85 if (outputGruStateTensorData.size() != inputGruStateTensor->bytes) {
86 printf_err("Unexpected number of bytes for GRU state mapping. Input = %zuz, output = %zuz.\n",
87 inputGruStateTensor->bytes,
88 outputGruStateTensorData.size());
89 return false;
90 }
91 auto* inputGruState = tflite::GetTensorData<int8_t>(inputGruStateTensor);
92 auto* outGruState = outputGruStateTensorData.data();
93 memcpy(inputGruState, outGruState, inputGruStateTensor->bytes);
94 }
95 return true;
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010096}