blob: f13de0e870c4bc106e939db23eac5e5b8a2ec1ad [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Ly8690a082023-12-18 20:40:24 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
Tai Lycf84bc92023-09-07 20:49:09 +000045 bool getIsParentGraphOutput() const
Jerry Ge9c9c8da2023-07-19 23:08:16 +000046 {
Jerry Ge9e94af82022-10-27 09:57:00 -070047 return isParentGraphOutput;
48 }
Tai Lycf84bc92023-09-07 20:49:09 +000049 int setIsVariable();
Eric Kunzee5e26762020-10-13 16:11:07 -070050
Tai Lycf84bc92023-09-07 20:49:09 +000051 bool getIsSubgraphInput() const
Eric Kunzee5e26762020-10-13 16:11:07 -070052 {
53 return isSubgraphInput;
54 }
55
Tai Lycf84bc92023-09-07 20:49:09 +000056 bool getIsSubgraphOutput() const
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 return isSubgraphOutput;
59 }
60
Tai Lycf84bc92023-09-07 20:49:09 +000061 bool getIsVariable() const
62 {
63 return isVariable;
64 }
65
Eric Kunzee5e26762020-10-13 16:11:07 -070066 int setProducer(GraphNode* node);
67 int addConsumer(GraphNode* node);
68
69 int setIsValid()
70 {
71 isValid = 1;
72 return 0;
73 }
74
75 int clearIsValid()
76 {
77 isValid = 0;
78 return 0;
79 }
80
81 int getIsValid() const
82 {
83 return isValid;
84 }
85
Eric Kunzee5e26762020-10-13 16:11:07 -070086 GraphNode* getProducer()
87 {
88 return producer;
89 }
90
91 std::vector<GraphNode*>& getConsumers()
92 {
93 return consumers;
94 }
95
96 const std::string& getName() const
97 {
98 return tensorName;
99 }
100
101 const std::vector<int>& getShape() const
102 {
103 return shape;
104 }
105
Jerry Ge264f7fa2023-04-21 22:49:57 +0000106 void setDimSize(size_t dim, uint32_t new_size)
107 {
108 this->shape[dim] = new_size;
109 return;
110 }
111
Jerry Ge12159fc2024-04-01 17:05:10 +0000112 void setShapeValue(std::vector<int>& shapeValue)
113 {
114 for (auto dim : shapeValue)
115 {
116 this->shapeValue.push_back(dim);
117 }
118 return;
119 }
120
121 int getShapeValueSize() const
122 {
123 return this->shapeValue.size();
124 }
125
126 std::string getShapeValueAsString() const
127 {
128 std::string shape_str("[");
129 for (auto& dim : shapeValue)
130 {
131 shape_str += (std::to_string(dim) + ", ");
132 }
133 shape_str.append("]");
134 return shape_str;
135 }
136
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 std::string getShapeAsString() const
138 {
139 std::string shape_str("[");
140 for (auto& dim : shape)
141 {
142 shape_str += (std::to_string(dim) + ", ");
143 }
144 shape_str.append("]");
145 return shape_str;
146 }
147
Eric Kunzee5e26762020-10-13 16:11:07 -0700148 const uint32_t getElementCount() const
149 {
150 uint32_t elements = 1;
151 for (size_t i = 0; i < shape.size(); i++)
152 elements *= shape[i];
153
154 return elements;
155 }
156
157 // Comparison of rank and type with other tensors
158 const int matchRank(const Tensor& ref) const
159 {
160 return (ref.shape.size() == shape.size()) ? 0 : 1;
161 }
162
163 const int matchType(const Tensor& ref) const
164 {
165 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
166 }
167
168 const int matchRankType(const Tensor& ref) const
169 {
170 return (matchType(ref) || matchRank(ref));
171 }
172
173 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
174 {
175 if (matchRankType(ref))
176 return 1;
177
178 for (size_t i = 0; i < shape.size(); i++)
179 {
180 if (shape[i] != ref.shape[i])
181 {
182 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000183 // For broadcasts, the order of *this and ref matters.
184 // *this should be the source tensor.
185 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
186 // this->shape must have size 1 if they don't match
187 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 {
189 return 1;
190 }
191 }
192 }
193
194 return 0;
195 }
196
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800197 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
198 {
199 if (matchRank(ref))
200 return 1;
201
202 for (size_t i = 0; i < shape.size(); i++)
203 {
204 if (shape[i] != ref.shape[i])
205 {
206 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000207 // For broadcasts, the order of *this and ref matters.
208 // *this should be the source tensor.
209 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
210 // this->shape must have size 1 if they don't match
211 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800212 {
213 return 1;
214 }
215 }
216 }
217
218 return 0;
219 }
220
Eric Kunzee5e26762020-10-13 16:11:07 -0700221 // Sometimes we might want to match several semi-compatible types,
222 // so just check rank and size here
223 const int matchRankSize(const Tensor& ref) const
224 {
225 if (matchRank(ref))
226 return 1;
227
228 for (size_t i = 0; i < shape.size(); i++)
229 {
230 if (shape[i] != ref.shape[i])
231 return 1;
232 }
233
234 return 0;
235 }
236
237 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000238 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700239 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000240 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 }
242
243 const int checkRequiredRank(const int minRank, const int maxRank) const
244 {
245 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
246 }
247
248 const int getRank() const
249 {
250 return shape.size();
251 }
252
Tai Lya4d748b2023-03-28 22:06:56 +0000253 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700254 {
255 return tensorDtype;
256 }
257
Tai Lya4d748b2023-03-28 22:06:56 +0000258 const DType getSerializationDtype() const
259 {
260 return serializationDtype;
261 }
262
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 virtual int dumpTensor(FILE* out) const = 0;
264 virtual int dumpTensorParams(FILE* out) const;
265 virtual int dumpTensorParams(std::ostream& out) const;
266
Jerry Ge20ab3df2024-01-26 16:56:55 +0000267 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
268 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
269 virtual int setTensorValueUInt8(const size_t bufLen, const uint8_t* vals) = 0;
270 virtual int setTensorValueInt8(const size_t bufLen, const int8_t* vals) = 0;
271 virtual int setTensorValueUInt16(const size_t bufLen, const uint16_t* vals) = 0;
272 virtual int setTensorValueInt16(const size_t bufLen, const int16_t* vals) = 0;
273 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
274 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
275 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
276 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
277 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
278 virtual int getTensorValueUInt8(const size_t bufLen, uint8_t* ibuf) const = 0;
279 virtual int getTensorValueInt8(const size_t bufLen, int8_t* ibuf) const = 0;
280 virtual int getTensorValueUInt16(const size_t bufLen, uint16_t* ibuf) const = 0;
281 virtual int getTensorValueInt16(const size_t bufLen, int16_t* ibuf) const = 0;
282 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
283 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
284 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 virtual int readFromNpyFile(const char* filename);
287 virtual int writeToNpyFile(const char* filename) const;
288 virtual int copyValueFrom(Tensor* tensor) = 0;
289
Tai Lya4d748b2023-03-28 22:06:56 +0000290 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000291 virtual int readfromVector(const ArrayProxy<float> vals);
292 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
Jerry Gec5291692024-01-02 22:29:08 +0000293 virtual int readfromVector(const ArrayProxy<int8_t> vals);
Jerry Ge20ab3df2024-01-26 16:56:55 +0000294 virtual int readfromVector(const ArrayProxy<uint16_t> vals);
Georgios Pinitase9059772023-12-06 18:52:30 +0000295 virtual int readfromVector(const ArrayProxy<int16_t> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000296 virtual int readfromVector(const ArrayProxy<int32_t> vals);
297 virtual int readfromVector(const ArrayProxy<int64_t> vals);
298 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100299
Tai Lya4d748b2023-03-28 22:06:56 +0000300 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000301 virtual int writeToVector(ArrayProxy<float> vals);
302 virtual int writeToVector(ArrayProxy<half_float::half> vals);
Jerry Gec5291692024-01-02 22:29:08 +0000303 virtual int writeToVector(ArrayProxy<int8_t> vals);
Jerry Ge20ab3df2024-01-26 16:56:55 +0000304 virtual int writeToVector(ArrayProxy<uint16_t> vals);
Georgios Pinitase9059772023-12-06 18:52:30 +0000305 virtual int writeToVector(ArrayProxy<int16_t> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000306 virtual int writeToVector(ArrayProxy<int32_t> vals);
307 virtual int writeToVector(ArrayProxy<int64_t> vals);
308 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100309
Eric Kunzee5e26762020-10-13 16:11:07 -0700310 const char* bool_to_str(bool in) const
311 {
312 static const char* true_str = "true";
313 static const char* false_str = "false";
314 return in ? true_str : false_str;
315 }
316
Tai Lycf84bc92023-09-07 20:49:09 +0000317 virtual int allocate() = 0;
318 virtual int deallocate() = 0;
319 virtual bool is_allocated() const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000322 const std::string tensorName;
323 const DType serializationDtype;
Jerry Ge264f7fa2023-04-21 22:49:57 +0000324 std::vector<int> shape;
Jerry Ge12159fc2024-04-01 17:05:10 +0000325 std::vector<int> shapeValue;
Tai Lya4d748b2023-03-28 22:06:56 +0000326 const TOSA_REF_TYPE tensorDtype;
Tai Lycf84bc92023-09-07 20:49:09 +0000327 bool isValid;
328 bool isSubgraphInput;
329 bool isSubgraphOutput;
330 bool isVariable;
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 bool isAllocated;
332
Jerry Ge9e94af82022-10-27 09:57:00 -0700333 bool isParentGraphOutput;
334
Eric Kunzee5e26762020-10-13 16:11:07 -0700335 GraphNode* producer;
336 std::vector<GraphNode*> consumers;
337
338 // Note: the Eigen::Tensor is not declared in Tensor
339 // Instead, the TensorTemplate class keeps the templated tensor
340 // declaration so that the graph manipulation tools are isolated
341 // from the templated tensor type.
342 //
343 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
344 // so that they can operate on the right types.
345};
346
347template <class T>
348class TensorTemplate : public Tensor
349{
350public:
Tai Lya4d748b2023-03-28 22:06:56 +0000351 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
352 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700353 {
354 tensor = nullptr;
355 }
356
357 virtual ~TensorTemplate()
358 {
359 deallocate();
360 }
361
362 virtual int allocate()
363 {
364 tensor = new T();
365 if (tensor)
366 return 0;
367 else
368 return 1;
369 }
370
371 virtual int deallocate()
372 {
373 if (tensor)
374 {
Eric Kunze9a367552023-07-11 13:27:36 -0700375 DEBUG_INFO(GT, "Deallocating tensor %s", tensorName.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700376 delete tensor;
377 }
378 tensor = nullptr;
379 return 0;
380 }
381
Tai Lycf84bc92023-09-07 20:49:09 +0000382 virtual bool is_allocated() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700383 {
384 if (tensor)
385 {
386 return true;
387 }
388 return false;
389 }
390
391 T& getTensor()
392 {
393 return *tensor;
394 }
395
396 virtual int dumpTensor(FILE* out) const;
397
Tai Lya4d748b2023-03-28 22:06:56 +0000398 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700399 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
Jerry Gec5291692024-01-02 22:29:08 +0000400 virtual int setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
401 virtual int setTensorValueInt8(const size_t bufLen, const int8_t* vals);
Jerry Ge20ab3df2024-01-26 16:56:55 +0000402 virtual int setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
Georgios Pinitase9059772023-12-06 18:52:30 +0000403 virtual int setTensorValueInt16(const size_t bufLen, const int16_t* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
405 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
406 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000407
408 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700409 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
Jerry Gec5291692024-01-02 22:29:08 +0000410 virtual int getTensorValueUInt8(const size_t bufLen, uint8_t* ibuf) const;
411 virtual int getTensorValueInt8(const size_t bufLen, int8_t* ibuf) const;
Jerry Ge20ab3df2024-01-26 16:56:55 +0000412 virtual int getTensorValueUInt16(const size_t bufLen, uint16_t* ibuf) const;
Georgios Pinitase9059772023-12-06 18:52:30 +0000413 virtual int getTensorValueInt16(const size_t bufLen, int16_t* ibuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700414 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
415 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
416 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
417
418 virtual int copyValueFrom(Tensor* tensor);
419
420protected:
421 T* tensor;
422};
423
424// allocate() template specializations to allocate the different tensor sizes
425// Let the compiler know here before the factory uses them, but define them in the .cc file.
426template <>
427int Tensor0<float>::allocate();
428template <>
429int Tensor1<float>::allocate();
430template <>
431int Tensor2<float>::allocate();
432template <>
433int Tensor3<float>::allocate();
434template <>
435int Tensor4<float>::allocate();
436template <>
437int Tensor5<float>::allocate();
438template <>
439int Tensor6<float>::allocate();
440
441template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000442int Tensor0<double>::allocate();
443template <>
444int Tensor1<double>::allocate();
445template <>
446int Tensor2<double>::allocate();
447template <>
448int Tensor3<double>::allocate();
449template <>
450int Tensor4<double>::allocate();
451template <>
452int Tensor5<double>::allocate();
453template <>
454int Tensor6<double>::allocate();
455
456template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700457int Tensor0<int32_t>::allocate();
458template <>
459int Tensor1<int32_t>::allocate();
460template <>
461int Tensor2<int32_t>::allocate();
462template <>
463int Tensor3<int32_t>::allocate();
464template <>
465int Tensor4<int32_t>::allocate();
466template <>
467int Tensor5<int32_t>::allocate();
468template <>
469int Tensor6<int32_t>::allocate();
470
471template <>
472int Tensor0<int64_t>::allocate();
473template <>
474int Tensor1<int64_t>::allocate();
475template <>
476int Tensor2<int64_t>::allocate();
477template <>
478int Tensor3<int64_t>::allocate();
479template <>
480int Tensor4<int64_t>::allocate();
481template <>
482int Tensor5<int64_t>::allocate();
483template <>
484int Tensor6<int64_t>::allocate();
485
486template <>
487int Tensor0<bool>::allocate();
488template <>
489int Tensor1<bool>::allocate();
490template <>
491int Tensor2<bool>::allocate();
492template <>
493int Tensor3<bool>::allocate();
494template <>
495int Tensor4<bool>::allocate();
496template <>
497int Tensor5<bool>::allocate();
498template <>
499int Tensor6<bool>::allocate();
500
501template <>
502int Tensor0<float>::copyValueFrom(Tensor* src);
503template <>
504int Tensor1<float>::copyValueFrom(Tensor* src);
505template <>
506int Tensor2<float>::copyValueFrom(Tensor* src);
507template <>
508int Tensor3<float>::copyValueFrom(Tensor* src);
509template <>
510int Tensor4<float>::copyValueFrom(Tensor* src);
511template <>
512int Tensor5<float>::copyValueFrom(Tensor* src);
513template <>
514int Tensor6<float>::copyValueFrom(Tensor* src);
515
516template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000517int Tensor0<double>::copyValueFrom(Tensor* src);
518template <>
519int Tensor1<double>::copyValueFrom(Tensor* src);
520template <>
521int Tensor2<double>::copyValueFrom(Tensor* src);
522template <>
523int Tensor3<double>::copyValueFrom(Tensor* src);
524template <>
525int Tensor4<double>::copyValueFrom(Tensor* src);
526template <>
527int Tensor5<double>::copyValueFrom(Tensor* src);
528template <>
529int Tensor6<double>::copyValueFrom(Tensor* src);
530
531template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700532int Tensor0<int32_t>::copyValueFrom(Tensor* src);
533template <>
534int Tensor1<int32_t>::copyValueFrom(Tensor* src);
535template <>
536int Tensor2<int32_t>::copyValueFrom(Tensor* src);
537template <>
538int Tensor3<int32_t>::copyValueFrom(Tensor* src);
539template <>
540int Tensor4<int32_t>::copyValueFrom(Tensor* src);
541template <>
542int Tensor5<int32_t>::copyValueFrom(Tensor* src);
543template <>
544int Tensor6<int32_t>::copyValueFrom(Tensor* src);
545
546template <>
547int Tensor0<int64_t>::copyValueFrom(Tensor* src);
548template <>
549int Tensor1<int64_t>::copyValueFrom(Tensor* src);
550template <>
551int Tensor2<int64_t>::copyValueFrom(Tensor* src);
552template <>
553int Tensor3<int64_t>::copyValueFrom(Tensor* src);
554template <>
555int Tensor4<int64_t>::copyValueFrom(Tensor* src);
556template <>
557int Tensor5<int64_t>::copyValueFrom(Tensor* src);
558template <>
559int Tensor6<int64_t>::copyValueFrom(Tensor* src);
560
561template <>
562int Tensor0<bool>::copyValueFrom(Tensor* src);
563template <>
564int Tensor1<bool>::copyValueFrom(Tensor* src);
565template <>
566int Tensor2<bool>::copyValueFrom(Tensor* src);
567template <>
568int Tensor3<bool>::copyValueFrom(Tensor* src);
569template <>
570int Tensor4<bool>::copyValueFrom(Tensor* src);
571template <>
572int Tensor5<bool>::copyValueFrom(Tensor* src);
573template <>
574int Tensor6<bool>::copyValueFrom(Tensor* src);
575
576template <>
Jerry Gec5291692024-01-02 22:29:08 +0000577int Tensor0<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
578template <>
579int Tensor1<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
580template <>
581int Tensor2<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
582template <>
583int Tensor3<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
584template <>
585int Tensor4<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
586template <>
587int Tensor5<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
588template <>
589int Tensor6<int32_t>::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals);
590
591template <>
592int Tensor0<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
593template <>
594int Tensor1<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
595template <>
596int Tensor2<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
597template <>
598int Tensor3<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
599template <>
600int Tensor4<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
601template <>
602int Tensor5<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
603template <>
604int Tensor6<int32_t>::setTensorValueInt8(const size_t bufLen, const int8_t* vals);
605
606template <>
Jerry Ge20ab3df2024-01-26 16:56:55 +0000607int Tensor0<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
608template <>
609int Tensor1<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
610template <>
611int Tensor2<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
612template <>
613int Tensor3<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
614template <>
615int Tensor4<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
616template <>
617int Tensor5<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
618template <>
619int Tensor6<int32_t>::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals);
620
621template <>
Georgios Pinitase9059772023-12-06 18:52:30 +0000622int Tensor0<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
623template <>
624int Tensor1<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
625template <>
626int Tensor2<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
627template <>
628int Tensor3<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
629template <>
630int Tensor4<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
631template <>
632int Tensor5<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
633template <>
634int Tensor6<int32_t>::setTensorValueInt16(const size_t bufLen, const int16_t* vals);
635
636template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700637int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
638template <>
639int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
640template <>
641int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
642template <>
643int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
644template <>
645int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
646template <>
647int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
648template <>
649int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
650
651template <>
Jerry Gec5291692024-01-02 22:29:08 +0000652int Tensor0<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
653template <>
654int Tensor1<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
655template <>
656int Tensor2<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
657template <>
658int Tensor3<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
659template <>
660int Tensor4<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
661template <>
662int Tensor5<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
663template <>
664int Tensor6<int32_t>::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const;
665
666template <>
667int Tensor0<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
668template <>
669int Tensor1<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
670template <>
671int Tensor2<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
672template <>
673int Tensor3<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
674template <>
675int Tensor4<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
676template <>
677int Tensor5<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
678template <>
679int Tensor6<int32_t>::getTensorValueInt8(const size_t bufLen, int8_t* vals) const;
680
681template <>
Jerry Ge20ab3df2024-01-26 16:56:55 +0000682int Tensor0<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
683template <>
684int Tensor1<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
685template <>
686int Tensor2<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
687template <>
688int Tensor3<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
689template <>
690int Tensor4<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
691template <>
692int Tensor5<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
693template <>
694int Tensor6<int32_t>::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const;
695
696template <>
Georgios Pinitase9059772023-12-06 18:52:30 +0000697int Tensor0<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
698template <>
699int Tensor1<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
700template <>
701int Tensor2<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
702template <>
703int Tensor3<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
704template <>
705int Tensor4<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
706template <>
707int Tensor5<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
708template <>
709int Tensor6<int32_t>::getTensorValueInt16(const size_t bufLen, int16_t* vals) const;
710
711template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700712int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
713template <>
714int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
715template <>
716int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
717template <>
718int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
719template <>
720int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
721template <>
722int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
723template <>
724int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
725
726template <>
727int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
728template <>
729int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
730template <>
731int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
732template <>
733int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
734template <>
735int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
736template <>
737int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
738template <>
739int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
740
741template <>
742int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
743template <>
744int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
745template <>
746int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
747template <>
748int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
749template <>
750int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
751template <>
752int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
753template <>
754int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
755
756template <>
757int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
758template <>
759int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
760template <>
761int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
762template <>
763int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
764template <>
765int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
766template <>
767int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
768template <>
769int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
770
771template <>
772int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
773template <>
774int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
775template <>
776int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
777template <>
778int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
779template <>
780int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
781template <>
782int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
783template <>
784int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
785
786template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000787int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
788template <>
789int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
790template <>
791int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
792template <>
793int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
794template <>
795int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
796template <>
797int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
798template <>
799int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
800
801template <>
802int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
803template <>
804int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
805template <>
806int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
807template <>
808int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
809template <>
810int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
811template <>
812int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
813template <>
814int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
815
816template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700817int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
818template <>
819int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
820template <>
821int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
822template <>
823int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
824template <>
825int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
826template <>
827int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
828template <>
829int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
830
831template <>
832int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
833template <>
834int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
835template <>
836int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
837template <>
838int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
839template <>
840int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
841template <>
842int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
843template <>
844int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
845
Eric Kunzee5e26762020-10-13 16:11:07 -0700846template <>
847int Tensor0<float>::dumpTensor(FILE* out) const;
848template <>
849int Tensor1<float>::dumpTensor(FILE* out) const;
850template <>
851int Tensor2<float>::dumpTensor(FILE* out) const;
852template <>
853int Tensor3<float>::dumpTensor(FILE* out) const;
854template <>
855int Tensor4<float>::dumpTensor(FILE* out) const;
856template <>
857int Tensor5<float>::dumpTensor(FILE* out) const;
858template <>
859int Tensor6<float>::dumpTensor(FILE* out) const;
860template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000861int Tensor0<double>::dumpTensor(FILE* out) const;
862template <>
863int Tensor1<double>::dumpTensor(FILE* out) const;
864template <>
865int Tensor2<double>::dumpTensor(FILE* out) const;
866template <>
867int Tensor3<double>::dumpTensor(FILE* out) const;
868template <>
869int Tensor4<double>::dumpTensor(FILE* out) const;
870template <>
871int Tensor5<float>::dumpTensor(FILE* out) const;
872template <>
873int Tensor6<double>::dumpTensor(FILE* out) const;
874template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700875int Tensor0<int32_t>::dumpTensor(FILE* out) const;
876template <>
877int Tensor1<int32_t>::dumpTensor(FILE* out) const;
878template <>
879int Tensor2<int32_t>::dumpTensor(FILE* out) const;
880template <>
881int Tensor3<int32_t>::dumpTensor(FILE* out) const;
882template <>
883int Tensor4<int32_t>::dumpTensor(FILE* out) const;
884template <>
885int Tensor5<int32_t>::dumpTensor(FILE* out) const;
886template <>
887int Tensor6<int32_t>::dumpTensor(FILE* out) const;
888template <>
889int Tensor0<int64_t>::dumpTensor(FILE* out) const;
890template <>
891int Tensor1<int64_t>::dumpTensor(FILE* out) const;
892template <>
893int Tensor2<int64_t>::dumpTensor(FILE* out) const;
894template <>
895int Tensor3<int64_t>::dumpTensor(FILE* out) const;
896template <>
897int Tensor4<int64_t>::dumpTensor(FILE* out) const;
898template <>
899int Tensor5<int64_t>::dumpTensor(FILE* out) const;
900template <>
901int Tensor6<int64_t>::dumpTensor(FILE* out) const;
902template <>
903int Tensor0<bool>::dumpTensor(FILE* out) const;
904template <>
905int Tensor1<bool>::dumpTensor(FILE* out) const;
906template <>
907int Tensor2<bool>::dumpTensor(FILE* out) const;
908template <>
909int Tensor3<bool>::dumpTensor(FILE* out) const;
910template <>
911int Tensor4<bool>::dumpTensor(FILE* out) const;
912template <>
913int Tensor5<bool>::dumpTensor(FILE* out) const;
914template <>
915int Tensor6<bool>::dumpTensor(FILE* out) const;
916
917class TensorFactory
918{
919public:
Tai Lya4d748b2023-03-28 22:06:56 +0000920 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700921 {
Tai Lya4d748b2023-03-28 22:06:56 +0000922 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700923 switch (tensorDtype_)
924 {
Tai Lya4d748b2023-03-28 22:06:56 +0000925 case TOSA_REF_TYPE_FP32:
926 case TOSA_REF_TYPE_FP16:
927 case TOSA_REF_TYPE_BF16:
Won Jeon2c34b462024-02-06 18:37:00 +0000928 case TOSA_REF_TYPE_FP8E4M3:
929 case TOSA_REF_TYPE_FP8E5M2:
Eric Kunzee5e26762020-10-13 16:11:07 -0700930 switch (rank)
931 {
932 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000933 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700934 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000935 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700936 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000937 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000939 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700940 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000941 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700942 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000943 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700944 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000945 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700946 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700947 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000948 case TOSA_REF_TYPE_INT32:
949 case TOSA_REF_TYPE_UINT8:
950 case TOSA_REF_TYPE_INT4:
951 case TOSA_REF_TYPE_INT8:
952 case TOSA_REF_TYPE_INT16:
953 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700954 switch (rank)
955 {
956 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000957 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700958 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000959 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700960 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000961 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700962 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000963 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700964 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000965 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700966 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000967 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700968 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000969 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700970 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700971 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000972 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700973 switch (rank)
974 {
975 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000976 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700977 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000978 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700979 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000980 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700981 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000982 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000984 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700985 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000986 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700987 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000988 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700989 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700990 break;
Won Jeona21b2e82023-08-10 10:33:01 +0000991 case TOSA_REF_TYPE_SHAPE:
Tai Ly8690a082023-12-18 20:40:24 +0000992 switch (rank)
993 {
994 case 0:
995 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
996 case 1:
997 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
998 default:
999 assert(0); // shape tensors must have rank of 0 or 1
1000 }
1001 break;
Tai Lya4d748b2023-03-28 22:06:56 +00001002 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -07001003 switch (rank)
1004 {
1005 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +00001006 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001007 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +00001008 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001009 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +00001010 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +00001012 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001013 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +00001014 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001015 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +00001016 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001017 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +00001018 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -07001019 }
Kevin Cheng989cb052021-04-28 16:29:44 -07001020 break;
Tai Lya4d748b2023-03-28 22:06:56 +00001021 case TOSA_REF_TYPE_FP64:
1022 switch (rank)
1023 {
1024 case 0:
1025 return new Tensor0<double>(tensorName_, dtype_, shape_);
1026 case 1:
1027 return new Tensor1<double>(tensorName_, dtype_, shape_);
1028 case 2:
1029 return new Tensor2<double>(tensorName_, dtype_, shape_);
1030 case 3:
1031 return new Tensor3<double>(tensorName_, dtype_, shape_);
1032 case 4:
1033 return new Tensor4<double>(tensorName_, dtype_, shape_);
1034 case 5:
1035 return new Tensor5<double>(tensorName_, dtype_, shape_);
1036 case 6:
1037 return new Tensor6<double>(tensorName_, dtype_, shape_);
1038 }
1039 break;
1040 case TOSA_REF_TYPE_UNKNOWN:
1041 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -07001042 break;
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 }
Kevin Cheng903763c2021-09-28 16:14:52 -07001044 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -07001045 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001046};
1047}; // namespace TosaReference
1048
1049#endif