blob: d49e0047d13c8d4c7619d86e61555614f2efd229 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
45 int getIsParentGraphOutput() const {
46 return isParentGraphOutput;
47 }
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 int getIsSubgraphInput() const
50 {
51 return isSubgraphInput;
52 }
53
54 int getIsSubgraphOutput() const
55 {
56 return isSubgraphOutput;
57 }
58
59 int setProducer(GraphNode* node);
60 int addConsumer(GraphNode* node);
61
62 int setIsValid()
63 {
64 isValid = 1;
65 return 0;
66 }
67
68 int clearIsValid()
69 {
70 isValid = 0;
71 return 0;
72 }
73
74 int getIsValid() const
75 {
76 return isValid;
77 }
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 GraphNode* getProducer()
80 {
81 return producer;
82 }
83
84 std::vector<GraphNode*>& getConsumers()
85 {
86 return consumers;
87 }
88
89 const std::string& getName() const
90 {
91 return tensorName;
92 }
93
94 const std::vector<int>& getShape() const
95 {
96 return shape;
97 }
98
Jerry Ge264f7fa2023-04-21 22:49:57 +000099 void setDimSize(size_t dim, uint32_t new_size)
100 {
101 this->shape[dim] = new_size;
102 return;
103 }
104
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 std::string getShapeAsString() const
106 {
107 std::string shape_str("[");
108 for (auto& dim : shape)
109 {
110 shape_str += (std::to_string(dim) + ", ");
111 }
112 shape_str.append("]");
113 return shape_str;
114 }
115
Eric Kunzee5e26762020-10-13 16:11:07 -0700116 const uint32_t getElementCount() const
117 {
118 uint32_t elements = 1;
119 for (size_t i = 0; i < shape.size(); i++)
120 elements *= shape[i];
121
122 return elements;
123 }
124
125 // Comparison of rank and type with other tensors
126 const int matchRank(const Tensor& ref) const
127 {
128 return (ref.shape.size() == shape.size()) ? 0 : 1;
129 }
130
131 const int matchType(const Tensor& ref) const
132 {
133 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
134 }
135
136 const int matchRankType(const Tensor& ref) const
137 {
138 return (matchType(ref) || matchRank(ref));
139 }
140
141 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
142 {
143 if (matchRankType(ref))
144 return 1;
145
146 for (size_t i = 0; i < shape.size(); i++)
147 {
148 if (shape[i] != ref.shape[i])
149 {
150 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000151 // For broadcasts, the order of *this and ref matters.
152 // *this should be the source tensor.
153 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
154 // this->shape must have size 1 if they don't match
155 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 {
157 return 1;
158 }
159 }
160 }
161
162 return 0;
163 }
164
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800165 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
166 {
167 if (matchRank(ref))
168 return 1;
169
170 for (size_t i = 0; i < shape.size(); i++)
171 {
172 if (shape[i] != ref.shape[i])
173 {
174 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000175 // For broadcasts, the order of *this and ref matters.
176 // *this should be the source tensor.
177 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
178 // this->shape must have size 1 if they don't match
179 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800180 {
181 return 1;
182 }
183 }
184 }
185
186 return 0;
187 }
188
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 // Sometimes we might want to match several semi-compatible types,
190 // so just check rank and size here
191 const int matchRankSize(const Tensor& ref) const
192 {
193 if (matchRank(ref))
194 return 1;
195
196 for (size_t i = 0; i < shape.size(); i++)
197 {
198 if (shape[i] != ref.shape[i])
199 return 1;
200 }
201
202 return 0;
203 }
204
205 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000206 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000208 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 }
210
211 const int checkRequiredRank(const int minRank, const int maxRank) const
212 {
213 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
214 }
215
216 const int getRank() const
217 {
218 return shape.size();
219 }
220
Tai Lya4d748b2023-03-28 22:06:56 +0000221 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 {
223 return tensorDtype;
224 }
225
Tai Lya4d748b2023-03-28 22:06:56 +0000226 const DType getSerializationDtype() const
227 {
228 return serializationDtype;
229 }
230
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 virtual int dumpTensor(FILE* out) const = 0;
232 virtual int dumpTensorParams(FILE* out) const;
233 virtual int dumpTensorParams(std::ostream& out) const;
234
Tai Lya4d748b2023-03-28 22:06:56 +0000235 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700236 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
237 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
238 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
239 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
Tai Lya4d748b2023-03-28 22:06:56 +0000240 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
242 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
243 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
244 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
245
246 virtual int readFromNpyFile(const char* filename);
247 virtual int writeToNpyFile(const char* filename) const;
248 virtual int copyValueFrom(Tensor* tensor) = 0;
249
Tai Lya4d748b2023-03-28 22:06:56 +0000250 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000251 virtual int readfromVector(const ArrayProxy<float> vals);
252 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
253 virtual int readfromVector(const ArrayProxy<int32_t> vals);
254 virtual int readfromVector(const ArrayProxy<int64_t> vals);
255 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100256
Tai Lya4d748b2023-03-28 22:06:56 +0000257 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000258 virtual int writeToVector(ArrayProxy<float> vals);
259 virtual int writeToVector(ArrayProxy<half_float::half> vals);
260 virtual int writeToVector(ArrayProxy<int32_t> vals);
261 virtual int writeToVector(ArrayProxy<int64_t> vals);
262 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100263
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 const char* bool_to_str(bool in) const
265 {
266 static const char* true_str = "true";
267 static const char* false_str = "false";
268 return in ? true_str : false_str;
269 }
270
271 virtual int allocate() = 0;
272 virtual int deallocate() = 0;
273 virtual bool is_allocated() = 0;
274
275protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000276 const std::string tensorName;
277 const DType serializationDtype;
Jerry Ge264f7fa2023-04-21 22:49:57 +0000278 std::vector<int> shape;
Tai Lya4d748b2023-03-28 22:06:56 +0000279 const TOSA_REF_TYPE tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700280 int isValid;
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 int isSubgraphInput;
282 int isSubgraphOutput;
283 bool isAllocated;
284
Jerry Ge9e94af82022-10-27 09:57:00 -0700285 bool isParentGraphOutput;
286
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 GraphNode* producer;
288 std::vector<GraphNode*> consumers;
289
290 // Note: the Eigen::Tensor is not declared in Tensor
291 // Instead, the TensorTemplate class keeps the templated tensor
292 // declaration so that the graph manipulation tools are isolated
293 // from the templated tensor type.
294 //
295 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
296 // so that they can operate on the right types.
297};
298
299template <class T>
300class TensorTemplate : public Tensor
301{
302public:
Tai Lya4d748b2023-03-28 22:06:56 +0000303 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
304 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305 {
306 tensor = nullptr;
307 }
308
309 virtual ~TensorTemplate()
310 {
311 deallocate();
312 }
313
314 virtual int allocate()
315 {
316 tensor = new T();
317 if (tensor)
318 return 0;
319 else
320 return 1;
321 }
322
323 virtual int deallocate()
324 {
325 if (tensor)
326 {
Eric Kunze9a367552023-07-11 13:27:36 -0700327 DEBUG_INFO(GT, "Deallocating tensor %s", tensorName.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700328 delete tensor;
329 }
330 tensor = nullptr;
331 return 0;
332 }
333
334 virtual bool is_allocated()
335 {
336 if (tensor)
337 {
338 return true;
339 }
340 return false;
341 }
342
343 T& getTensor()
344 {
345 return *tensor;
346 }
347
348 virtual int dumpTensor(FILE* out) const;
349
Tai Lya4d748b2023-03-28 22:06:56 +0000350 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700351 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
352 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
353 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
354 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000355
356 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700357 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
358 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
359 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
360 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
361
362 virtual int copyValueFrom(Tensor* tensor);
363
364protected:
365 T* tensor;
366};
367
368// allocate() template specializations to allocate the different tensor sizes
369// Let the compiler know here before the factory uses them, but define them in the .cc file.
370template <>
371int Tensor0<float>::allocate();
372template <>
373int Tensor1<float>::allocate();
374template <>
375int Tensor2<float>::allocate();
376template <>
377int Tensor3<float>::allocate();
378template <>
379int Tensor4<float>::allocate();
380template <>
381int Tensor5<float>::allocate();
382template <>
383int Tensor6<float>::allocate();
384
385template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000386int Tensor0<double>::allocate();
387template <>
388int Tensor1<double>::allocate();
389template <>
390int Tensor2<double>::allocate();
391template <>
392int Tensor3<double>::allocate();
393template <>
394int Tensor4<double>::allocate();
395template <>
396int Tensor5<double>::allocate();
397template <>
398int Tensor6<double>::allocate();
399
400template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700401int Tensor0<int32_t>::allocate();
402template <>
403int Tensor1<int32_t>::allocate();
404template <>
405int Tensor2<int32_t>::allocate();
406template <>
407int Tensor3<int32_t>::allocate();
408template <>
409int Tensor4<int32_t>::allocate();
410template <>
411int Tensor5<int32_t>::allocate();
412template <>
413int Tensor6<int32_t>::allocate();
414
415template <>
416int Tensor0<int64_t>::allocate();
417template <>
418int Tensor1<int64_t>::allocate();
419template <>
420int Tensor2<int64_t>::allocate();
421template <>
422int Tensor3<int64_t>::allocate();
423template <>
424int Tensor4<int64_t>::allocate();
425template <>
426int Tensor5<int64_t>::allocate();
427template <>
428int Tensor6<int64_t>::allocate();
429
430template <>
431int Tensor0<bool>::allocate();
432template <>
433int Tensor1<bool>::allocate();
434template <>
435int Tensor2<bool>::allocate();
436template <>
437int Tensor3<bool>::allocate();
438template <>
439int Tensor4<bool>::allocate();
440template <>
441int Tensor5<bool>::allocate();
442template <>
443int Tensor6<bool>::allocate();
444
445template <>
446int Tensor0<float>::copyValueFrom(Tensor* src);
447template <>
448int Tensor1<float>::copyValueFrom(Tensor* src);
449template <>
450int Tensor2<float>::copyValueFrom(Tensor* src);
451template <>
452int Tensor3<float>::copyValueFrom(Tensor* src);
453template <>
454int Tensor4<float>::copyValueFrom(Tensor* src);
455template <>
456int Tensor5<float>::copyValueFrom(Tensor* src);
457template <>
458int Tensor6<float>::copyValueFrom(Tensor* src);
459
460template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000461int Tensor0<double>::copyValueFrom(Tensor* src);
462template <>
463int Tensor1<double>::copyValueFrom(Tensor* src);
464template <>
465int Tensor2<double>::copyValueFrom(Tensor* src);
466template <>
467int Tensor3<double>::copyValueFrom(Tensor* src);
468template <>
469int Tensor4<double>::copyValueFrom(Tensor* src);
470template <>
471int Tensor5<double>::copyValueFrom(Tensor* src);
472template <>
473int Tensor6<double>::copyValueFrom(Tensor* src);
474
475template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700476int Tensor0<int32_t>::copyValueFrom(Tensor* src);
477template <>
478int Tensor1<int32_t>::copyValueFrom(Tensor* src);
479template <>
480int Tensor2<int32_t>::copyValueFrom(Tensor* src);
481template <>
482int Tensor3<int32_t>::copyValueFrom(Tensor* src);
483template <>
484int Tensor4<int32_t>::copyValueFrom(Tensor* src);
485template <>
486int Tensor5<int32_t>::copyValueFrom(Tensor* src);
487template <>
488int Tensor6<int32_t>::copyValueFrom(Tensor* src);
489
490template <>
491int Tensor0<int64_t>::copyValueFrom(Tensor* src);
492template <>
493int Tensor1<int64_t>::copyValueFrom(Tensor* src);
494template <>
495int Tensor2<int64_t>::copyValueFrom(Tensor* src);
496template <>
497int Tensor3<int64_t>::copyValueFrom(Tensor* src);
498template <>
499int Tensor4<int64_t>::copyValueFrom(Tensor* src);
500template <>
501int Tensor5<int64_t>::copyValueFrom(Tensor* src);
502template <>
503int Tensor6<int64_t>::copyValueFrom(Tensor* src);
504
505template <>
506int Tensor0<bool>::copyValueFrom(Tensor* src);
507template <>
508int Tensor1<bool>::copyValueFrom(Tensor* src);
509template <>
510int Tensor2<bool>::copyValueFrom(Tensor* src);
511template <>
512int Tensor3<bool>::copyValueFrom(Tensor* src);
513template <>
514int Tensor4<bool>::copyValueFrom(Tensor* src);
515template <>
516int Tensor5<bool>::copyValueFrom(Tensor* src);
517template <>
518int Tensor6<bool>::copyValueFrom(Tensor* src);
519
520template <>
521int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
522template <>
523int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
524template <>
525int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
526template <>
527int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
528template <>
529int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
530template <>
531int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
532template <>
533int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
534
535template <>
536int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
537template <>
538int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
539template <>
540int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
541template <>
542int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
543template <>
544int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
545template <>
546int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
547template <>
548int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
549
550template <>
551int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
552template <>
553int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
554template <>
555int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
556template <>
557int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
558template <>
559int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
560template <>
561int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
562template <>
563int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
564
565template <>
566int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
567template <>
568int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
569template <>
570int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
571template <>
572int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
573template <>
574int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
575template <>
576int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
577template <>
578int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
579
580template <>
581int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
582template <>
583int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
584template <>
585int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
586template <>
587int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
588template <>
589int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
590template <>
591int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
592template <>
593int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
594
595template <>
596int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
597template <>
598int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
599template <>
600int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
601template <>
602int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
603template <>
604int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
605template <>
606int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
607template <>
608int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
609
610template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000611int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
612template <>
613int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
614template <>
615int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
616template <>
617int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
618template <>
619int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
620template <>
621int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
622template <>
623int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
624
625template <>
626int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
627template <>
628int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
629template <>
630int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
631template <>
632int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
633template <>
634int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
635template <>
636int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
637template <>
638int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
639
640template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700641int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
642template <>
643int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
644template <>
645int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
646template <>
647int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
648template <>
649int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
650template <>
651int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
652template <>
653int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
654
655template <>
656int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
657template <>
658int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
659template <>
660int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
661template <>
662int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
663template <>
664int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
665template <>
666int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
667template <>
668int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
669
Eric Kunzee5e26762020-10-13 16:11:07 -0700670template <>
671int Tensor0<float>::dumpTensor(FILE* out) const;
672template <>
673int Tensor1<float>::dumpTensor(FILE* out) const;
674template <>
675int Tensor2<float>::dumpTensor(FILE* out) const;
676template <>
677int Tensor3<float>::dumpTensor(FILE* out) const;
678template <>
679int Tensor4<float>::dumpTensor(FILE* out) const;
680template <>
681int Tensor5<float>::dumpTensor(FILE* out) const;
682template <>
683int Tensor6<float>::dumpTensor(FILE* out) const;
684template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000685int Tensor0<double>::dumpTensor(FILE* out) const;
686template <>
687int Tensor1<double>::dumpTensor(FILE* out) const;
688template <>
689int Tensor2<double>::dumpTensor(FILE* out) const;
690template <>
691int Tensor3<double>::dumpTensor(FILE* out) const;
692template <>
693int Tensor4<double>::dumpTensor(FILE* out) const;
694template <>
695int Tensor5<float>::dumpTensor(FILE* out) const;
696template <>
697int Tensor6<double>::dumpTensor(FILE* out) const;
698template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700699int Tensor0<int32_t>::dumpTensor(FILE* out) const;
700template <>
701int Tensor1<int32_t>::dumpTensor(FILE* out) const;
702template <>
703int Tensor2<int32_t>::dumpTensor(FILE* out) const;
704template <>
705int Tensor3<int32_t>::dumpTensor(FILE* out) const;
706template <>
707int Tensor4<int32_t>::dumpTensor(FILE* out) const;
708template <>
709int Tensor5<int32_t>::dumpTensor(FILE* out) const;
710template <>
711int Tensor6<int32_t>::dumpTensor(FILE* out) const;
712template <>
713int Tensor0<int64_t>::dumpTensor(FILE* out) const;
714template <>
715int Tensor1<int64_t>::dumpTensor(FILE* out) const;
716template <>
717int Tensor2<int64_t>::dumpTensor(FILE* out) const;
718template <>
719int Tensor3<int64_t>::dumpTensor(FILE* out) const;
720template <>
721int Tensor4<int64_t>::dumpTensor(FILE* out) const;
722template <>
723int Tensor5<int64_t>::dumpTensor(FILE* out) const;
724template <>
725int Tensor6<int64_t>::dumpTensor(FILE* out) const;
726template <>
727int Tensor0<bool>::dumpTensor(FILE* out) const;
728template <>
729int Tensor1<bool>::dumpTensor(FILE* out) const;
730template <>
731int Tensor2<bool>::dumpTensor(FILE* out) const;
732template <>
733int Tensor3<bool>::dumpTensor(FILE* out) const;
734template <>
735int Tensor4<bool>::dumpTensor(FILE* out) const;
736template <>
737int Tensor5<bool>::dumpTensor(FILE* out) const;
738template <>
739int Tensor6<bool>::dumpTensor(FILE* out) const;
740
741class TensorFactory
742{
743public:
Tai Lya4d748b2023-03-28 22:06:56 +0000744 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700745 {
Tai Lya4d748b2023-03-28 22:06:56 +0000746 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700747 switch (tensorDtype_)
748 {
Tai Lya4d748b2023-03-28 22:06:56 +0000749 case TOSA_REF_TYPE_FP32:
750 case TOSA_REF_TYPE_FP16:
751 case TOSA_REF_TYPE_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700752 switch (rank)
753 {
754 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000755 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700756 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000757 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700758 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000759 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000761 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700762 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000763 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000765 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000767 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700768 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700769 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000770 case TOSA_REF_TYPE_INT32:
771 case TOSA_REF_TYPE_UINT8:
772 case TOSA_REF_TYPE_INT4:
773 case TOSA_REF_TYPE_INT8:
774 case TOSA_REF_TYPE_INT16:
775 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700776 switch (rank)
777 {
778 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000779 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700780 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000781 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700782 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000783 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700784 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000785 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700786 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000787 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700788 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000789 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700790 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000791 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700793 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000794 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700795 switch (rank)
796 {
797 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000798 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700799 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000800 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700801 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000802 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700803 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000804 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700805 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000806 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700807 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000808 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700809 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000810 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700811 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700812 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000813 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700814 switch (rank)
815 {
816 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000817 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700818 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000819 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700820 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000821 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700822 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000823 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700824 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000825 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700826 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000827 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700828 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000829 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700830 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700831 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000832 case TOSA_REF_TYPE_FP64:
833 switch (rank)
834 {
835 case 0:
836 return new Tensor0<double>(tensorName_, dtype_, shape_);
837 case 1:
838 return new Tensor1<double>(tensorName_, dtype_, shape_);
839 case 2:
840 return new Tensor2<double>(tensorName_, dtype_, shape_);
841 case 3:
842 return new Tensor3<double>(tensorName_, dtype_, shape_);
843 case 4:
844 return new Tensor4<double>(tensorName_, dtype_, shape_);
845 case 5:
846 return new Tensor5<double>(tensorName_, dtype_, shape_);
847 case 6:
848 return new Tensor6<double>(tensorName_, dtype_, shape_);
849 }
850 break;
851 case TOSA_REF_TYPE_UNKNOWN:
852 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -0700853 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700854 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700855 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700856 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700857};
858}; // namespace TosaReference
859
860#endif