blob: 5bcd1b218060922546c33b38ec9ccc4ff251687b [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
Tai Lycf84bc92023-09-07 20:49:09 +000045 bool getIsParentGraphOutput() const
Jerry Ge9c9c8da2023-07-19 23:08:16 +000046 {
Jerry Ge9e94af82022-10-27 09:57:00 -070047 return isParentGraphOutput;
48 }
Tai Lycf84bc92023-09-07 20:49:09 +000049 int setIsVariable();
Eric Kunzee5e26762020-10-13 16:11:07 -070050
Tai Lycf84bc92023-09-07 20:49:09 +000051 bool getIsSubgraphInput() const
Eric Kunzee5e26762020-10-13 16:11:07 -070052 {
53 return isSubgraphInput;
54 }
55
Tai Lycf84bc92023-09-07 20:49:09 +000056 bool getIsSubgraphOutput() const
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 return isSubgraphOutput;
59 }
60
Tai Lycf84bc92023-09-07 20:49:09 +000061 bool getIsVariable() const
62 {
63 return isVariable;
64 }
65
Eric Kunzee5e26762020-10-13 16:11:07 -070066 int setProducer(GraphNode* node);
67 int addConsumer(GraphNode* node);
68
69 int setIsValid()
70 {
71 isValid = 1;
72 return 0;
73 }
74
75 int clearIsValid()
76 {
77 isValid = 0;
78 return 0;
79 }
80
81 int getIsValid() const
82 {
83 return isValid;
84 }
85
Eric Kunzee5e26762020-10-13 16:11:07 -070086 GraphNode* getProducer()
87 {
88 return producer;
89 }
90
91 std::vector<GraphNode*>& getConsumers()
92 {
93 return consumers;
94 }
95
96 const std::string& getName() const
97 {
98 return tensorName;
99 }
100
101 const std::vector<int>& getShape() const
102 {
103 return shape;
104 }
105
Jerry Ge264f7fa2023-04-21 22:49:57 +0000106 void setDimSize(size_t dim, uint32_t new_size)
107 {
108 this->shape[dim] = new_size;
109 return;
110 }
111
Eric Kunzee5e26762020-10-13 16:11:07 -0700112 std::string getShapeAsString() const
113 {
114 std::string shape_str("[");
115 for (auto& dim : shape)
116 {
117 shape_str += (std::to_string(dim) + ", ");
118 }
119 shape_str.append("]");
120 return shape_str;
121 }
122
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 const uint32_t getElementCount() const
124 {
125 uint32_t elements = 1;
126 for (size_t i = 0; i < shape.size(); i++)
127 elements *= shape[i];
128
129 return elements;
130 }
131
132 // Comparison of rank and type with other tensors
133 const int matchRank(const Tensor& ref) const
134 {
135 return (ref.shape.size() == shape.size()) ? 0 : 1;
136 }
137
138 const int matchType(const Tensor& ref) const
139 {
140 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
141 }
142
143 const int matchRankType(const Tensor& ref) const
144 {
145 return (matchType(ref) || matchRank(ref));
146 }
147
148 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
149 {
150 if (matchRankType(ref))
151 return 1;
152
153 for (size_t i = 0; i < shape.size(); i++)
154 {
155 if (shape[i] != ref.shape[i])
156 {
157 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000158 // For broadcasts, the order of *this and ref matters.
159 // *this should be the source tensor.
160 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
161 // this->shape must have size 1 if they don't match
162 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 {
164 return 1;
165 }
166 }
167 }
168
169 return 0;
170 }
171
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800172 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
173 {
174 if (matchRank(ref))
175 return 1;
176
177 for (size_t i = 0; i < shape.size(); i++)
178 {
179 if (shape[i] != ref.shape[i])
180 {
181 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000182 // For broadcasts, the order of *this and ref matters.
183 // *this should be the source tensor.
184 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
185 // this->shape must have size 1 if they don't match
186 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800187 {
188 return 1;
189 }
190 }
191 }
192
193 return 0;
194 }
195
Eric Kunzee5e26762020-10-13 16:11:07 -0700196 // Sometimes we might want to match several semi-compatible types,
197 // so just check rank and size here
198 const int matchRankSize(const Tensor& ref) const
199 {
200 if (matchRank(ref))
201 return 1;
202
203 for (size_t i = 0; i < shape.size(); i++)
204 {
205 if (shape[i] != ref.shape[i])
206 return 1;
207 }
208
209 return 0;
210 }
211
212 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000213 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000215 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 }
217
218 const int checkRequiredRank(const int minRank, const int maxRank) const
219 {
220 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
221 }
222
223 const int getRank() const
224 {
225 return shape.size();
226 }
227
Tai Lya4d748b2023-03-28 22:06:56 +0000228 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 {
230 return tensorDtype;
231 }
232
Tai Lya4d748b2023-03-28 22:06:56 +0000233 const DType getSerializationDtype() const
234 {
235 return serializationDtype;
236 }
237
Eric Kunzee5e26762020-10-13 16:11:07 -0700238 virtual int dumpTensor(FILE* out) const = 0;
239 virtual int dumpTensorParams(FILE* out) const;
240 virtual int dumpTensorParams(std::ostream& out) const;
241
Tai Lya4d748b2023-03-28 22:06:56 +0000242 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700243 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
Georgios Pinitase9059772023-12-06 18:52:30 +0000244 virtual int setTensorValueInt16(const size_t bufLen, const int16_t* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700245 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
246 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
247 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
Tai Lya4d748b2023-03-28 22:06:56 +0000248 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700249 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
Georgios Pinitase9059772023-12-06 18:52:30 +0000250 virtual int getTensorValueInt16(const size_t bufLen, int16_t* ibuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700251 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
252 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
253 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
254
255 virtual int readFromNpyFile(const char* filename);
256 virtual int writeToNpyFile(const char* filename) const;
257 virtual int copyValueFrom(Tensor* tensor) = 0;
258
Tai Lya4d748b2023-03-28 22:06:56 +0000259 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000260 virtual int readfromVector(const ArrayProxy<float> vals);
261 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
Georgios Pinitase9059772023-12-06 18:52:30 +0000262 virtual int readfromVector(const ArrayProxy<int16_t> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000263 virtual int readfromVector(const ArrayProxy<int32_t> vals);
264 virtual int readfromVector(const ArrayProxy<int64_t> vals);
265 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100266
Tai Lya4d748b2023-03-28 22:06:56 +0000267 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000268 virtual int writeToVector(ArrayProxy<float> vals);
269 virtual int writeToVector(ArrayProxy<half_float::half> vals);
Georgios Pinitase9059772023-12-06 18:52:30 +0000270 virtual int writeToVector(ArrayProxy<int16_t> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000271 virtual int writeToVector(ArrayProxy<int32_t> vals);
272 virtual int writeToVector(ArrayProxy<int64_t> vals);
273 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 const char* bool_to_str(bool in) const
276 {
277 static const char* true_str = "true";
278 static const char* false_str = "false";
279 return in ? true_str : false_str;
280 }
281
Tai Lycf84bc92023-09-07 20:49:09 +0000282 virtual int allocate() = 0;
283 virtual int deallocate() = 0;
284 virtual bool is_allocated() const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000287 const std::string tensorName;
288 const DType serializationDtype;
Jerry Ge264f7fa2023-04-21 22:49:57 +0000289 std::vector<int> shape;
Tai Lya4d748b2023-03-28 22:06:56 +0000290 const TOSA_REF_TYPE tensorDtype;
Tai Lycf84bc92023-09-07 20:49:09 +0000291 bool isValid;
292 bool isSubgraphInput;
293 bool isSubgraphOutput;
294 bool isVariable;
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 bool isAllocated;
296
Jerry Ge9e94af82022-10-27 09:57:00 -0700297 bool isParentGraphOutput;
298
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 GraphNode* producer;
300 std::vector<GraphNode*> consumers;
301
302 // Note: the Eigen::Tensor is not declared in Tensor
303 // Instead, the TensorTemplate class keeps the templated tensor
304 // declaration so that the graph manipulation tools are isolated
305 // from the templated tensor type.
306 //
307 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
308 // so that they can operate on the right types.
309};
310
311template <class T>
312class TensorTemplate : public Tensor
313{
314public:
Tai Lya4d748b2023-03-28 22:06:56 +0000315 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
316 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317 {
318 tensor = nullptr;
319 }
320
321 virtual ~TensorTemplate()
322 {
323 deallocate();
324 }
325
326 virtual int allocate()
327 {
328 tensor = new T();
329 if (tensor)
330 return 0;
331 else
332 return 1;
333 }
334
335 virtual int deallocate()
336 {
337 if (tensor)
338 {
Eric Kunze9a367552023-07-11 13:27:36 -0700339 DEBUG_INFO(GT, "Deallocating tensor %s", tensorName.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700340 delete tensor;
341 }
342 tensor = nullptr;
343 return 0;
344 }
345
Tai Lycf84bc92023-09-07 20:49:09 +0000346 virtual bool is_allocated() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700347 {
348 if (tensor)
349 {
350 return true;
351 }
352 return false;
353 }
354
355 T& getTensor()
356 {
357 return *tensor;
358 }
359
360 virtual int dumpTensor(FILE* out) const;
361
Tai Lya4d748b2023-03-28 22:06:56 +0000362 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700363 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
Georgios Pinitase9059772023-12-06 18:52:30 +0000364 virtual int setTensorValueInt16(const size_t bufLen, const int16_t* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700365 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
366 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
367 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000368
369 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700370 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
Georgios Pinitase9059772023-12-06 18:52:30 +0000371 virtual int getTensorValueInt16(const size_t bufLen, int16_t* ibuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700372 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
373 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
374 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
375
376 virtual int copyValueFrom(Tensor* tensor);
377
378protected:
379 T* tensor;
380};
381
382// allocate() template specializations to allocate the different tensor sizes
383// Let the compiler know here before the factory uses them, but define them in the .cc file.
384template <>
385int Tensor0<float>::allocate();
386template <>
387int Tensor1<float>::allocate();
388template <>
389int Tensor2<float>::allocate();
390template <>
391int Tensor3<float>::allocate();
392template <>
393int Tensor4<float>::allocate();
394template <>
395int Tensor5<float>::allocate();
396template <>
397int Tensor6<float>::allocate();
398
399template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000400int Tensor0<double>::allocate();
401template <>
402int Tensor1<double>::allocate();
403template <>
404int Tensor2<double>::allocate();
405template <>
406int Tensor3<double>::allocate();
407template <>
408int Tensor4<double>::allocate();
409template <>
410int Tensor5<double>::allocate();
411template <>
412int Tensor6<double>::allocate();
413
414template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700415int Tensor0<int32_t>::allocate();
416template <>
417int Tensor1<int32_t>::allocate();
418template <>
419int Tensor2<int32_t>::allocate();
420template <>
421int Tensor3<int32_t>::allocate();
422template <>
423int Tensor4<int32_t>::allocate();
424template <>
425int Tensor5<int32_t>::allocate();
426template <>
427int Tensor6<int32_t>::allocate();
428
429template <>
430int Tensor0<int64_t>::allocate();
431template <>
432int Tensor1<int64_t>::allocate();
433template <>
434int Tensor2<int64_t>::allocate();
435template <>
436int Tensor3<int64_t>::allocate();
437template <>
438int Tensor4<int64_t>::allocate();
439template <>
440int Tensor5<int64_t>::allocate();
441template <>
442int Tensor6<int64_t>::allocate();
443
444template <>
445int Tensor0<bool>::allocate();
446template <>
447int Tensor1<bool>::allocate();
448template <>
449int Tensor2<bool>::allocate();
450template <>
451int Tensor3<bool>::allocate();
452template <>
453int Tensor4<bool>::allocate();
454template <>
455int Tensor5<bool>::allocate();
456template <>
457int Tensor6<bool>::allocate();
458
459template <>
460int Tensor0<float>::copyValueFrom(Tensor* src);
461template <>
462int Tensor1<float>::copyValueFrom(Tensor* src);
463template <>
464int Tensor2<float>::copyValueFrom(Tensor* src);
465template <>
466int Tensor3<float>::copyValueFrom(Tensor* src);
467template <>
468int Tensor4<float>::copyValueFrom(Tensor* src);
469template <>
470int Tensor5<float>::copyValueFrom(Tensor* src);
471template <>
472int Tensor6<float>::copyValueFrom(Tensor* src);
473
474template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000475int Tensor0<double>::copyValueFrom(Tensor* src);
476template <>
477int Tensor1<double>::copyValueFrom(Tensor* src);
478template <>
479int Tensor2<double>::copyValueFrom(Tensor* src);
480template <>
481int Tensor3<double>::copyValueFrom(Tensor* src);
482template <>
483int Tensor4<double>::copyValueFrom(Tensor* src);
484template <>
485int Tensor5<double>::copyValueFrom(Tensor* src);
486template <>
487int Tensor6<double>::copyValueFrom(Tensor* src);
488
489template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700490int Tensor0<int32_t>::copyValueFrom(Tensor* src);
491template <>
492int Tensor1<int32_t>::copyValueFrom(Tensor* src);
493template <>
494int Tensor2<int32_t>::copyValueFrom(Tensor* src);
495template <>
496int Tensor3<int32_t>::copyValueFrom(Tensor* src);
497template <>
498int Tensor4<int32_t>::copyValueFrom(Tensor* src);
499template <>
500int Tensor5<int32_t>::copyValueFrom(Tensor* src);
501template <>
502int Tensor6<int32_t>::copyValueFrom(Tensor* src);
503
504template <>
505int Tensor0<int64_t>::copyValueFrom(Tensor* src);
506template <>
507int Tensor1<int64_t>::copyValueFrom(Tensor* src);
508template <>
509int Tensor2<int64_t>::copyValueFrom(Tensor* src);
510template <>
511int Tensor3<int64_t>::copyValueFrom(Tensor* src);
512template <>
513int Tensor4<int64_t>::copyValueFrom(Tensor* src);
514template <>
515int Tensor5<int64_t>::copyValueFrom(Tensor* src);
516template <>
517int Tensor6<int64_t>::copyValueFrom(Tensor* src);
518
519template <>
520int Tensor0<bool>::copyValueFrom(Tensor* src);
521template <>
522int Tensor1<bool>::copyValueFrom(Tensor* src);
523template <>
524int Tensor2<bool>::copyValueFrom(Tensor* src);
525template <>
526int Tensor3<bool>::copyValueFrom(Tensor* src);
527template <>
528int Tensor4<bool>::copyValueFrom(Tensor* src);
529template <>
530int Tensor5<bool>::copyValueFrom(Tensor* src);
531template <>
532int Tensor6<bool>::copyValueFrom(Tensor* src);
533
534template <>
Georgios Pinitase9059772023-12-06 18:52:30 +0000535int Tensor0<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
536template <>
537int Tensor1<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
538template <>
539int Tensor2<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
540template <>
541int Tensor3<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
542template <>
543int Tensor4<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
544template <>
545int Tensor5<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
546template <>
547int Tensor6<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
548
549template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700550int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
551template <>
552int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
553template <>
554int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
555template <>
556int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
557template <>
558int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
559template <>
560int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
561template <>
562int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
563
564template <>
Georgios Pinitase9059772023-12-06 18:52:30 +0000565int Tensor0<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
566template <>
567int Tensor1<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
568template <>
569int Tensor2<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
570template <>
571int Tensor3<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
572template <>
573int Tensor4<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
574template <>
575int Tensor5<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
576template <>
577int Tensor6<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
578
579template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700580int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
581template <>
582int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
583template <>
584int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
585template <>
586int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
587template <>
588int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
589template <>
590int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
591template <>
592int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
593
594template <>
595int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
596template <>
597int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
598template <>
599int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
600template <>
601int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
602template <>
603int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
604template <>
605int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
606template <>
607int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
608
609template <>
610int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
611template <>
612int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
613template <>
614int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
615template <>
616int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
617template <>
618int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
619template <>
620int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
621template <>
622int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
623
624template <>
625int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
626template <>
627int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
628template <>
629int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
630template <>
631int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
632template <>
633int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
634template <>
635int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
636template <>
637int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
638
639template <>
640int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
641template <>
642int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
643template <>
644int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
645template <>
646int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
647template <>
648int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
649template <>
650int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
651template <>
652int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
653
654template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000655int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
656template <>
657int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
658template <>
659int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
660template <>
661int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
662template <>
663int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
664template <>
665int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
666template <>
667int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
668
669template <>
670int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
671template <>
672int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
673template <>
674int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
675template <>
676int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
677template <>
678int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
679template <>
680int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
681template <>
682int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
683
684template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700685int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
686template <>
687int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
688template <>
689int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
690template <>
691int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
692template <>
693int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
694template <>
695int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
696template <>
697int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
698
699template <>
700int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
701template <>
702int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
703template <>
704int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
705template <>
706int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
707template <>
708int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
709template <>
710int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
711template <>
712int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
713
Eric Kunzee5e26762020-10-13 16:11:07 -0700714template <>
715int Tensor0<float>::dumpTensor(FILE* out) const;
716template <>
717int Tensor1<float>::dumpTensor(FILE* out) const;
718template <>
719int Tensor2<float>::dumpTensor(FILE* out) const;
720template <>
721int Tensor3<float>::dumpTensor(FILE* out) const;
722template <>
723int Tensor4<float>::dumpTensor(FILE* out) const;
724template <>
725int Tensor5<float>::dumpTensor(FILE* out) const;
726template <>
727int Tensor6<float>::dumpTensor(FILE* out) const;
728template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000729int Tensor0<double>::dumpTensor(FILE* out) const;
730template <>
731int Tensor1<double>::dumpTensor(FILE* out) const;
732template <>
733int Tensor2<double>::dumpTensor(FILE* out) const;
734template <>
735int Tensor3<double>::dumpTensor(FILE* out) const;
736template <>
737int Tensor4<double>::dumpTensor(FILE* out) const;
738template <>
739int Tensor5<float>::dumpTensor(FILE* out) const;
740template <>
741int Tensor6<double>::dumpTensor(FILE* out) const;
742template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700743int Tensor0<int32_t>::dumpTensor(FILE* out) const;
744template <>
745int Tensor1<int32_t>::dumpTensor(FILE* out) const;
746template <>
747int Tensor2<int32_t>::dumpTensor(FILE* out) const;
748template <>
749int Tensor3<int32_t>::dumpTensor(FILE* out) const;
750template <>
751int Tensor4<int32_t>::dumpTensor(FILE* out) const;
752template <>
753int Tensor5<int32_t>::dumpTensor(FILE* out) const;
754template <>
755int Tensor6<int32_t>::dumpTensor(FILE* out) const;
756template <>
757int Tensor0<int64_t>::dumpTensor(FILE* out) const;
758template <>
759int Tensor1<int64_t>::dumpTensor(FILE* out) const;
760template <>
761int Tensor2<int64_t>::dumpTensor(FILE* out) const;
762template <>
763int Tensor3<int64_t>::dumpTensor(FILE* out) const;
764template <>
765int Tensor4<int64_t>::dumpTensor(FILE* out) const;
766template <>
767int Tensor5<int64_t>::dumpTensor(FILE* out) const;
768template <>
769int Tensor6<int64_t>::dumpTensor(FILE* out) const;
770template <>
771int Tensor0<bool>::dumpTensor(FILE* out) const;
772template <>
773int Tensor1<bool>::dumpTensor(FILE* out) const;
774template <>
775int Tensor2<bool>::dumpTensor(FILE* out) const;
776template <>
777int Tensor3<bool>::dumpTensor(FILE* out) const;
778template <>
779int Tensor4<bool>::dumpTensor(FILE* out) const;
780template <>
781int Tensor5<bool>::dumpTensor(FILE* out) const;
782template <>
783int Tensor6<bool>::dumpTensor(FILE* out) const;
784
785class TensorFactory
786{
787public:
Tai Lya4d748b2023-03-28 22:06:56 +0000788 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700789 {
Tai Lya4d748b2023-03-28 22:06:56 +0000790 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700791 switch (tensorDtype_)
792 {
Tai Lya4d748b2023-03-28 22:06:56 +0000793 case TOSA_REF_TYPE_FP32:
794 case TOSA_REF_TYPE_FP16:
795 case TOSA_REF_TYPE_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700796 switch (rank)
797 {
798 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000799 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000801 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000803 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000805 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000807 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700808 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000809 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700810 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000811 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700812 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700813 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000814 case TOSA_REF_TYPE_INT32:
815 case TOSA_REF_TYPE_UINT8:
816 case TOSA_REF_TYPE_INT4:
817 case TOSA_REF_TYPE_INT8:
818 case TOSA_REF_TYPE_INT16:
819 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700820 switch (rank)
821 {
822 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000823 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700824 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000825 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700826 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000827 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700828 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000829 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700830 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000831 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700832 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000833 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700834 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000835 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700836 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700837 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000838 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700839 switch (rank)
840 {
841 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000842 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700843 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000844 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700845 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000846 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700847 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000848 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700849 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000850 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700851 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000852 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700853 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000854 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700855 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700856 break;
Won Jeona21b2e82023-08-10 10:33:01 +0000857 case TOSA_REF_TYPE_SHAPE:
Tai Ly0913dba2023-08-22 22:50:18 +0000858 assert(rank == 0);
Won Jeona21b2e82023-08-10 10:33:01 +0000859 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Tai Lya4d748b2023-03-28 22:06:56 +0000860 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700861 switch (rank)
862 {
863 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000864 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700865 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000866 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000868 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700869 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000870 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700871 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000872 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700873 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000874 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700875 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000876 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700877 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700878 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000879 case TOSA_REF_TYPE_FP64:
880 switch (rank)
881 {
882 case 0:
883 return new Tensor0<double>(tensorName_, dtype_, shape_);
884 case 1:
885 return new Tensor1<double>(tensorName_, dtype_, shape_);
886 case 2:
887 return new Tensor2<double>(tensorName_, dtype_, shape_);
888 case 3:
889 return new Tensor3<double>(tensorName_, dtype_, shape_);
890 case 4:
891 return new Tensor4<double>(tensorName_, dtype_, shape_);
892 case 5:
893 return new Tensor5<double>(tensorName_, dtype_, shape_);
894 case 6:
895 return new Tensor6<double>(tensorName_, dtype_, shape_);
896 }
897 break;
898 case TOSA_REF_TYPE_UNKNOWN:
899 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -0700900 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700901 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700902 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700903 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700904};
905}; // namespace TosaReference
906
907#endif