blob: 21cf148a7de3ac6377c8981e5a7569592ef7956d [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
45 int getIsParentGraphOutput() const {
46 return isParentGraphOutput;
47 }
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 int getIsSubgraphInput() const
50 {
51 return isSubgraphInput;
52 }
53
54 int getIsSubgraphOutput() const
55 {
56 return isSubgraphOutput;
57 }
58
59 int setProducer(GraphNode* node);
60 int addConsumer(GraphNode* node);
61
62 int setIsValid()
63 {
64 isValid = 1;
65 return 0;
66 }
67
68 int clearIsValid()
69 {
70 isValid = 0;
71 return 0;
72 }
73
74 int getIsValid() const
75 {
76 return isValid;
77 }
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 GraphNode* getProducer()
80 {
81 return producer;
82 }
83
84 std::vector<GraphNode*>& getConsumers()
85 {
86 return consumers;
87 }
88
89 const std::string& getName() const
90 {
91 return tensorName;
92 }
93
94 const std::vector<int>& getShape() const
95 {
96 return shape;
97 }
98
Jerry Ge264f7fa2023-04-21 22:49:57 +000099 void setDimSize(size_t dim, uint32_t new_size)
100 {
101 this->shape[dim] = new_size;
102 return;
103 }
104
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 std::string getShapeAsString() const
106 {
107 std::string shape_str("[");
108 for (auto& dim : shape)
109 {
110 shape_str += (std::to_string(dim) + ", ");
111 }
112 shape_str.append("]");
113 return shape_str;
114 }
115
Eric Kunzee5e26762020-10-13 16:11:07 -0700116 const uint32_t getElementCount() const
117 {
118 uint32_t elements = 1;
119 for (size_t i = 0; i < shape.size(); i++)
120 elements *= shape[i];
121
122 return elements;
123 }
124
125 // Comparison of rank and type with other tensors
126 const int matchRank(const Tensor& ref) const
127 {
128 return (ref.shape.size() == shape.size()) ? 0 : 1;
129 }
130
131 const int matchType(const Tensor& ref) const
132 {
133 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
134 }
135
136 const int matchRankType(const Tensor& ref) const
137 {
138 return (matchType(ref) || matchRank(ref));
139 }
140
141 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
142 {
143 if (matchRankType(ref))
144 return 1;
145
146 for (size_t i = 0; i < shape.size(); i++)
147 {
148 if (shape[i] != ref.shape[i])
149 {
150 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000151 // For broadcasts, the order of *this and ref matters.
152 // *this should be the source tensor.
153 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
154 // this->shape must have size 1 if they don't match
155 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 {
157 return 1;
158 }
159 }
160 }
161
162 return 0;
163 }
164
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800165 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
166 {
167 if (matchRank(ref))
168 return 1;
169
170 for (size_t i = 0; i < shape.size(); i++)
171 {
172 if (shape[i] != ref.shape[i])
173 {
174 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000175 // For broadcasts, the order of *this and ref matters.
176 // *this should be the source tensor.
177 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
178 // this->shape must have size 1 if they don't match
179 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800180 {
181 return 1;
182 }
183 }
184 }
185
186 return 0;
187 }
188
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 // Sometimes we might want to match several semi-compatible types,
190 // so just check rank and size here
191 const int matchRankSize(const Tensor& ref) const
192 {
193 if (matchRank(ref))
194 return 1;
195
196 for (size_t i = 0; i < shape.size(); i++)
197 {
198 if (shape[i] != ref.shape[i])
199 return 1;
200 }
201
202 return 0;
203 }
204
205 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000206 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000208 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 }
210
211 const int checkRequiredRank(const int minRank, const int maxRank) const
212 {
213 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
214 }
215
216 const int getRank() const
217 {
218 return shape.size();
219 }
220
Tai Lya4d748b2023-03-28 22:06:56 +0000221 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 {
223 return tensorDtype;
224 }
225
Tai Lya4d748b2023-03-28 22:06:56 +0000226 const DType getSerializationDtype() const
227 {
228 return serializationDtype;
229 }
230
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 virtual int dumpTensor(FILE* out) const = 0;
232 virtual int dumpTensorParams(FILE* out) const;
233 virtual int dumpTensorParams(std::ostream& out) const;
234
Tai Lya4d748b2023-03-28 22:06:56 +0000235 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700236 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
237 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
238 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
239 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
Tai Lya4d748b2023-03-28 22:06:56 +0000240 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
242 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
243 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
244 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
245
246 virtual int readFromNpyFile(const char* filename);
247 virtual int writeToNpyFile(const char* filename) const;
248 virtual int copyValueFrom(Tensor* tensor) = 0;
249
Tai Lya4d748b2023-03-28 22:06:56 +0000250 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000251 virtual int readfromVector(const ArrayProxy<float> vals);
252 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
253 virtual int readfromVector(const ArrayProxy<int32_t> vals);
254 virtual int readfromVector(const ArrayProxy<int64_t> vals);
255 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100256
Tai Lya4d748b2023-03-28 22:06:56 +0000257 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000258 virtual int writeToVector(ArrayProxy<float> vals);
259 virtual int writeToVector(ArrayProxy<half_float::half> vals);
260 virtual int writeToVector(ArrayProxy<int32_t> vals);
261 virtual int writeToVector(ArrayProxy<int64_t> vals);
262 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100263
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 const char* bool_to_str(bool in) const
265 {
266 static const char* true_str = "true";
267 static const char* false_str = "false";
268 return in ? true_str : false_str;
269 }
270
271 virtual int allocate() = 0;
272 virtual int deallocate() = 0;
273 virtual bool is_allocated() = 0;
274
275protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000276 const std::string tensorName;
277 const DType serializationDtype;
Jerry Ge264f7fa2023-04-21 22:49:57 +0000278 std::vector<int> shape;
Tai Lya4d748b2023-03-28 22:06:56 +0000279 const TOSA_REF_TYPE tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700280 int isValid;
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 int isSubgraphInput;
282 int isSubgraphOutput;
283 bool isAllocated;
284
Jerry Ge9e94af82022-10-27 09:57:00 -0700285 bool isParentGraphOutput;
286
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 GraphNode* producer;
288 std::vector<GraphNode*> consumers;
289
290 // Note: the Eigen::Tensor is not declared in Tensor
291 // Instead, the TensorTemplate class keeps the templated tensor
292 // declaration so that the graph manipulation tools are isolated
293 // from the templated tensor type.
294 //
295 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
296 // so that they can operate on the right types.
297};
298
299template <class T>
300class TensorTemplate : public Tensor
301{
302public:
Tai Lya4d748b2023-03-28 22:06:56 +0000303 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
304 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305 {
306 tensor = nullptr;
307 }
308
309 virtual ~TensorTemplate()
310 {
311 deallocate();
312 }
313
314 virtual int allocate()
315 {
316 tensor = new T();
317 if (tensor)
318 return 0;
319 else
320 return 1;
321 }
322
323 virtual int deallocate()
324 {
325 if (tensor)
326 {
327 delete tensor;
328 }
329 tensor = nullptr;
330 return 0;
331 }
332
333 virtual bool is_allocated()
334 {
335 if (tensor)
336 {
337 return true;
338 }
339 return false;
340 }
341
342 T& getTensor()
343 {
344 return *tensor;
345 }
346
347 virtual int dumpTensor(FILE* out) const;
348
Tai Lya4d748b2023-03-28 22:06:56 +0000349 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
351 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
352 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
353 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000354
355 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700356 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
357 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
358 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
359 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
360
361 virtual int copyValueFrom(Tensor* tensor);
362
363protected:
364 T* tensor;
365};
366
367// allocate() template specializations to allocate the different tensor sizes
368// Let the compiler know here before the factory uses them, but define them in the .cc file.
369template <>
370int Tensor0<float>::allocate();
371template <>
372int Tensor1<float>::allocate();
373template <>
374int Tensor2<float>::allocate();
375template <>
376int Tensor3<float>::allocate();
377template <>
378int Tensor4<float>::allocate();
379template <>
380int Tensor5<float>::allocate();
381template <>
382int Tensor6<float>::allocate();
383
384template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000385int Tensor0<double>::allocate();
386template <>
387int Tensor1<double>::allocate();
388template <>
389int Tensor2<double>::allocate();
390template <>
391int Tensor3<double>::allocate();
392template <>
393int Tensor4<double>::allocate();
394template <>
395int Tensor5<double>::allocate();
396template <>
397int Tensor6<double>::allocate();
398
399template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700400int Tensor0<int32_t>::allocate();
401template <>
402int Tensor1<int32_t>::allocate();
403template <>
404int Tensor2<int32_t>::allocate();
405template <>
406int Tensor3<int32_t>::allocate();
407template <>
408int Tensor4<int32_t>::allocate();
409template <>
410int Tensor5<int32_t>::allocate();
411template <>
412int Tensor6<int32_t>::allocate();
413
414template <>
415int Tensor0<int64_t>::allocate();
416template <>
417int Tensor1<int64_t>::allocate();
418template <>
419int Tensor2<int64_t>::allocate();
420template <>
421int Tensor3<int64_t>::allocate();
422template <>
423int Tensor4<int64_t>::allocate();
424template <>
425int Tensor5<int64_t>::allocate();
426template <>
427int Tensor6<int64_t>::allocate();
428
429template <>
430int Tensor0<bool>::allocate();
431template <>
432int Tensor1<bool>::allocate();
433template <>
434int Tensor2<bool>::allocate();
435template <>
436int Tensor3<bool>::allocate();
437template <>
438int Tensor4<bool>::allocate();
439template <>
440int Tensor5<bool>::allocate();
441template <>
442int Tensor6<bool>::allocate();
443
444template <>
445int Tensor0<float>::copyValueFrom(Tensor* src);
446template <>
447int Tensor1<float>::copyValueFrom(Tensor* src);
448template <>
449int Tensor2<float>::copyValueFrom(Tensor* src);
450template <>
451int Tensor3<float>::copyValueFrom(Tensor* src);
452template <>
453int Tensor4<float>::copyValueFrom(Tensor* src);
454template <>
455int Tensor5<float>::copyValueFrom(Tensor* src);
456template <>
457int Tensor6<float>::copyValueFrom(Tensor* src);
458
459template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000460int Tensor0<double>::copyValueFrom(Tensor* src);
461template <>
462int Tensor1<double>::copyValueFrom(Tensor* src);
463template <>
464int Tensor2<double>::copyValueFrom(Tensor* src);
465template <>
466int Tensor3<double>::copyValueFrom(Tensor* src);
467template <>
468int Tensor4<double>::copyValueFrom(Tensor* src);
469template <>
470int Tensor5<double>::copyValueFrom(Tensor* src);
471template <>
472int Tensor6<double>::copyValueFrom(Tensor* src);
473
474template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700475int Tensor0<int32_t>::copyValueFrom(Tensor* src);
476template <>
477int Tensor1<int32_t>::copyValueFrom(Tensor* src);
478template <>
479int Tensor2<int32_t>::copyValueFrom(Tensor* src);
480template <>
481int Tensor3<int32_t>::copyValueFrom(Tensor* src);
482template <>
483int Tensor4<int32_t>::copyValueFrom(Tensor* src);
484template <>
485int Tensor5<int32_t>::copyValueFrom(Tensor* src);
486template <>
487int Tensor6<int32_t>::copyValueFrom(Tensor* src);
488
489template <>
490int Tensor0<int64_t>::copyValueFrom(Tensor* src);
491template <>
492int Tensor1<int64_t>::copyValueFrom(Tensor* src);
493template <>
494int Tensor2<int64_t>::copyValueFrom(Tensor* src);
495template <>
496int Tensor3<int64_t>::copyValueFrom(Tensor* src);
497template <>
498int Tensor4<int64_t>::copyValueFrom(Tensor* src);
499template <>
500int Tensor5<int64_t>::copyValueFrom(Tensor* src);
501template <>
502int Tensor6<int64_t>::copyValueFrom(Tensor* src);
503
504template <>
505int Tensor0<bool>::copyValueFrom(Tensor* src);
506template <>
507int Tensor1<bool>::copyValueFrom(Tensor* src);
508template <>
509int Tensor2<bool>::copyValueFrom(Tensor* src);
510template <>
511int Tensor3<bool>::copyValueFrom(Tensor* src);
512template <>
513int Tensor4<bool>::copyValueFrom(Tensor* src);
514template <>
515int Tensor5<bool>::copyValueFrom(Tensor* src);
516template <>
517int Tensor6<bool>::copyValueFrom(Tensor* src);
518
519template <>
520int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
521template <>
522int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
523template <>
524int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
525template <>
526int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
527template <>
528int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
529template <>
530int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
531template <>
532int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
533
534template <>
535int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
536template <>
537int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
538template <>
539int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
540template <>
541int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
542template <>
543int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
544template <>
545int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
546template <>
547int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
548
549template <>
550int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
551template <>
552int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
553template <>
554int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
555template <>
556int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
557template <>
558int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
559template <>
560int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
561template <>
562int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
563
564template <>
565int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
566template <>
567int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
568template <>
569int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
570template <>
571int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
572template <>
573int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
574template <>
575int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
576template <>
577int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
578
579template <>
580int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
581template <>
582int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
583template <>
584int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
585template <>
586int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
587template <>
588int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
589template <>
590int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
591template <>
592int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
593
594template <>
595int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
596template <>
597int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
598template <>
599int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
600template <>
601int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
602template <>
603int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
604template <>
605int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
606template <>
607int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
608
609template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000610int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
611template <>
612int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
613template <>
614int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
615template <>
616int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
617template <>
618int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
619template <>
620int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
621template <>
622int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
623
624template <>
625int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
626template <>
627int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
628template <>
629int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
630template <>
631int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
632template <>
633int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
634template <>
635int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
636template <>
637int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
638
639template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700640int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
641template <>
642int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
643template <>
644int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
645template <>
646int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
647template <>
648int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
649template <>
650int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
651template <>
652int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
653
654template <>
655int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
656template <>
657int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
658template <>
659int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
660template <>
661int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
662template <>
663int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
664template <>
665int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
666template <>
667int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
668
Eric Kunzee5e26762020-10-13 16:11:07 -0700669template <>
670int Tensor0<float>::dumpTensor(FILE* out) const;
671template <>
672int Tensor1<float>::dumpTensor(FILE* out) const;
673template <>
674int Tensor2<float>::dumpTensor(FILE* out) const;
675template <>
676int Tensor3<float>::dumpTensor(FILE* out) const;
677template <>
678int Tensor4<float>::dumpTensor(FILE* out) const;
679template <>
680int Tensor5<float>::dumpTensor(FILE* out) const;
681template <>
682int Tensor6<float>::dumpTensor(FILE* out) const;
683template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000684int Tensor0<double>::dumpTensor(FILE* out) const;
685template <>
686int Tensor1<double>::dumpTensor(FILE* out) const;
687template <>
688int Tensor2<double>::dumpTensor(FILE* out) const;
689template <>
690int Tensor3<double>::dumpTensor(FILE* out) const;
691template <>
692int Tensor4<double>::dumpTensor(FILE* out) const;
693template <>
694int Tensor5<float>::dumpTensor(FILE* out) const;
695template <>
696int Tensor6<double>::dumpTensor(FILE* out) const;
697template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700698int Tensor0<int32_t>::dumpTensor(FILE* out) const;
699template <>
700int Tensor1<int32_t>::dumpTensor(FILE* out) const;
701template <>
702int Tensor2<int32_t>::dumpTensor(FILE* out) const;
703template <>
704int Tensor3<int32_t>::dumpTensor(FILE* out) const;
705template <>
706int Tensor4<int32_t>::dumpTensor(FILE* out) const;
707template <>
708int Tensor5<int32_t>::dumpTensor(FILE* out) const;
709template <>
710int Tensor6<int32_t>::dumpTensor(FILE* out) const;
711template <>
712int Tensor0<int64_t>::dumpTensor(FILE* out) const;
713template <>
714int Tensor1<int64_t>::dumpTensor(FILE* out) const;
715template <>
716int Tensor2<int64_t>::dumpTensor(FILE* out) const;
717template <>
718int Tensor3<int64_t>::dumpTensor(FILE* out) const;
719template <>
720int Tensor4<int64_t>::dumpTensor(FILE* out) const;
721template <>
722int Tensor5<int64_t>::dumpTensor(FILE* out) const;
723template <>
724int Tensor6<int64_t>::dumpTensor(FILE* out) const;
725template <>
726int Tensor0<bool>::dumpTensor(FILE* out) const;
727template <>
728int Tensor1<bool>::dumpTensor(FILE* out) const;
729template <>
730int Tensor2<bool>::dumpTensor(FILE* out) const;
731template <>
732int Tensor3<bool>::dumpTensor(FILE* out) const;
733template <>
734int Tensor4<bool>::dumpTensor(FILE* out) const;
735template <>
736int Tensor5<bool>::dumpTensor(FILE* out) const;
737template <>
738int Tensor6<bool>::dumpTensor(FILE* out) const;
739
740class TensorFactory
741{
742public:
Tai Lya4d748b2023-03-28 22:06:56 +0000743 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700744 {
Tai Lya4d748b2023-03-28 22:06:56 +0000745 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700746 switch (tensorDtype_)
747 {
Tai Lya4d748b2023-03-28 22:06:56 +0000748 case TOSA_REF_TYPE_FP32:
749 case TOSA_REF_TYPE_FP16:
750 case TOSA_REF_TYPE_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700751 switch (rank)
752 {
753 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000754 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000756 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000758 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700759 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000760 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700761 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000762 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700763 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000764 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700765 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000766 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700768 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000769 case TOSA_REF_TYPE_INT32:
770 case TOSA_REF_TYPE_UINT8:
771 case TOSA_REF_TYPE_INT4:
772 case TOSA_REF_TYPE_INT8:
773 case TOSA_REF_TYPE_INT16:
774 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 switch (rank)
776 {
777 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000778 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000780 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700781 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000782 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700783 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000784 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700785 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000786 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700787 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000788 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700789 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000790 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700791 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700792 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000793 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700794 switch (rank)
795 {
796 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000797 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700798 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000799 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000801 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000803 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000805 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000807 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700808 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000809 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700810 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700811 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000812 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700813 switch (rank)
814 {
815 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000816 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700817 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000818 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000820 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700821 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000822 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700823 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000824 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700825 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000826 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700827 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000828 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700829 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700830 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000831 case TOSA_REF_TYPE_FP64:
832 switch (rank)
833 {
834 case 0:
835 return new Tensor0<double>(tensorName_, dtype_, shape_);
836 case 1:
837 return new Tensor1<double>(tensorName_, dtype_, shape_);
838 case 2:
839 return new Tensor2<double>(tensorName_, dtype_, shape_);
840 case 3:
841 return new Tensor3<double>(tensorName_, dtype_, shape_);
842 case 4:
843 return new Tensor4<double>(tensorName_, dtype_, shape_);
844 case 5:
845 return new Tensor5<double>(tensorName_, dtype_, shape_);
846 case 6:
847 return new Tensor6<double>(tensorName_, dtype_, shape_);
848 }
849 break;
850 case TOSA_REF_TYPE_UNKNOWN:
851 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -0700852 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700853 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700854 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700855 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700856};
857}; // namespace TosaReference
858
859#endif