blob: a0432846f5ec05aa2fef24932709c8885ba3a9c2 [file] [log] [blame]
Tracy Narine10403ec2023-11-28 11:55:08 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Exceptions.hpp>
7
8#pragma once
9
10inline void CreateRescaleTosaOperator(const std::string& inputName,
11 const std::string& outputName,
12 DType output_type,
13 const std::vector<int32_t>& shape,
14 int32_t scale_multiplier,
15 int32_t scale_shift,
16 int32_t input_zp,
17 int32_t output_zp,
18 bool double_round,
19 bool scale32,
20 TosaSerializationOperator** op,
21 TosaSerializationTensor** tensor)
22{
23 if (!op)
24 {
25 throw armnn::Exception("CreateRescaleTosaOperator: nullptr op");
26 }
27
28 std::vector<int32_t> multipliers{scale_multiplier};
29 std::vector<int32_t> shifts{scale_shift};
30 TosaRescaleAttribute attribute(input_zp,
31 output_zp,
32 multipliers,
33 shifts,
34 scale32,
35 double_round,
36 false);
37
38 // op
39 *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
40 if (!(*op))
41 {
42 throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
43 }
44 if (tensor != nullptr)
45 {
46 // tensor
47 *tensor = new TosaSerializationTensor(outputName, shape, output_type, {});
48 if (! (*tensor))
49 {
50 throw armnn::Exception("CreateRescaleTosaOperator: failed to created tensor");
51 }
52 }
53}
54
55inline void CreateRescaleTosaOperator(const std::string& inputName,
56 const std::string& outputName,
57 DType output_type,
58 const std::vector<int32_t>& shape,
59 double scale,
60 int32_t input_zp,
61 int32_t output_zp,
62 bool double_round,
63 bool scale32,
64 TosaSerializationOperator** op,
65 TosaSerializationTensor** tensor)
66{
67 // The code that follows is based on the behaviour specified in
68 // https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
69
70 auto GetScaleParams = [](double scale, double& m, int32_t& n)
71 {
72 m = 0;
73 n = 0;
74
75 double lastErr = 1e06;
76
77 const int32_t numExponents = 62;
78 const double start = 1.0;
79 const double end = 2.0;
80
81 // Slow iterative approach but running in Reference only
82 for (int32_t i = 0; i < numExponents; ++i)
83 {
84 double exp = 1.0 / (1 << i);
85 double currentM = scale / exp; // Find current m given value = currentM * exp
86 if ((currentM >= start) && (currentM < end))
87 {
88 double value = currentM * exp;
89 double err = std::abs(scale - value);
90 if (err < lastErr)
91 {
92 // Take the m, n that minimize the error
93 n = i;
94 m = currentM;
95 lastErr = err;
96 }
97 }
98 }
99 };
100
101 auto GetMultiplierShiftByScale = [GetScaleParams](bool scale32, double scale, int32_t& multiplier, int32_t& shift)
102 {
103 double m = 0;
104 int32_t n = 0;
105
106 GetScaleParams(scale, m, n);
107
108 multiplier = (scale32) ? (1 << 30) * static_cast<int32_t>(m) : (1 << 14) * static_cast<int32_t>(m);
109 shift = (scale32) ? (30 + n) : (14 + n);
110 };
111
112 int32_t multiplier;
113 int32_t shift;
114 GetMultiplierShiftByScale(scale32, scale, multiplier, shift);
115 CreateRescaleTosaOperator(inputName, outputName, output_type, shape, multiplier, shift,
116 input_zp, output_zp, double_round, scale32, op, tensor);
117}
118
119inline void CreateFromInt32RescaleTosaOperator(const std::string& inputName,
120 const std::string& outputName,
121 DType output_type,
122 const std::vector<int32_t>& shape,
123 double output_scale,
124 int32_t output_zp,
125 TosaSerializationOperator** op,
126 TosaSerializationTensor** tensor)
127{
128 CreateRescaleTosaOperator(inputName, outputName, output_type, shape,
129 output_scale, 0, output_zp, true, true, op, tensor);
130}