blob: 6bfdf57b36d3894f1208569d61e9f559a6a02e08 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "RawTensor.h"
25
26#include "Utils.h"
27
28#include "arm_compute/core/Utils.h"
29
30#include <algorithm>
31#include <array>
32#include <functional>
33#include <stdexcept>
34#include <utility>
35
36namespace arm_compute
37{
38namespace test
39{
40RawTensor::RawTensor(TensorShape shape, Format format, int fixed_point_position)
41 : _buffer(nullptr),
42 _shape(shape),
43 _format(format),
44 _fixed_point_position(fixed_point_position)
45{
46 _buffer = ::arm_compute::test::cpp14::make_unique<BufferType[]>(size());
47}
48
49RawTensor::RawTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position)
50 : _buffer(nullptr),
51 _shape(shape),
52 _data_type(data_type),
53 _num_channels(num_channels),
54 _fixed_point_position(fixed_point_position)
55{
56 _buffer = ::arm_compute::test::cpp14::make_unique<BufferType[]>(size());
57}
58
59RawTensor::RawTensor(const RawTensor &tensor)
60 : _buffer(nullptr),
61 _shape(tensor.shape()),
62 _format(tensor.format()),
63 _fixed_point_position(tensor.fixed_point_position())
64{
65 _buffer = ::arm_compute::test::cpp14::make_unique<BufferType[]>(tensor.size());
66 std::copy(tensor.data(), tensor.data() + size(), _buffer.get());
67}
68
69RawTensor &RawTensor::operator=(RawTensor tensor)
70{
71 swap(*this, tensor);
72
73 return *this;
74}
75
76RawTensor::BufferType &RawTensor::operator[](size_t offset)
77{
78 return _buffer[offset];
79}
80
81const RawTensor::BufferType &RawTensor::operator[](size_t offset) const
82{
83 return _buffer[offset];
84}
85
86TensorShape RawTensor::shape() const
87{
88 return _shape;
89}
90
91size_t RawTensor::element_size() const
92{
93 return num_channels() * element_size_from_data_type(data_type());
94}
95
96int RawTensor::fixed_point_position() const
97{
98 return _fixed_point_position;
99}
100
101size_t RawTensor::size() const
102{
103 const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>());
104 return size * element_size();
105}
106
107Format RawTensor::format() const
108{
109 return _format;
110}
111
112DataType RawTensor::data_type() const
113{
114 if(_format != Format::UNKNOWN)
115 {
116 return data_type_from_format(_format);
117 }
118 else
119 {
120 return _data_type;
121 }
122}
123
124int RawTensor::num_channels() const
125{
126 switch(_format)
127 {
128 case Format::U8:
129 case Format::S16:
130 case Format::U16:
131 case Format::S32:
132 case Format::U32:
133 return 1;
134 case Format::RGB888:
135 return 3;
136 case Format::UNKNOWN:
137 return _num_channels;
138 default:
139 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
140 }
141}
142
143int RawTensor::num_elements() const
144{
145 return _shape.total_size();
146}
147
148const RawTensor::BufferType *RawTensor::data() const
149{
150 return _buffer.get();
151}
152
153RawTensor::BufferType *RawTensor::data()
154{
155 return _buffer.get();
156}
157
158const RawTensor::BufferType *RawTensor::operator()(const Coordinates &coord) const
159{
160 return _buffer.get() + coord2index(_shape, coord) * element_size();
161}
162
163RawTensor::BufferType *RawTensor::operator()(const Coordinates &coord)
164{
165 return _buffer.get() + coord2index(_shape, coord) * element_size();
166}
167
168void swap(RawTensor &tensor1, RawTensor &tensor2)
169{
170 // Use unqualified call to swap to enable ADL. But make std::swap available
171 // as backup.
172 using std::swap;
173 swap(tensor1._shape, tensor2._shape);
174 swap(tensor1._format, tensor2._format);
175 swap(tensor1._data_type, tensor2._data_type);
176 swap(tensor1._num_channels, tensor2._num_channels);
177 swap(tensor1._buffer, tensor2._buffer);
178}
179} // namespace test
180} // namespace arm_compute