blob: cfa8ce7e91d3bc4a21134140b362e3b75c0a4e9d [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <armnn/ArmNN.hpp>
#include <TypeUtils.hpp>
namespace armnn
{
class BaseIterator
{
public:
BaseIterator() {}
virtual ~BaseIterator() {}
virtual BaseIterator& operator++() = 0;
virtual BaseIterator& operator+=(const unsigned int increment) = 0;
virtual BaseIterator& operator-=(const unsigned int increment) = 0;
};
class Decoder : public BaseIterator
{
public:
Decoder() : BaseIterator() {}
virtual ~Decoder() {}
virtual float Get() const = 0;
};
class Encoder : public BaseIterator
{
public:
Encoder() : BaseIterator() {}
virtual ~Encoder() {}
virtual void Set(const float& right) = 0;
};
class ComparisonEncoder : public BaseIterator
{
public:
ComparisonEncoder() : BaseIterator() {}
virtual ~ComparisonEncoder() {}
virtual void Set(bool right) = 0;
};
template<typename T, typename Base>
class TypedIterator : public Base
{
public:
TypedIterator(T* data)
: m_Iterator(data)
{}
TypedIterator& operator++() override
{
++m_Iterator;
return *this;
}
TypedIterator& operator+=(const unsigned int increment) override
{
m_Iterator += increment;
return *this;
}
TypedIterator& operator-=(const unsigned int increment) override
{
m_Iterator -= increment;
return *this;
}
T* m_Iterator;
};
class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder>
{
public:
QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
float Get() const override
{
return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
}
private:
const float m_Scale;
const int32_t m_Offset;
};
class FloatDecoder : public TypedIterator<const float, Decoder>
{
public:
FloatDecoder(const float* data)
: TypedIterator(data) {}
float Get() const override
{
return *m_Iterator;
}
};
class FloatEncoder : public TypedIterator<float, Encoder>
{
public:
FloatEncoder(float* data)
: TypedIterator(data) {}
void Set(const float& right) override
{
*m_Iterator = right;
}
};
class QASymm8Encoder : public TypedIterator<uint8_t, Encoder>
{
public:
QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
void Set(const float& right) override
{
*m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
}
private:
const float m_Scale;
const int32_t m_Offset;
};
class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder>
{
public:
BooleanEncoder(uint8_t* data)
: TypedIterator(data) {}
void Set(bool right) override
{
*m_Iterator = right;
}
};
} //namespace armnn