blob: f59a5e1b2824544eabb90aa6f7de6d09381d1052 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
Jerry Ge9c9c8da2023-07-19 23:08:16 +000045 int getIsParentGraphOutput() const
46 {
Jerry Ge9e94af82022-10-27 09:57:00 -070047 return isParentGraphOutput;
48 }
Eric Kunzee5e26762020-10-13 16:11:07 -070049
50 int getIsSubgraphInput() const
51 {
52 return isSubgraphInput;
53 }
54
55 int getIsSubgraphOutput() const
56 {
57 return isSubgraphOutput;
58 }
59
60 int setProducer(GraphNode* node);
61 int addConsumer(GraphNode* node);
62
63 int setIsValid()
64 {
65 isValid = 1;
66 return 0;
67 }
68
69 int clearIsValid()
70 {
71 isValid = 0;
72 return 0;
73 }
74
75 int getIsValid() const
76 {
77 return isValid;
78 }
79
Eric Kunzee5e26762020-10-13 16:11:07 -070080 GraphNode* getProducer()
81 {
82 return producer;
83 }
84
85 std::vector<GraphNode*>& getConsumers()
86 {
87 return consumers;
88 }
89
90 const std::string& getName() const
91 {
92 return tensorName;
93 }
94
95 const std::vector<int>& getShape() const
96 {
97 return shape;
98 }
99
Jerry Ge264f7fa2023-04-21 22:49:57 +0000100 void setDimSize(size_t dim, uint32_t new_size)
101 {
102 this->shape[dim] = new_size;
103 return;
104 }
105
Eric Kunzee5e26762020-10-13 16:11:07 -0700106 std::string getShapeAsString() const
107 {
108 std::string shape_str("[");
109 for (auto& dim : shape)
110 {
111 shape_str += (std::to_string(dim) + ", ");
112 }
113 shape_str.append("]");
114 return shape_str;
115 }
116
Eric Kunzee5e26762020-10-13 16:11:07 -0700117 const uint32_t getElementCount() const
118 {
119 uint32_t elements = 1;
120 for (size_t i = 0; i < shape.size(); i++)
121 elements *= shape[i];
122
123 return elements;
124 }
125
126 // Comparison of rank and type with other tensors
127 const int matchRank(const Tensor& ref) const
128 {
129 return (ref.shape.size() == shape.size()) ? 0 : 1;
130 }
131
132 const int matchType(const Tensor& ref) const
133 {
134 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
135 }
136
137 const int matchRankType(const Tensor& ref) const
138 {
139 return (matchType(ref) || matchRank(ref));
140 }
141
142 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
143 {
144 if (matchRankType(ref))
145 return 1;
146
147 for (size_t i = 0; i < shape.size(); i++)
148 {
149 if (shape[i] != ref.shape[i])
150 {
151 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000152 // For broadcasts, the order of *this and ref matters.
153 // *this should be the source tensor.
154 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
155 // this->shape must have size 1 if they don't match
156 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700157 {
158 return 1;
159 }
160 }
161 }
162
163 return 0;
164 }
165
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800166 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
167 {
168 if (matchRank(ref))
169 return 1;
170
171 for (size_t i = 0; i < shape.size(); i++)
172 {
173 if (shape[i] != ref.shape[i])
174 {
175 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000176 // For broadcasts, the order of *this and ref matters.
177 // *this should be the source tensor.
178 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
179 // this->shape must have size 1 if they don't match
180 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800181 {
182 return 1;
183 }
184 }
185 }
186
187 return 0;
188 }
189
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 // Sometimes we might want to match several semi-compatible types,
191 // so just check rank and size here
192 const int matchRankSize(const Tensor& ref) const
193 {
194 if (matchRank(ref))
195 return 1;
196
197 for (size_t i = 0; i < shape.size(); i++)
198 {
199 if (shape[i] != ref.shape[i])
200 return 1;
201 }
202
203 return 0;
204 }
205
206 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000207 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000209 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 }
211
212 const int checkRequiredRank(const int minRank, const int maxRank) const
213 {
214 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
215 }
216
217 const int getRank() const
218 {
219 return shape.size();
220 }
221
Tai Lya4d748b2023-03-28 22:06:56 +0000222 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 {
224 return tensorDtype;
225 }
226
Tai Lya4d748b2023-03-28 22:06:56 +0000227 const DType getSerializationDtype() const
228 {
229 return serializationDtype;
230 }
231
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 virtual int dumpTensor(FILE* out) const = 0;
233 virtual int dumpTensorParams(FILE* out) const;
234 virtual int dumpTensorParams(std::ostream& out) const;
235
Tai Lya4d748b2023-03-28 22:06:56 +0000236 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
238 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
239 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
240 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
Tai Lya4d748b2023-03-28 22:06:56 +0000241 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
243 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
244 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
245 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
246
247 virtual int readFromNpyFile(const char* filename);
248 virtual int writeToNpyFile(const char* filename) const;
249 virtual int copyValueFrom(Tensor* tensor) = 0;
250
Tai Lya4d748b2023-03-28 22:06:56 +0000251 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000252 virtual int readfromVector(const ArrayProxy<float> vals);
253 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
254 virtual int readfromVector(const ArrayProxy<int32_t> vals);
255 virtual int readfromVector(const ArrayProxy<int64_t> vals);
256 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100257
Tai Lya4d748b2023-03-28 22:06:56 +0000258 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000259 virtual int writeToVector(ArrayProxy<float> vals);
260 virtual int writeToVector(ArrayProxy<half_float::half> vals);
261 virtual int writeToVector(ArrayProxy<int32_t> vals);
262 virtual int writeToVector(ArrayProxy<int64_t> vals);
263 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100264
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 const char* bool_to_str(bool in) const
266 {
267 static const char* true_str = "true";
268 static const char* false_str = "false";
269 return in ? true_str : false_str;
270 }
271
272 virtual int allocate() = 0;
273 virtual int deallocate() = 0;
274 virtual bool is_allocated() = 0;
275
276protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000277 const std::string tensorName;
278 const DType serializationDtype;
Jerry Ge264f7fa2023-04-21 22:49:57 +0000279 std::vector<int> shape;
Tai Lya4d748b2023-03-28 22:06:56 +0000280 const TOSA_REF_TYPE tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 int isValid;
Eric Kunzee5e26762020-10-13 16:11:07 -0700282 int isSubgraphInput;
283 int isSubgraphOutput;
284 bool isAllocated;
285
Jerry Ge9e94af82022-10-27 09:57:00 -0700286 bool isParentGraphOutput;
287
Eric Kunzee5e26762020-10-13 16:11:07 -0700288 GraphNode* producer;
289 std::vector<GraphNode*> consumers;
290
291 // Note: the Eigen::Tensor is not declared in Tensor
292 // Instead, the TensorTemplate class keeps the templated tensor
293 // declaration so that the graph manipulation tools are isolated
294 // from the templated tensor type.
295 //
296 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
297 // so that they can operate on the right types.
298};
299
300template <class T>
301class TensorTemplate : public Tensor
302{
303public:
Tai Lya4d748b2023-03-28 22:06:56 +0000304 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
305 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700306 {
307 tensor = nullptr;
308 }
309
310 virtual ~TensorTemplate()
311 {
312 deallocate();
313 }
314
315 virtual int allocate()
316 {
317 tensor = new T();
318 if (tensor)
319 return 0;
320 else
321 return 1;
322 }
323
324 virtual int deallocate()
325 {
326 if (tensor)
327 {
Eric Kunze9a367552023-07-11 13:27:36 -0700328 DEBUG_INFO(GT, "Deallocating tensor %s", tensorName.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700329 delete tensor;
330 }
331 tensor = nullptr;
332 return 0;
333 }
334
335 virtual bool is_allocated()
336 {
337 if (tensor)
338 {
339 return true;
340 }
341 return false;
342 }
343
344 T& getTensor()
345 {
346 return *tensor;
347 }
348
349 virtual int dumpTensor(FILE* out) const;
350
Tai Lya4d748b2023-03-28 22:06:56 +0000351 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
353 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
354 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
355 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000356
357 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700358 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
359 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
360 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
361 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
362
363 virtual int copyValueFrom(Tensor* tensor);
364
365protected:
366 T* tensor;
367};
368
369// allocate() template specializations to allocate the different tensor sizes
370// Let the compiler know here before the factory uses them, but define them in the .cc file.
371template <>
372int Tensor0<float>::allocate();
373template <>
374int Tensor1<float>::allocate();
375template <>
376int Tensor2<float>::allocate();
377template <>
378int Tensor3<float>::allocate();
379template <>
380int Tensor4<float>::allocate();
381template <>
382int Tensor5<float>::allocate();
383template <>
384int Tensor6<float>::allocate();
385
386template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000387int Tensor0<double>::allocate();
388template <>
389int Tensor1<double>::allocate();
390template <>
391int Tensor2<double>::allocate();
392template <>
393int Tensor3<double>::allocate();
394template <>
395int Tensor4<double>::allocate();
396template <>
397int Tensor5<double>::allocate();
398template <>
399int Tensor6<double>::allocate();
400
401template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700402int Tensor0<int32_t>::allocate();
403template <>
404int Tensor1<int32_t>::allocate();
405template <>
406int Tensor2<int32_t>::allocate();
407template <>
408int Tensor3<int32_t>::allocate();
409template <>
410int Tensor4<int32_t>::allocate();
411template <>
412int Tensor5<int32_t>::allocate();
413template <>
414int Tensor6<int32_t>::allocate();
415
416template <>
417int Tensor0<int64_t>::allocate();
418template <>
419int Tensor1<int64_t>::allocate();
420template <>
421int Tensor2<int64_t>::allocate();
422template <>
423int Tensor3<int64_t>::allocate();
424template <>
425int Tensor4<int64_t>::allocate();
426template <>
427int Tensor5<int64_t>::allocate();
428template <>
429int Tensor6<int64_t>::allocate();
430
431template <>
432int Tensor0<bool>::allocate();
433template <>
434int Tensor1<bool>::allocate();
435template <>
436int Tensor2<bool>::allocate();
437template <>
438int Tensor3<bool>::allocate();
439template <>
440int Tensor4<bool>::allocate();
441template <>
442int Tensor5<bool>::allocate();
443template <>
444int Tensor6<bool>::allocate();
445
446template <>
447int Tensor0<float>::copyValueFrom(Tensor* src);
448template <>
449int Tensor1<float>::copyValueFrom(Tensor* src);
450template <>
451int Tensor2<float>::copyValueFrom(Tensor* src);
452template <>
453int Tensor3<float>::copyValueFrom(Tensor* src);
454template <>
455int Tensor4<float>::copyValueFrom(Tensor* src);
456template <>
457int Tensor5<float>::copyValueFrom(Tensor* src);
458template <>
459int Tensor6<float>::copyValueFrom(Tensor* src);
460
461template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000462int Tensor0<double>::copyValueFrom(Tensor* src);
463template <>
464int Tensor1<double>::copyValueFrom(Tensor* src);
465template <>
466int Tensor2<double>::copyValueFrom(Tensor* src);
467template <>
468int Tensor3<double>::copyValueFrom(Tensor* src);
469template <>
470int Tensor4<double>::copyValueFrom(Tensor* src);
471template <>
472int Tensor5<double>::copyValueFrom(Tensor* src);
473template <>
474int Tensor6<double>::copyValueFrom(Tensor* src);
475
476template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700477int Tensor0<int32_t>::copyValueFrom(Tensor* src);
478template <>
479int Tensor1<int32_t>::copyValueFrom(Tensor* src);
480template <>
481int Tensor2<int32_t>::copyValueFrom(Tensor* src);
482template <>
483int Tensor3<int32_t>::copyValueFrom(Tensor* src);
484template <>
485int Tensor4<int32_t>::copyValueFrom(Tensor* src);
486template <>
487int Tensor5<int32_t>::copyValueFrom(Tensor* src);
488template <>
489int Tensor6<int32_t>::copyValueFrom(Tensor* src);
490
491template <>
492int Tensor0<int64_t>::copyValueFrom(Tensor* src);
493template <>
494int Tensor1<int64_t>::copyValueFrom(Tensor* src);
495template <>
496int Tensor2<int64_t>::copyValueFrom(Tensor* src);
497template <>
498int Tensor3<int64_t>::copyValueFrom(Tensor* src);
499template <>
500int Tensor4<int64_t>::copyValueFrom(Tensor* src);
501template <>
502int Tensor5<int64_t>::copyValueFrom(Tensor* src);
503template <>
504int Tensor6<int64_t>::copyValueFrom(Tensor* src);
505
506template <>
507int Tensor0<bool>::copyValueFrom(Tensor* src);
508template <>
509int Tensor1<bool>::copyValueFrom(Tensor* src);
510template <>
511int Tensor2<bool>::copyValueFrom(Tensor* src);
512template <>
513int Tensor3<bool>::copyValueFrom(Tensor* src);
514template <>
515int Tensor4<bool>::copyValueFrom(Tensor* src);
516template <>
517int Tensor5<bool>::copyValueFrom(Tensor* src);
518template <>
519int Tensor6<bool>::copyValueFrom(Tensor* src);
520
521template <>
522int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
523template <>
524int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
525template <>
526int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
527template <>
528int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
529template <>
530int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
531template <>
532int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
533template <>
534int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
535
536template <>
537int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
538template <>
539int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
540template <>
541int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
542template <>
543int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
544template <>
545int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
546template <>
547int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
548template <>
549int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
550
551template <>
552int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
553template <>
554int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
555template <>
556int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
557template <>
558int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
559template <>
560int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
561template <>
562int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
563template <>
564int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
565
566template <>
567int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
568template <>
569int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
570template <>
571int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
572template <>
573int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
574template <>
575int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
576template <>
577int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
578template <>
579int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
580
581template <>
582int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
583template <>
584int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
585template <>
586int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
587template <>
588int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
589template <>
590int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
591template <>
592int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
593template <>
594int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
595
596template <>
597int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
598template <>
599int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
600template <>
601int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
602template <>
603int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
604template <>
605int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
606template <>
607int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
608template <>
609int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
610
611template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000612int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
613template <>
614int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
615template <>
616int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
617template <>
618int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
619template <>
620int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
621template <>
622int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
623template <>
624int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
625
626template <>
627int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
628template <>
629int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
630template <>
631int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
632template <>
633int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
634template <>
635int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
636template <>
637int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
638template <>
639int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
640
641template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700642int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
643template <>
644int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
645template <>
646int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
647template <>
648int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
649template <>
650int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
651template <>
652int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
653template <>
654int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
655
656template <>
657int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
658template <>
659int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
660template <>
661int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
662template <>
663int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
664template <>
665int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
666template <>
667int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
668template <>
669int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
670
Eric Kunzee5e26762020-10-13 16:11:07 -0700671template <>
672int Tensor0<float>::dumpTensor(FILE* out) const;
673template <>
674int Tensor1<float>::dumpTensor(FILE* out) const;
675template <>
676int Tensor2<float>::dumpTensor(FILE* out) const;
677template <>
678int Tensor3<float>::dumpTensor(FILE* out) const;
679template <>
680int Tensor4<float>::dumpTensor(FILE* out) const;
681template <>
682int Tensor5<float>::dumpTensor(FILE* out) const;
683template <>
684int Tensor6<float>::dumpTensor(FILE* out) const;
685template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000686int Tensor0<double>::dumpTensor(FILE* out) const;
687template <>
688int Tensor1<double>::dumpTensor(FILE* out) const;
689template <>
690int Tensor2<double>::dumpTensor(FILE* out) const;
691template <>
692int Tensor3<double>::dumpTensor(FILE* out) const;
693template <>
694int Tensor4<double>::dumpTensor(FILE* out) const;
695template <>
696int Tensor5<float>::dumpTensor(FILE* out) const;
697template <>
698int Tensor6<double>::dumpTensor(FILE* out) const;
699template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700700int Tensor0<int32_t>::dumpTensor(FILE* out) const;
701template <>
702int Tensor1<int32_t>::dumpTensor(FILE* out) const;
703template <>
704int Tensor2<int32_t>::dumpTensor(FILE* out) const;
705template <>
706int Tensor3<int32_t>::dumpTensor(FILE* out) const;
707template <>
708int Tensor4<int32_t>::dumpTensor(FILE* out) const;
709template <>
710int Tensor5<int32_t>::dumpTensor(FILE* out) const;
711template <>
712int Tensor6<int32_t>::dumpTensor(FILE* out) const;
713template <>
714int Tensor0<int64_t>::dumpTensor(FILE* out) const;
715template <>
716int Tensor1<int64_t>::dumpTensor(FILE* out) const;
717template <>
718int Tensor2<int64_t>::dumpTensor(FILE* out) const;
719template <>
720int Tensor3<int64_t>::dumpTensor(FILE* out) const;
721template <>
722int Tensor4<int64_t>::dumpTensor(FILE* out) const;
723template <>
724int Tensor5<int64_t>::dumpTensor(FILE* out) const;
725template <>
726int Tensor6<int64_t>::dumpTensor(FILE* out) const;
727template <>
728int Tensor0<bool>::dumpTensor(FILE* out) const;
729template <>
730int Tensor1<bool>::dumpTensor(FILE* out) const;
731template <>
732int Tensor2<bool>::dumpTensor(FILE* out) const;
733template <>
734int Tensor3<bool>::dumpTensor(FILE* out) const;
735template <>
736int Tensor4<bool>::dumpTensor(FILE* out) const;
737template <>
738int Tensor5<bool>::dumpTensor(FILE* out) const;
739template <>
740int Tensor6<bool>::dumpTensor(FILE* out) const;
741
742class TensorFactory
743{
744public:
Tai Lya4d748b2023-03-28 22:06:56 +0000745 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700746 {
Tai Lya4d748b2023-03-28 22:06:56 +0000747 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700748 switch (tensorDtype_)
749 {
Tai Lya4d748b2023-03-28 22:06:56 +0000750 case TOSA_REF_TYPE_FP32:
751 case TOSA_REF_TYPE_FP16:
752 case TOSA_REF_TYPE_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700753 switch (rank)
754 {
755 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000756 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000758 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700759 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000760 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700761 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000762 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700763 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000764 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700765 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000766 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000768 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700770 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000771 case TOSA_REF_TYPE_INT32:
772 case TOSA_REF_TYPE_UINT8:
773 case TOSA_REF_TYPE_INT4:
774 case TOSA_REF_TYPE_INT8:
775 case TOSA_REF_TYPE_INT16:
776 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700777 switch (rank)
778 {
779 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000780 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700781 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000782 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700783 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000784 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700785 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000786 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700787 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000788 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700789 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000790 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700791 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000792 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700793 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700794 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000795 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700796 switch (rank)
797 {
798 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000799 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000801 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000803 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000805 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000807 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700808 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000809 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700810 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000811 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700812 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700813 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000814 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700815 switch (rank)
816 {
817 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000818 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000820 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700821 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000822 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700823 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000824 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700825 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000826 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700827 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000828 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700829 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000830 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700831 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700832 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000833 case TOSA_REF_TYPE_FP64:
834 switch (rank)
835 {
836 case 0:
837 return new Tensor0<double>(tensorName_, dtype_, shape_);
838 case 1:
839 return new Tensor1<double>(tensorName_, dtype_, shape_);
840 case 2:
841 return new Tensor2<double>(tensorName_, dtype_, shape_);
842 case 3:
843 return new Tensor3<double>(tensorName_, dtype_, shape_);
844 case 4:
845 return new Tensor4<double>(tensorName_, dtype_, shape_);
846 case 5:
847 return new Tensor5<double>(tensorName_, dtype_, shape_);
848 case 6:
849 return new Tensor6<double>(tensorName_, dtype_, shape_);
850 }
851 break;
852 case TOSA_REF_TYPE_UNKNOWN:
853 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -0700854 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700855 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700856 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700857 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700858};
859}; // namespace TosaReference
860
861#endif