blob: c6a7571a92e2e9136491a2e6a7a6f66d390973bd [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matteo Martincighada9cb22018-10-17 14:45:07 +01006#pragma once
7
telsoa014fcda012018-03-09 14:13:49 +00008#include <armnn/Tensor.hpp>
9
Matteo Martincighe011d202019-11-28 11:35:47 +000010#include <armnnUtils/DataLayoutIndexed.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
12namespace armnn
13{
14
telsoa01c577f2c2018-08-31 09:22:23 +010015// Utility class providing access to raw tensor memory based on indices along each dimension.
telsoa014fcda012018-03-09 14:13:49 +000016template <typename DataType>
17class TensorBufferArrayView
18{
19public:
Matteo Martincigh21350152018-11-28 16:22:22 +000020 TensorBufferArrayView(const TensorShape& shape, DataType* data,
21 armnnUtils::DataLayoutIndexed dataLayout = DataLayout::NCHW)
telsoa014fcda012018-03-09 14:13:49 +000022 : m_Shape(shape)
23 , m_Data(data)
James Conroy59540822018-10-11 12:39:05 +010024 , m_DataLayout(dataLayout)
telsoa014fcda012018-03-09 14:13:49 +000025 {
Colm Donelanb4ef1632024-02-01 15:00:43 +000026 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(m_Shape.GetNumDimensions() == 4,
27 "Only $d tensors are supported by TensorBufferArrayView.");
telsoa014fcda012018-03-09 14:13:49 +000028 }
29
30 DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
31 {
Matteo Martincighee423ce2019-06-05 09:02:41 +010032 return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)];
telsoa014fcda012018-03-09 14:13:49 +000033 }
34
35private:
Matteo Martincigh21350152018-11-28 16:22:22 +000036 const TensorShape m_Shape;
37 DataType* m_Data;
38 armnnUtils::DataLayoutIndexed m_DataLayout;
telsoa014fcda012018-03-09 14:13:49 +000039};
40
41} //namespace armnn