blob: 1a4dd7aac347859fdb34da3af37854401c73ce85 [file] [log] [blame]
Tracy Narine10403ec2023-11-28 11:55:08 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Exceptions.hpp>
7
8#pragma once
9
10inline void CreateRescaleTosaOperator(const std::string& inputName,
11 const std::string& outputName,
12 DType output_type,
13 const std::vector<int32_t>& shape,
14 int32_t scale_multiplier,
15 int32_t scale_shift,
16 int32_t input_zp,
17 int32_t output_zp,
18 bool double_round,
19 bool scale32,
20 TosaSerializationOperator** op,
21 TosaSerializationTensor** tensor)
22{
23 if (!op)
24 {
25 throw armnn::Exception("CreateRescaleTosaOperator: nullptr op");
26 }
27
28 std::vector<int32_t> multipliers{scale_multiplier};
29 std::vector<int32_t> shifts{scale_shift};
30 TosaRescaleAttribute attribute(input_zp,
31 output_zp,
32 multipliers,
33 shifts,
34 scale32,
35 double_round,
Teresa Charlin571a4f72024-03-26 11:18:42 +000036 false, // per_channel
37 false, // input_unsigned
38 false); // output_unsigned
Tracy Narine10403ec2023-11-28 11:55:08 +000039
40 // op
41 *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
42 if (!(*op))
43 {
44 throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
45 }
46 if (tensor != nullptr)
47 {
48 // tensor
49 *tensor = new TosaSerializationTensor(outputName, shape, output_type, {});
50 if (! (*tensor))
51 {
52 throw armnn::Exception("CreateRescaleTosaOperator: failed to created tensor");
53 }
54 }
55}
56
57inline void CreateRescaleTosaOperator(const std::string& inputName,
58 const std::string& outputName,
59 DType output_type,
60 const std::vector<int32_t>& shape,
61 double scale,
62 int32_t input_zp,
63 int32_t output_zp,
64 bool double_round,
65 bool scale32,
66 TosaSerializationOperator** op,
67 TosaSerializationTensor** tensor)
68{
69 // The code that follows is based on the behaviour specified in
70 // https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
71
72 auto GetScaleParams = [](double scale, double& m, int32_t& n)
73 {
74 m = 0;
75 n = 0;
76
77 double lastErr = 1e06;
78
79 const int32_t numExponents = 62;
80 const double start = 1.0;
81 const double end = 2.0;
82
83 // Slow iterative approach but running in Reference only
84 for (int32_t i = 0; i < numExponents; ++i)
85 {
86 double exp = 1.0 / (1 << i);
87 double currentM = scale / exp; // Find current m given value = currentM * exp
88 if ((currentM >= start) && (currentM < end))
89 {
90 double value = currentM * exp;
91 double err = std::abs(scale - value);
92 if (err < lastErr)
93 {
94 // Take the m, n that minimize the error
95 n = i;
96 m = currentM;
97 lastErr = err;
98 }
99 }
100 }
101 };
102
103 auto GetMultiplierShiftByScale = [GetScaleParams](bool scale32, double scale, int32_t& multiplier, int32_t& shift)
104 {
105 double m = 0;
106 int32_t n = 0;
107
108 GetScaleParams(scale, m, n);
109
110 multiplier = (scale32) ? (1 << 30) * static_cast<int32_t>(m) : (1 << 14) * static_cast<int32_t>(m);
111 shift = (scale32) ? (30 + n) : (14 + n);
112 };
113
114 int32_t multiplier;
115 int32_t shift;
116 GetMultiplierShiftByScale(scale32, scale, multiplier, shift);
117 CreateRescaleTosaOperator(inputName, outputName, output_type, shape, multiplier, shift,
118 input_zp, output_zp, double_round, scale32, op, tensor);
119}
120
121inline void CreateFromInt32RescaleTosaOperator(const std::string& inputName,
122 const std::string& outputName,
123 DType output_type,
124 const std::vector<int32_t>& shape,
125 double output_scale,
126 int32_t output_zp,
127 TosaSerializationOperator** op,
128 TosaSerializationTensor** tensor)
129{
130 CreateRescaleTosaOperator(inputName, outputName, output_type, shape,
131 output_scale, 0, output_zp, true, true, op, tensor);
132}