blob: a78fd725b446222fbbe2df482e110fd7796e24f0 [file] [log] [blame]
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Types.hpp>
9
10#include <cmath>
11#include <algorithm>
12
13namespace armnn
14{
15
16using OffsetScalePair = std::pair<float, int>;
17
18struct IQuantizationScheme
19{
20 virtual OffsetScalePair ComputeScheme(double min, double max) const = 0;
21
22 virtual int NumBits() const = 0;
23
24 virtual DataType GetDataType() const = 0;
25
26 virtual ~IQuantizationScheme() {}
27};
28
Ryan OShea9add1202020-02-07 10:06:33 +000029struct QAsymmU8QuantizationScheme : IQuantizationScheme
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000030{
31 OffsetScalePair ComputeScheme(double min, double max) const override
32 {
Les Belle0ca8612019-05-17 16:17:12 +010033 if (min > max)
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000034 {
Les Belle0ca8612019-05-17 16:17:12 +010035 throw InvalidArgumentException("min > max will result in invalid quantization.");
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000036 }
37
38 double highest = (1 << NumBits()) - 1;
39
40 min = std::min(0.0, min); // min <= 0.0
41 max = std::max(0.0, max); // max >= 0.0
42
Tee Jungaad2fe42019-11-13 07:17:46 +000043 // To avoid dividing by zero when quantizing a zero filled tensor
44 if (min == 0.0 && max == 0.0)
45 {
46 max = 1.0;
47 }
48
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000049 // Assumes quantization range [0-highest]
50 double scale = (max-min) / highest;
51 double offset = -min / scale;
52
53 // Clamp offset [0-highest]
54 offset = std::max(0.0, std::min(highest, offset));
55
56 return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)));
57 }
58
59 int NumBits() const override { return 8; }
60
Derek Lambertif90c56d2020-01-10 17:14:08 +000061 DataType GetDataType() const override { return DataType::QAsymmU8; }
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000062};
63
Ryan OShea9add1202020-02-07 10:06:33 +000064struct QAsymmS8QuantizationScheme : IQuantizationScheme
65{
66 OffsetScalePair ComputeScheme(double min, double max) const override
67 {
68 if (min > max)
69 {
70 throw InvalidArgumentException("min > max will result in invalid quantization.");
71 }
72
73 double highest = (1 << NumBits()) - 1;
74
75 min = std::min(0.0, min); // min <= 0.0
76 max = std::max(0.0, max); // max >= 0.0
77
78 // To avoid dividing by zero when quantizing a zero filled tensor
79 if (min == 0.0 && max == 0.0)
80 {
81 max = 1.0;
82 }
83
84 // Assumes quantization range [0-255]
85 double scale = (max-min) / highest ;
86 double offset = - min / scale;
87
88 //Clamp 0 to Highest
89 offset = std::max(0.0, std::min(highest, offset));
90
91 //-128 on offset to cast to signed range
92 return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)-128));
93 }
94
95 int NumBits() const override { return 8; }
96
97 DataType GetDataType() const override { return DataType::QAsymmS8; }
98};
99
Finn Williamsfd271062019-12-04 14:27:27 +0000100struct QSymmS8QuantizationScheme : IQuantizationScheme
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000101{
102 OffsetScalePair ComputeScheme(double min, double max) const override
103 {
Les Belle0ca8612019-05-17 16:17:12 +0100104 if (min > max)
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000105 {
Les Belle0ca8612019-05-17 16:17:12 +0100106 throw InvalidArgumentException("min > max will result in invalid quantization.");
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000107 }
108
Tee Jungaad2fe42019-11-13 07:17:46 +0000109 // To avoid dividing by zero when quantizing a zero filled tensor
110 if (min == 0.0 && max == 0.0)
111 {
112 max = 1.0;
113 }
114
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000115 double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
116
117 double extent = std::max(std::abs(min), std::abs(max));
118 double scale = extent / highest;
119
120 return std::make_pair(static_cast<float>(scale), 0);
121 }
122
Finn Williamsfd271062019-12-04 14:27:27 +0000123 int NumBits() const override { return 8; }
124
125 DataType GetDataType() const override { return DataType::QSymmS8; }
126};
127
128struct QSymm16QuantizationScheme : IQuantizationScheme
129{
130 OffsetScalePair ComputeScheme(double min, double max) const override
131 {
132 if (min > max)
133 {
134 throw InvalidArgumentException("min > max will result in invalid quantization.");
135 }
136
137 // To avoid dividing by zero when quantizing a zero filled tensor
138 if (min == 0.0 && max == 0.0)
139 {
140 max = 1.0;
141 }
142
143 double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
144
145 double extent = std::max(std::abs(min), std::abs(max));
146 double scale = extent / highest;
147
Finn Williamsfd271062019-12-04 14:27:27 +0000148 return std::make_pair(static_cast<float>(scale), 0);
149
150 }
151
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000152 int NumBits() const override { return 16; }
153
Derek Lambertif90c56d2020-01-10 17:14:08 +0000154 DataType GetDataType() const override { return DataType::QSymmS16; }
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000155};
156
157} // namespace armnn