blob: b68a9b65acea3de699a2b3b1d9ebf82823166092 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
45 int getIsParentGraphOutput() const {
46 return isParentGraphOutput;
47 }
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 int getIsSubgraphInput() const
50 {
51 return isSubgraphInput;
52 }
53
54 int getIsSubgraphOutput() const
55 {
56 return isSubgraphOutput;
57 }
58
59 int setProducer(GraphNode* node);
60 int addConsumer(GraphNode* node);
61
62 int setIsValid()
63 {
64 isValid = 1;
65 return 0;
66 }
67
68 int clearIsValid()
69 {
70 isValid = 0;
71 return 0;
72 }
73
74 int getIsValid() const
75 {
76 return isValid;
77 }
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 GraphNode* getProducer()
80 {
81 return producer;
82 }
83
84 std::vector<GraphNode*>& getConsumers()
85 {
86 return consumers;
87 }
88
89 const std::string& getName() const
90 {
91 return tensorName;
92 }
93
94 const std::vector<int>& getShape() const
95 {
96 return shape;
97 }
98
99 std::string getShapeAsString() const
100 {
101 std::string shape_str("[");
102 for (auto& dim : shape)
103 {
104 shape_str += (std::to_string(dim) + ", ");
105 }
106 shape_str.append("]");
107 return shape_str;
108 }
109
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 const uint32_t getElementCount() const
111 {
112 uint32_t elements = 1;
113 for (size_t i = 0; i < shape.size(); i++)
114 elements *= shape[i];
115
116 return elements;
117 }
118
119 // Comparison of rank and type with other tensors
120 const int matchRank(const Tensor& ref) const
121 {
122 return (ref.shape.size() == shape.size()) ? 0 : 1;
123 }
124
125 const int matchType(const Tensor& ref) const
126 {
127 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
128 }
129
130 const int matchRankType(const Tensor& ref) const
131 {
132 return (matchType(ref) || matchRank(ref));
133 }
134
135 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
136 {
137 if (matchRankType(ref))
138 return 1;
139
140 for (size_t i = 0; i < shape.size(); i++)
141 {
142 if (shape[i] != ref.shape[i])
143 {
144 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000145 // For broadcasts, the order of *this and ref matters.
146 // *this should be the source tensor.
147 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
148 // this->shape must have size 1 if they don't match
149 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700150 {
151 return 1;
152 }
153 }
154 }
155
156 return 0;
157 }
158
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800159 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
160 {
161 if (matchRank(ref))
162 return 1;
163
164 for (size_t i = 0; i < shape.size(); i++)
165 {
166 if (shape[i] != ref.shape[i])
167 {
168 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000169 // For broadcasts, the order of *this and ref matters.
170 // *this should be the source tensor.
171 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
172 // this->shape must have size 1 if they don't match
173 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800174 {
175 return 1;
176 }
177 }
178 }
179
180 return 0;
181 }
182
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 // Sometimes we might want to match several semi-compatible types,
184 // so just check rank and size here
185 const int matchRankSize(const Tensor& ref) const
186 {
187 if (matchRank(ref))
188 return 1;
189
190 for (size_t i = 0; i < shape.size(); i++)
191 {
192 if (shape[i] != ref.shape[i])
193 return 1;
194 }
195
196 return 0;
197 }
198
199 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000200 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000202 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 }
204
205 const int checkRequiredRank(const int minRank, const int maxRank) const
206 {
207 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
208 }
209
210 const int getRank() const
211 {
212 return shape.size();
213 }
214
Tai Lya4d748b2023-03-28 22:06:56 +0000215 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 {
217 return tensorDtype;
218 }
219
Tai Lya4d748b2023-03-28 22:06:56 +0000220 const DType getSerializationDtype() const
221 {
222 return serializationDtype;
223 }
224
Eric Kunzee5e26762020-10-13 16:11:07 -0700225 virtual int dumpTensor(FILE* out) const = 0;
226 virtual int dumpTensorParams(FILE* out) const;
227 virtual int dumpTensorParams(std::ostream& out) const;
228
Tai Lya4d748b2023-03-28 22:06:56 +0000229 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
231 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
232 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
233 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
Tai Lya4d748b2023-03-28 22:06:56 +0000234 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
236 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
237 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
238 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
239
240 virtual int readFromNpyFile(const char* filename);
241 virtual int writeToNpyFile(const char* filename) const;
242 virtual int copyValueFrom(Tensor* tensor) = 0;
243
Tai Lya4d748b2023-03-28 22:06:56 +0000244 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000245 virtual int readfromVector(const ArrayProxy<float> vals);
246 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
247 virtual int readfromVector(const ArrayProxy<int32_t> vals);
248 virtual int readfromVector(const ArrayProxy<int64_t> vals);
249 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100250
Tai Lya4d748b2023-03-28 22:06:56 +0000251 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000252 virtual int writeToVector(ArrayProxy<float> vals);
253 virtual int writeToVector(ArrayProxy<half_float::half> vals);
254 virtual int writeToVector(ArrayProxy<int32_t> vals);
255 virtual int writeToVector(ArrayProxy<int64_t> vals);
256 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100257
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 const char* bool_to_str(bool in) const
259 {
260 static const char* true_str = "true";
261 static const char* false_str = "false";
262 return in ? true_str : false_str;
263 }
264
265 virtual int allocate() = 0;
266 virtual int deallocate() = 0;
267 virtual bool is_allocated() = 0;
268
269protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000270 const std::string tensorName;
271 const DType serializationDtype;
272 const std::vector<int> shape;
273 const TOSA_REF_TYPE tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700274 int isValid;
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 int isSubgraphInput;
276 int isSubgraphOutput;
277 bool isAllocated;
278
Jerry Ge9e94af82022-10-27 09:57:00 -0700279 bool isParentGraphOutput;
280
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 GraphNode* producer;
282 std::vector<GraphNode*> consumers;
283
284 // Note: the Eigen::Tensor is not declared in Tensor
285 // Instead, the TensorTemplate class keeps the templated tensor
286 // declaration so that the graph manipulation tools are isolated
287 // from the templated tensor type.
288 //
289 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
290 // so that they can operate on the right types.
291};
292
293template <class T>
294class TensorTemplate : public Tensor
295{
296public:
Tai Lya4d748b2023-03-28 22:06:56 +0000297 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
298 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 {
300 tensor = nullptr;
301 }
302
303 virtual ~TensorTemplate()
304 {
305 deallocate();
306 }
307
308 virtual int allocate()
309 {
310 tensor = new T();
311 if (tensor)
312 return 0;
313 else
314 return 1;
315 }
316
317 virtual int deallocate()
318 {
319 if (tensor)
320 {
321 delete tensor;
322 }
323 tensor = nullptr;
324 return 0;
325 }
326
327 virtual bool is_allocated()
328 {
329 if (tensor)
330 {
331 return true;
332 }
333 return false;
334 }
335
336 T& getTensor()
337 {
338 return *tensor;
339 }
340
341 virtual int dumpTensor(FILE* out) const;
342
Tai Lya4d748b2023-03-28 22:06:56 +0000343 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700344 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
345 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
346 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
347 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000348
349 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
351 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
352 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
353 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
354
355 virtual int copyValueFrom(Tensor* tensor);
356
357protected:
358 T* tensor;
359};
360
361// allocate() template specializations to allocate the different tensor sizes
362// Let the compiler know here before the factory uses them, but define them in the .cc file.
363template <>
364int Tensor0<float>::allocate();
365template <>
366int Tensor1<float>::allocate();
367template <>
368int Tensor2<float>::allocate();
369template <>
370int Tensor3<float>::allocate();
371template <>
372int Tensor4<float>::allocate();
373template <>
374int Tensor5<float>::allocate();
375template <>
376int Tensor6<float>::allocate();
377
378template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000379int Tensor0<double>::allocate();
380template <>
381int Tensor1<double>::allocate();
382template <>
383int Tensor2<double>::allocate();
384template <>
385int Tensor3<double>::allocate();
386template <>
387int Tensor4<double>::allocate();
388template <>
389int Tensor5<double>::allocate();
390template <>
391int Tensor6<double>::allocate();
392
393template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700394int Tensor0<int32_t>::allocate();
395template <>
396int Tensor1<int32_t>::allocate();
397template <>
398int Tensor2<int32_t>::allocate();
399template <>
400int Tensor3<int32_t>::allocate();
401template <>
402int Tensor4<int32_t>::allocate();
403template <>
404int Tensor5<int32_t>::allocate();
405template <>
406int Tensor6<int32_t>::allocate();
407
408template <>
409int Tensor0<int64_t>::allocate();
410template <>
411int Tensor1<int64_t>::allocate();
412template <>
413int Tensor2<int64_t>::allocate();
414template <>
415int Tensor3<int64_t>::allocate();
416template <>
417int Tensor4<int64_t>::allocate();
418template <>
419int Tensor5<int64_t>::allocate();
420template <>
421int Tensor6<int64_t>::allocate();
422
423template <>
424int Tensor0<bool>::allocate();
425template <>
426int Tensor1<bool>::allocate();
427template <>
428int Tensor2<bool>::allocate();
429template <>
430int Tensor3<bool>::allocate();
431template <>
432int Tensor4<bool>::allocate();
433template <>
434int Tensor5<bool>::allocate();
435template <>
436int Tensor6<bool>::allocate();
437
438template <>
439int Tensor0<float>::copyValueFrom(Tensor* src);
440template <>
441int Tensor1<float>::copyValueFrom(Tensor* src);
442template <>
443int Tensor2<float>::copyValueFrom(Tensor* src);
444template <>
445int Tensor3<float>::copyValueFrom(Tensor* src);
446template <>
447int Tensor4<float>::copyValueFrom(Tensor* src);
448template <>
449int Tensor5<float>::copyValueFrom(Tensor* src);
450template <>
451int Tensor6<float>::copyValueFrom(Tensor* src);
452
453template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000454int Tensor0<double>::copyValueFrom(Tensor* src);
455template <>
456int Tensor1<double>::copyValueFrom(Tensor* src);
457template <>
458int Tensor2<double>::copyValueFrom(Tensor* src);
459template <>
460int Tensor3<double>::copyValueFrom(Tensor* src);
461template <>
462int Tensor4<double>::copyValueFrom(Tensor* src);
463template <>
464int Tensor5<double>::copyValueFrom(Tensor* src);
465template <>
466int Tensor6<double>::copyValueFrom(Tensor* src);
467
468template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700469int Tensor0<int32_t>::copyValueFrom(Tensor* src);
470template <>
471int Tensor1<int32_t>::copyValueFrom(Tensor* src);
472template <>
473int Tensor2<int32_t>::copyValueFrom(Tensor* src);
474template <>
475int Tensor3<int32_t>::copyValueFrom(Tensor* src);
476template <>
477int Tensor4<int32_t>::copyValueFrom(Tensor* src);
478template <>
479int Tensor5<int32_t>::copyValueFrom(Tensor* src);
480template <>
481int Tensor6<int32_t>::copyValueFrom(Tensor* src);
482
483template <>
484int Tensor0<int64_t>::copyValueFrom(Tensor* src);
485template <>
486int Tensor1<int64_t>::copyValueFrom(Tensor* src);
487template <>
488int Tensor2<int64_t>::copyValueFrom(Tensor* src);
489template <>
490int Tensor3<int64_t>::copyValueFrom(Tensor* src);
491template <>
492int Tensor4<int64_t>::copyValueFrom(Tensor* src);
493template <>
494int Tensor5<int64_t>::copyValueFrom(Tensor* src);
495template <>
496int Tensor6<int64_t>::copyValueFrom(Tensor* src);
497
498template <>
499int Tensor0<bool>::copyValueFrom(Tensor* src);
500template <>
501int Tensor1<bool>::copyValueFrom(Tensor* src);
502template <>
503int Tensor2<bool>::copyValueFrom(Tensor* src);
504template <>
505int Tensor3<bool>::copyValueFrom(Tensor* src);
506template <>
507int Tensor4<bool>::copyValueFrom(Tensor* src);
508template <>
509int Tensor5<bool>::copyValueFrom(Tensor* src);
510template <>
511int Tensor6<bool>::copyValueFrom(Tensor* src);
512
513template <>
514int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
515template <>
516int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
517template <>
518int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
519template <>
520int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
521template <>
522int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
523template <>
524int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
525template <>
526int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
527
528template <>
529int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
530template <>
531int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
532template <>
533int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
534template <>
535int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
536template <>
537int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
538template <>
539int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
540template <>
541int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
542
543template <>
544int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
545template <>
546int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
547template <>
548int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
549template <>
550int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
551template <>
552int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
553template <>
554int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
555template <>
556int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
557
558template <>
559int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
560template <>
561int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
562template <>
563int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
564template <>
565int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
566template <>
567int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
568template <>
569int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
570template <>
571int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
572
573template <>
574int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
575template <>
576int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
577template <>
578int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
579template <>
580int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
581template <>
582int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
583template <>
584int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
585template <>
586int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
587
588template <>
589int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
590template <>
591int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
592template <>
593int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
594template <>
595int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
596template <>
597int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
598template <>
599int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
600template <>
601int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
602
603template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000604int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
605template <>
606int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
607template <>
608int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
609template <>
610int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
611template <>
612int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
613template <>
614int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
615template <>
616int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
617
618template <>
619int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
620template <>
621int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
622template <>
623int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
624template <>
625int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
626template <>
627int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
628template <>
629int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
630template <>
631int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
632
633template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700634int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
635template <>
636int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
637template <>
638int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
639template <>
640int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
641template <>
642int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
643template <>
644int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
645template <>
646int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
647
648template <>
649int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
650template <>
651int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
652template <>
653int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
654template <>
655int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
656template <>
657int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
658template <>
659int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
660template <>
661int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
662
Eric Kunzee5e26762020-10-13 16:11:07 -0700663template <>
664int Tensor0<float>::dumpTensor(FILE* out) const;
665template <>
666int Tensor1<float>::dumpTensor(FILE* out) const;
667template <>
668int Tensor2<float>::dumpTensor(FILE* out) const;
669template <>
670int Tensor3<float>::dumpTensor(FILE* out) const;
671template <>
672int Tensor4<float>::dumpTensor(FILE* out) const;
673template <>
674int Tensor5<float>::dumpTensor(FILE* out) const;
675template <>
676int Tensor6<float>::dumpTensor(FILE* out) const;
677template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000678int Tensor0<double>::dumpTensor(FILE* out) const;
679template <>
680int Tensor1<double>::dumpTensor(FILE* out) const;
681template <>
682int Tensor2<double>::dumpTensor(FILE* out) const;
683template <>
684int Tensor3<double>::dumpTensor(FILE* out) const;
685template <>
686int Tensor4<double>::dumpTensor(FILE* out) const;
687template <>
688int Tensor5<float>::dumpTensor(FILE* out) const;
689template <>
690int Tensor6<double>::dumpTensor(FILE* out) const;
691template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700692int Tensor0<int32_t>::dumpTensor(FILE* out) const;
693template <>
694int Tensor1<int32_t>::dumpTensor(FILE* out) const;
695template <>
696int Tensor2<int32_t>::dumpTensor(FILE* out) const;
697template <>
698int Tensor3<int32_t>::dumpTensor(FILE* out) const;
699template <>
700int Tensor4<int32_t>::dumpTensor(FILE* out) const;
701template <>
702int Tensor5<int32_t>::dumpTensor(FILE* out) const;
703template <>
704int Tensor6<int32_t>::dumpTensor(FILE* out) const;
705template <>
706int Tensor0<int64_t>::dumpTensor(FILE* out) const;
707template <>
708int Tensor1<int64_t>::dumpTensor(FILE* out) const;
709template <>
710int Tensor2<int64_t>::dumpTensor(FILE* out) const;
711template <>
712int Tensor3<int64_t>::dumpTensor(FILE* out) const;
713template <>
714int Tensor4<int64_t>::dumpTensor(FILE* out) const;
715template <>
716int Tensor5<int64_t>::dumpTensor(FILE* out) const;
717template <>
718int Tensor6<int64_t>::dumpTensor(FILE* out) const;
719template <>
720int Tensor0<bool>::dumpTensor(FILE* out) const;
721template <>
722int Tensor1<bool>::dumpTensor(FILE* out) const;
723template <>
724int Tensor2<bool>::dumpTensor(FILE* out) const;
725template <>
726int Tensor3<bool>::dumpTensor(FILE* out) const;
727template <>
728int Tensor4<bool>::dumpTensor(FILE* out) const;
729template <>
730int Tensor5<bool>::dumpTensor(FILE* out) const;
731template <>
732int Tensor6<bool>::dumpTensor(FILE* out) const;
733
734class TensorFactory
735{
736public:
Tai Lya4d748b2023-03-28 22:06:56 +0000737 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700738 {
Tai Lya4d748b2023-03-28 22:06:56 +0000739 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700740 switch (tensorDtype_)
741 {
Tai Lya4d748b2023-03-28 22:06:56 +0000742 case TOSA_REF_TYPE_FP32:
743 case TOSA_REF_TYPE_FP16:
744 case TOSA_REF_TYPE_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700745 switch (rank)
746 {
747 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000748 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700749 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000750 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700751 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000752 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700753 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000754 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000756 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000758 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700759 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000760 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700761 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700762 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000763 case TOSA_REF_TYPE_INT32:
764 case TOSA_REF_TYPE_UINT8:
765 case TOSA_REF_TYPE_INT4:
766 case TOSA_REF_TYPE_INT8:
767 case TOSA_REF_TYPE_INT16:
768 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 switch (rank)
770 {
771 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000772 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000774 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000776 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700777 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000778 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000780 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700781 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000782 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700783 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000784 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700785 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700786 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000787 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700788 switch (rank)
789 {
790 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000791 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000793 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700794 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000795 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700796 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000797 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700798 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000799 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000801 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000803 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700805 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000806 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700807 switch (rank)
808 {
809 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000810 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700811 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000812 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700813 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000814 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700815 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000816 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700817 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000818 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000820 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700821 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000822 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700823 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700824 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000825 case TOSA_REF_TYPE_FP64:
826 switch (rank)
827 {
828 case 0:
829 return new Tensor0<double>(tensorName_, dtype_, shape_);
830 case 1:
831 return new Tensor1<double>(tensorName_, dtype_, shape_);
832 case 2:
833 return new Tensor2<double>(tensorName_, dtype_, shape_);
834 case 3:
835 return new Tensor3<double>(tensorName_, dtype_, shape_);
836 case 4:
837 return new Tensor4<double>(tensorName_, dtype_, shape_);
838 case 5:
839 return new Tensor5<double>(tensorName_, dtype_, shape_);
840 case 6:
841 return new Tensor6<double>(tensorName_, dtype_, shape_);
842 }
843 break;
844 case TOSA_REF_TYPE_UNKNOWN:
845 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -0700846 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700847 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700848 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700849 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700850};
851}; // namespace TosaReference
852
853#endif