blob: d5f1de896a10bfc781ac63c1a5bbaa979a7dd6d6 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070020#include "model_common.h"
21#include "ops/template_types.h"
22#include "tosa_generated.h"
23#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Kevin Cheng989cb052021-04-28 16:29:44 -070037 Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
45 int getIsParentGraphOutput() const {
46 return isParentGraphOutput;
47 }
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 int getIsSubgraphInput() const
50 {
51 return isSubgraphInput;
52 }
53
54 int getIsSubgraphOutput() const
55 {
56 return isSubgraphOutput;
57 }
58
59 int setProducer(GraphNode* node);
60 int addConsumer(GraphNode* node);
61
62 int setIsValid()
63 {
64 isValid = 1;
65 return 0;
66 }
67
68 int clearIsValid()
69 {
70 isValid = 0;
71 return 0;
72 }
73
74 int getIsValid() const
75 {
76 return isValid;
77 }
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 GraphNode* getProducer()
80 {
81 return producer;
82 }
83
84 std::vector<GraphNode*>& getConsumers()
85 {
86 return consumers;
87 }
88
89 const std::string& getName() const
90 {
91 return tensorName;
92 }
93
94 const std::vector<int>& getShape() const
95 {
96 return shape;
97 }
98
99 std::string getShapeAsString() const
100 {
101 std::string shape_str("[");
102 for (auto& dim : shape)
103 {
104 shape_str += (std::to_string(dim) + ", ");
105 }
106 shape_str.append("]");
107 return shape_str;
108 }
109
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 const uint32_t getElementCount() const
111 {
112 uint32_t elements = 1;
113 for (size_t i = 0; i < shape.size(); i++)
114 elements *= shape[i];
115
116 return elements;
117 }
118
119 // Comparison of rank and type with other tensors
120 const int matchRank(const Tensor& ref) const
121 {
122 return (ref.shape.size() == shape.size()) ? 0 : 1;
123 }
124
125 const int matchType(const Tensor& ref) const
126 {
127 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
128 }
129
130 const int matchRankType(const Tensor& ref) const
131 {
132 return (matchType(ref) || matchRank(ref));
133 }
134
135 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
136 {
137 if (matchRankType(ref))
138 return 1;
139
140 for (size_t i = 0; i < shape.size(); i++)
141 {
142 if (shape[i] != ref.shape[i])
143 {
144 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000145 // For broadcasts, the order of *this and ref matters.
146 // *this should be the source tensor.
147 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
148 // this->shape must have size 1 if they don't match
149 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700150 {
151 return 1;
152 }
153 }
154 }
155
156 return 0;
157 }
158
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800159 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
160 {
161 if (matchRank(ref))
162 return 1;
163
164 for (size_t i = 0; i < shape.size(); i++)
165 {
166 if (shape[i] != ref.shape[i])
167 {
168 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000169 // For broadcasts, the order of *this and ref matters.
170 // *this should be the source tensor.
171 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
172 // this->shape must have size 1 if they don't match
173 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800174 {
175 return 1;
176 }
177 }
178 }
179
180 return 0;
181 }
182
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 // Sometimes we might want to match several semi-compatible types,
184 // so just check rank and size here
185 const int matchRankSize(const Tensor& ref) const
186 {
187 if (matchRank(ref))
188 return 1;
189
190 for (size_t i = 0; i < shape.size(); i++)
191 {
192 if (shape[i] != ref.shape[i])
193 return 1;
194 }
195
196 return 0;
197 }
198
199 // Unary check to make sure rank matches
200 const int checkRequiredRank(const int exactRank) const
201 {
202 return (shape.size() == (size_t)exactRank) ? 0 : 1;
203 }
204
205 const int checkRequiredRank(const int minRank, const int maxRank) const
206 {
207 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
208 }
209
210 const int getRank() const
211 {
212 return shape.size();
213 }
214
215 const DType getDtype() const
216 {
217 return tensorDtype;
218 }
219
220 virtual int dumpTensor(FILE* out) const = 0;
221 virtual int dumpTensorParams(FILE* out) const;
222 virtual int dumpTensorParams(std::ostream& out) const;
223
224 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
225 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
226 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
227 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
228 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
229 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
230 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
231 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
232
233 virtual int readFromNpyFile(const char* filename);
234 virtual int writeToNpyFile(const char* filename) const;
235 virtual int copyValueFrom(Tensor* tensor) = 0;
236
Grant Watson64285a12022-11-16 15:32:39 +0000237 virtual int readfromVector(const ArrayProxy<float> vals);
238 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
239 virtual int readfromVector(const ArrayProxy<int32_t> vals);
240 virtual int readfromVector(const ArrayProxy<int64_t> vals);
241 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100242
Grant Watson64285a12022-11-16 15:32:39 +0000243 virtual int writeToVector(ArrayProxy<float> vals);
244 virtual int writeToVector(ArrayProxy<half_float::half> vals);
245 virtual int writeToVector(ArrayProxy<int32_t> vals);
246 virtual int writeToVector(ArrayProxy<int64_t> vals);
247 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100248
Eric Kunzee5e26762020-10-13 16:11:07 -0700249 const char* bool_to_str(bool in) const
250 {
251 static const char* true_str = "true";
252 static const char* false_str = "false";
253 return in ? true_str : false_str;
254 }
255
256 virtual int allocate() = 0;
257 virtual int deallocate() = 0;
258 virtual bool is_allocated() = 0;
259
260protected:
261 std::string tensorName;
262 DType tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 int isValid;
264 std::vector<int> shape;
265 int isSubgraphInput;
266 int isSubgraphOutput;
267 bool isAllocated;
268
Jerry Ge9e94af82022-10-27 09:57:00 -0700269 bool isParentGraphOutput;
270
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 GraphNode* producer;
272 std::vector<GraphNode*> consumers;
273
274 // Note: the Eigen::Tensor is not declared in Tensor
275 // Instead, the TensorTemplate class keeps the templated tensor
276 // declaration so that the graph manipulation tools are isolated
277 // from the templated tensor type.
278 //
279 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
280 // so that they can operate on the right types.
281};
282
283template <class T>
284class TensorTemplate : public Tensor
285{
286public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700287 TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 : Tensor(tensorName_, tensorDtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 {
290 tensor = nullptr;
291 }
292
293 virtual ~TensorTemplate()
294 {
295 deallocate();
296 }
297
298 virtual int allocate()
299 {
300 tensor = new T();
301 if (tensor)
302 return 0;
303 else
304 return 1;
305 }
306
307 virtual int deallocate()
308 {
309 if (tensor)
310 {
311 delete tensor;
312 }
313 tensor = nullptr;
314 return 0;
315 }
316
317 virtual bool is_allocated()
318 {
319 if (tensor)
320 {
321 return true;
322 }
323 return false;
324 }
325
326 T& getTensor()
327 {
328 return *tensor;
329 }
330
331 virtual int dumpTensor(FILE* out) const;
332
333 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
334 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
335 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
336 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
337 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
338 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
339 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
340 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
341
342 virtual int copyValueFrom(Tensor* tensor);
343
344protected:
345 T* tensor;
346};
347
348// allocate() template specializations to allocate the different tensor sizes
349// Let the compiler know here before the factory uses them, but define them in the .cc file.
350template <>
351int Tensor0<float>::allocate();
352template <>
353int Tensor1<float>::allocate();
354template <>
355int Tensor2<float>::allocate();
356template <>
357int Tensor3<float>::allocate();
358template <>
359int Tensor4<float>::allocate();
360template <>
361int Tensor5<float>::allocate();
362template <>
363int Tensor6<float>::allocate();
364
365template <>
366int Tensor0<int32_t>::allocate();
367template <>
368int Tensor1<int32_t>::allocate();
369template <>
370int Tensor2<int32_t>::allocate();
371template <>
372int Tensor3<int32_t>::allocate();
373template <>
374int Tensor4<int32_t>::allocate();
375template <>
376int Tensor5<int32_t>::allocate();
377template <>
378int Tensor6<int32_t>::allocate();
379
380template <>
381int Tensor0<int64_t>::allocate();
382template <>
383int Tensor1<int64_t>::allocate();
384template <>
385int Tensor2<int64_t>::allocate();
386template <>
387int Tensor3<int64_t>::allocate();
388template <>
389int Tensor4<int64_t>::allocate();
390template <>
391int Tensor5<int64_t>::allocate();
392template <>
393int Tensor6<int64_t>::allocate();
394
395template <>
396int Tensor0<bool>::allocate();
397template <>
398int Tensor1<bool>::allocate();
399template <>
400int Tensor2<bool>::allocate();
401template <>
402int Tensor3<bool>::allocate();
403template <>
404int Tensor4<bool>::allocate();
405template <>
406int Tensor5<bool>::allocate();
407template <>
408int Tensor6<bool>::allocate();
409
410template <>
411int Tensor0<float>::copyValueFrom(Tensor* src);
412template <>
413int Tensor1<float>::copyValueFrom(Tensor* src);
414template <>
415int Tensor2<float>::copyValueFrom(Tensor* src);
416template <>
417int Tensor3<float>::copyValueFrom(Tensor* src);
418template <>
419int Tensor4<float>::copyValueFrom(Tensor* src);
420template <>
421int Tensor5<float>::copyValueFrom(Tensor* src);
422template <>
423int Tensor6<float>::copyValueFrom(Tensor* src);
424
425template <>
426int Tensor0<int32_t>::copyValueFrom(Tensor* src);
427template <>
428int Tensor1<int32_t>::copyValueFrom(Tensor* src);
429template <>
430int Tensor2<int32_t>::copyValueFrom(Tensor* src);
431template <>
432int Tensor3<int32_t>::copyValueFrom(Tensor* src);
433template <>
434int Tensor4<int32_t>::copyValueFrom(Tensor* src);
435template <>
436int Tensor5<int32_t>::copyValueFrom(Tensor* src);
437template <>
438int Tensor6<int32_t>::copyValueFrom(Tensor* src);
439
440template <>
441int Tensor0<int64_t>::copyValueFrom(Tensor* src);
442template <>
443int Tensor1<int64_t>::copyValueFrom(Tensor* src);
444template <>
445int Tensor2<int64_t>::copyValueFrom(Tensor* src);
446template <>
447int Tensor3<int64_t>::copyValueFrom(Tensor* src);
448template <>
449int Tensor4<int64_t>::copyValueFrom(Tensor* src);
450template <>
451int Tensor5<int64_t>::copyValueFrom(Tensor* src);
452template <>
453int Tensor6<int64_t>::copyValueFrom(Tensor* src);
454
455template <>
456int Tensor0<bool>::copyValueFrom(Tensor* src);
457template <>
458int Tensor1<bool>::copyValueFrom(Tensor* src);
459template <>
460int Tensor2<bool>::copyValueFrom(Tensor* src);
461template <>
462int Tensor3<bool>::copyValueFrom(Tensor* src);
463template <>
464int Tensor4<bool>::copyValueFrom(Tensor* src);
465template <>
466int Tensor5<bool>::copyValueFrom(Tensor* src);
467template <>
468int Tensor6<bool>::copyValueFrom(Tensor* src);
469
470template <>
471int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
472template <>
473int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
474template <>
475int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
476template <>
477int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
478template <>
479int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
480template <>
481int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
482template <>
483int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
484
485template <>
486int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
487template <>
488int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
489template <>
490int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
491template <>
492int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
493template <>
494int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
495template <>
496int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
497template <>
498int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
499
500template <>
501int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
502template <>
503int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
504template <>
505int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
506template <>
507int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
508template <>
509int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
510template <>
511int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
512template <>
513int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
514
515template <>
516int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
517template <>
518int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
519template <>
520int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
521template <>
522int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
523template <>
524int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
525template <>
526int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
527template <>
528int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
529
530template <>
531int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
532template <>
533int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
534template <>
535int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
536template <>
537int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
538template <>
539int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
540template <>
541int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
542template <>
543int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
544
545template <>
546int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
547template <>
548int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
549template <>
550int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
551template <>
552int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
553template <>
554int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
555template <>
556int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
557template <>
558int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
559
560template <>
561int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
562template <>
563int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
564template <>
565int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
566template <>
567int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
568template <>
569int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
570template <>
571int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
572template <>
573int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
574
575template <>
576int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
577template <>
578int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
579template <>
580int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
581template <>
582int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
583template <>
584int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
585template <>
586int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
587template <>
588int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
589
590// assume we only dump float type tensor now
591template <>
592int Tensor0<float>::dumpTensor(FILE* out) const;
593template <>
594int Tensor1<float>::dumpTensor(FILE* out) const;
595template <>
596int Tensor2<float>::dumpTensor(FILE* out) const;
597template <>
598int Tensor3<float>::dumpTensor(FILE* out) const;
599template <>
600int Tensor4<float>::dumpTensor(FILE* out) const;
601template <>
602int Tensor5<float>::dumpTensor(FILE* out) const;
603template <>
604int Tensor6<float>::dumpTensor(FILE* out) const;
605template <>
606int Tensor0<int32_t>::dumpTensor(FILE* out) const;
607template <>
608int Tensor1<int32_t>::dumpTensor(FILE* out) const;
609template <>
610int Tensor2<int32_t>::dumpTensor(FILE* out) const;
611template <>
612int Tensor3<int32_t>::dumpTensor(FILE* out) const;
613template <>
614int Tensor4<int32_t>::dumpTensor(FILE* out) const;
615template <>
616int Tensor5<int32_t>::dumpTensor(FILE* out) const;
617template <>
618int Tensor6<int32_t>::dumpTensor(FILE* out) const;
619template <>
620int Tensor0<int64_t>::dumpTensor(FILE* out) const;
621template <>
622int Tensor1<int64_t>::dumpTensor(FILE* out) const;
623template <>
624int Tensor2<int64_t>::dumpTensor(FILE* out) const;
625template <>
626int Tensor3<int64_t>::dumpTensor(FILE* out) const;
627template <>
628int Tensor4<int64_t>::dumpTensor(FILE* out) const;
629template <>
630int Tensor5<int64_t>::dumpTensor(FILE* out) const;
631template <>
632int Tensor6<int64_t>::dumpTensor(FILE* out) const;
633template <>
634int Tensor0<bool>::dumpTensor(FILE* out) const;
635template <>
636int Tensor1<bool>::dumpTensor(FILE* out) const;
637template <>
638int Tensor2<bool>::dumpTensor(FILE* out) const;
639template <>
640int Tensor3<bool>::dumpTensor(FILE* out) const;
641template <>
642int Tensor4<bool>::dumpTensor(FILE* out) const;
643template <>
644int Tensor5<bool>::dumpTensor(FILE* out) const;
645template <>
646int Tensor6<bool>::dumpTensor(FILE* out) const;
647
648class TensorFactory
649{
650public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700651 static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700652 {
653 switch (tensorDtype_)
654 {
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100655 case DType_FP32:
James Ward8b390432022-08-12 20:48:56 +0100656 case DType_FP16:
James Ward24dbc422022-10-19 12:20:31 +0100657 case DType_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 switch (rank)
659 {
660 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800661 return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700662 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800663 return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700664 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800665 return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800667 return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700668 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800669 return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700670 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800671 return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700672 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800673 return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700675 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 case DType_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700677 case DType_UINT8:
678 case DType_INT4:
679 case DType_INT8:
680 case DType_INT16:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100681 case DType_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 switch (rank)
683 {
684 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800685 return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700686 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800687 return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800689 return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800691 return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700692 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800693 return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700694 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800695 return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800697 return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700698 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700699 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700700 case DType_INT48:
701 switch (rank)
702 {
703 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800704 return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700705 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800706 return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700707 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800708 return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700709 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800710 return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700711 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800712 return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700713 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800714 return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700715 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700717 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700718 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700719 case DType_BOOL:
720 switch (rank)
721 {
722 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700724 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800725 return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700726 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800727 return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700728 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800729 return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700730 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800731 return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700732 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800733 return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700734 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800735 return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700736 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700737 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700738 default:
Kevin Cheng989cb052021-04-28 16:29:44 -0700739 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700740 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700741 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700742 }
743
744 static Tensor* newTensor(DType type, const std::vector<int> shape);
745};
746}; // namespace TosaReference
747
748#endif