blob: 2fd37cdf9ecf746a3eac62c7c807913586f253b6 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
19#include "model_common.h"
20#include "ops/template_types.h"
21#include "tosa_generated.h"
22#include "tosa_serialization_handler.h"
23#include <Eigen/CXX11/Tensor>
24#include <list>
25#include <vector>
26
27using namespace tosa;
28
29namespace TosaReference
30{
31class GraphNode;
32
33class Tensor
34{
35public:
36 Tensor(std::string tensorName_,
37 DType tensorDtype__,
38 const std::vector<Usage>& tensorUsage_,
39 const std::vector<Format>& tensorFormat_,
40 std::vector<int> shape_,
41 int isConst_);
42
43 virtual ~Tensor();
44
45 int setIsSubgraphInput();
46 int setIsSubgraphOutput();
47
48 int getIsSubgraphInput() const
49 {
50 return isSubgraphInput;
51 }
52
53 int getIsSubgraphOutput() const
54 {
55 return isSubgraphOutput;
56 }
57
58 int setProducer(GraphNode* node);
59 int addConsumer(GraphNode* node);
60
61 int setIsValid()
62 {
63 isValid = 1;
64 return 0;
65 }
66
67 int clearIsValid()
68 {
69 isValid = 0;
70 return 0;
71 }
72
73 int getIsValid() const
74 {
75 return isValid;
76 }
77
78 int getIsConst() const
79 {
80 return isConst;
81 }
82
83 GraphNode* getProducer()
84 {
85 return producer;
86 }
87
88 std::vector<GraphNode*>& getConsumers()
89 {
90 return consumers;
91 }
92
93 const std::string& getName() const
94 {
95 return tensorName;
96 }
97
98 const std::vector<int>& getShape() const
99 {
100 return shape;
101 }
102
103 std::string getShapeAsString() const
104 {
105 std::string shape_str("[");
106 for (auto& dim : shape)
107 {
108 shape_str += (std::to_string(dim) + ", ");
109 }
110 shape_str.append("]");
111 return shape_str;
112 }
113
114 const std::vector<Usage>& getUsage() const
115 {
116 return tensorUsage;
117 }
118
119 bool hasUsage(Usage usage) const
120 {
121 for (auto& usg : tensorUsage)
122 {
123 if (usg == usage)
124 {
125 return true;
126 }
127 }
128 return false;
129 }
130
131 std::string getUsageAsString() const
132 {
133 std::string usage_str("[");
134 for (auto& usg : tensorUsage)
135 {
136 usage_str += (std::string(EnumNamesUsage()[usg]) + ", ");
137 }
138 usage_str.append("]");
139 return usage_str;
140 }
141
142 const std::vector<Format>& getFormat() const
143 {
144 return tensorFormat;
145 }
146
147 bool hasFormat(Format format) const
148 {
149 for (auto& fmt : tensorFormat)
150 {
151 if (fmt == format)
152 {
153 return true;
154 }
155 }
156 return false;
157 }
158
159 std::string getFormatAsString() const
160 {
161 std::string format_str("[");
162 for (auto& fmt : tensorFormat)
163 {
164 format_str += (std::string(EnumNamesFormat()[fmt]) + ", ");
165 }
166 format_str.append("]");
167 return format_str;
168 }
169
170 const uint32_t getElementCount() const
171 {
172 uint32_t elements = 1;
173 for (size_t i = 0; i < shape.size(); i++)
174 elements *= shape[i];
175
176 return elements;
177 }
178
179 // Comparison of rank and type with other tensors
180 const int matchRank(const Tensor& ref) const
181 {
182 return (ref.shape.size() == shape.size()) ? 0 : 1;
183 }
184
185 const int matchType(const Tensor& ref) const
186 {
187 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
188 }
189
190 const int matchRankType(const Tensor& ref) const
191 {
192 return (matchType(ref) || matchRank(ref));
193 }
194
195 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
196 {
197 if (matchRankType(ref))
198 return 1;
199
200 for (size_t i = 0; i < shape.size(); i++)
201 {
202 if (shape[i] != ref.shape[i])
203 {
204 if (!broadcastOk ||
205 // For broadcasts, at least one operand must have size 1
206 // if they don't both match
207 (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
208 {
209 return 1;
210 }
211 }
212 }
213
214 return 0;
215 }
216
217 // Sometimes we might want to match several semi-compatible types,
218 // so just check rank and size here
219 const int matchRankSize(const Tensor& ref) const
220 {
221 if (matchRank(ref))
222 return 1;
223
224 for (size_t i = 0; i < shape.size(); i++)
225 {
226 if (shape[i] != ref.shape[i])
227 return 1;
228 }
229
230 return 0;
231 }
232
233 // Unary check to make sure rank matches
234 const int checkRequiredRank(const int exactRank) const
235 {
236 return (shape.size() == (size_t)exactRank) ? 0 : 1;
237 }
238
239 const int checkRequiredRank(const int minRank, const int maxRank) const
240 {
241 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
242 }
243
244 const int getRank() const
245 {
246 return shape.size();
247 }
248
249 const DType getDtype() const
250 {
251 return tensorDtype;
252 }
253
254 virtual int dumpTensor(FILE* out) const = 0;
255 virtual int dumpTensorParams(FILE* out) const;
256 virtual int dumpTensorParams(std::ostream& out) const;
257
258 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
259 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
260 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
261 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
262 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
263 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
264 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
265 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
266
267 virtual int readFromNpyFile(const char* filename);
268 virtual int writeToNpyFile(const char* filename) const;
269 virtual int copyValueFrom(Tensor* tensor) = 0;
270
271 const char* bool_to_str(bool in) const
272 {
273 static const char* true_str = "true";
274 static const char* false_str = "false";
275 return in ? true_str : false_str;
276 }
277
278 virtual int allocate() = 0;
279 virtual int deallocate() = 0;
280 virtual bool is_allocated() = 0;
281
282protected:
283 std::string tensorName;
284 DType tensorDtype;
285 std::vector<Usage> tensorUsage;
286 std::vector<Format> tensorFormat;
287 int isConst;
288 int isValid;
289 std::vector<int> shape;
290 int isSubgraphInput;
291 int isSubgraphOutput;
292 bool isAllocated;
293
294 GraphNode* producer;
295 std::vector<GraphNode*> consumers;
296
297 // Note: the Eigen::Tensor is not declared in Tensor
298 // Instead, the TensorTemplate class keeps the templated tensor
299 // declaration so that the graph manipulation tools are isolated
300 // from the templated tensor type.
301 //
302 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
303 // so that they can operate on the right types.
304};
305
306template <class T>
307class TensorTemplate : public Tensor
308{
309public:
310 TensorTemplate(std::string tensorName_,
311 DType tensorDtype_,
312 const std::vector<Usage>& tensorUsage_,
313 const std::vector<Format>& tensorFormat_,
314 std::vector<int> shape_,
315 int isConst_)
316 : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_)
317 {
318 tensor = nullptr;
319 }
320
321 virtual ~TensorTemplate()
322 {
323 deallocate();
324 }
325
326 virtual int allocate()
327 {
328 tensor = new T();
329 if (tensor)
330 return 0;
331 else
332 return 1;
333 }
334
335 virtual int deallocate()
336 {
337 if (tensor)
338 {
339 delete tensor;
340 }
341 tensor = nullptr;
342 return 0;
343 }
344
345 virtual bool is_allocated()
346 {
347 if (tensor)
348 {
349 return true;
350 }
351 return false;
352 }
353
354 T& getTensor()
355 {
356 return *tensor;
357 }
358
359 virtual int dumpTensor(FILE* out) const;
360
361 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
362 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
363 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
364 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
365 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
366 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
367 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
368 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
369
370 virtual int copyValueFrom(Tensor* tensor);
371
372protected:
373 T* tensor;
374};
375
376// allocate() template specializations to allocate the different tensor sizes
377// Let the compiler know here before the factory uses them, but define them in the .cc file.
378template <>
379int Tensor0<float>::allocate();
380template <>
381int Tensor1<float>::allocate();
382template <>
383int Tensor2<float>::allocate();
384template <>
385int Tensor3<float>::allocate();
386template <>
387int Tensor4<float>::allocate();
388template <>
389int Tensor5<float>::allocate();
390template <>
391int Tensor6<float>::allocate();
392
393template <>
394int Tensor0<int32_t>::allocate();
395template <>
396int Tensor1<int32_t>::allocate();
397template <>
398int Tensor2<int32_t>::allocate();
399template <>
400int Tensor3<int32_t>::allocate();
401template <>
402int Tensor4<int32_t>::allocate();
403template <>
404int Tensor5<int32_t>::allocate();
405template <>
406int Tensor6<int32_t>::allocate();
407
408template <>
409int Tensor0<int64_t>::allocate();
410template <>
411int Tensor1<int64_t>::allocate();
412template <>
413int Tensor2<int64_t>::allocate();
414template <>
415int Tensor3<int64_t>::allocate();
416template <>
417int Tensor4<int64_t>::allocate();
418template <>
419int Tensor5<int64_t>::allocate();
420template <>
421int Tensor6<int64_t>::allocate();
422
423template <>
424int Tensor0<bool>::allocate();
425template <>
426int Tensor1<bool>::allocate();
427template <>
428int Tensor2<bool>::allocate();
429template <>
430int Tensor3<bool>::allocate();
431template <>
432int Tensor4<bool>::allocate();
433template <>
434int Tensor5<bool>::allocate();
435template <>
436int Tensor6<bool>::allocate();
437
438template <>
439int Tensor0<float>::copyValueFrom(Tensor* src);
440template <>
441int Tensor1<float>::copyValueFrom(Tensor* src);
442template <>
443int Tensor2<float>::copyValueFrom(Tensor* src);
444template <>
445int Tensor3<float>::copyValueFrom(Tensor* src);
446template <>
447int Tensor4<float>::copyValueFrom(Tensor* src);
448template <>
449int Tensor5<float>::copyValueFrom(Tensor* src);
450template <>
451int Tensor6<float>::copyValueFrom(Tensor* src);
452
453template <>
454int Tensor0<int32_t>::copyValueFrom(Tensor* src);
455template <>
456int Tensor1<int32_t>::copyValueFrom(Tensor* src);
457template <>
458int Tensor2<int32_t>::copyValueFrom(Tensor* src);
459template <>
460int Tensor3<int32_t>::copyValueFrom(Tensor* src);
461template <>
462int Tensor4<int32_t>::copyValueFrom(Tensor* src);
463template <>
464int Tensor5<int32_t>::copyValueFrom(Tensor* src);
465template <>
466int Tensor6<int32_t>::copyValueFrom(Tensor* src);
467
468template <>
469int Tensor0<int64_t>::copyValueFrom(Tensor* src);
470template <>
471int Tensor1<int64_t>::copyValueFrom(Tensor* src);
472template <>
473int Tensor2<int64_t>::copyValueFrom(Tensor* src);
474template <>
475int Tensor3<int64_t>::copyValueFrom(Tensor* src);
476template <>
477int Tensor4<int64_t>::copyValueFrom(Tensor* src);
478template <>
479int Tensor5<int64_t>::copyValueFrom(Tensor* src);
480template <>
481int Tensor6<int64_t>::copyValueFrom(Tensor* src);
482
483template <>
484int Tensor0<bool>::copyValueFrom(Tensor* src);
485template <>
486int Tensor1<bool>::copyValueFrom(Tensor* src);
487template <>
488int Tensor2<bool>::copyValueFrom(Tensor* src);
489template <>
490int Tensor3<bool>::copyValueFrom(Tensor* src);
491template <>
492int Tensor4<bool>::copyValueFrom(Tensor* src);
493template <>
494int Tensor5<bool>::copyValueFrom(Tensor* src);
495template <>
496int Tensor6<bool>::copyValueFrom(Tensor* src);
497
498template <>
499int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
500template <>
501int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
502template <>
503int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
504template <>
505int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
506template <>
507int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
508template <>
509int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
510template <>
511int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
512
513template <>
514int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
515template <>
516int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
517template <>
518int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
519template <>
520int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
521template <>
522int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
523template <>
524int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
525template <>
526int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
527
528template <>
529int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
530template <>
531int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
532template <>
533int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
534template <>
535int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
536template <>
537int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
538template <>
539int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
540template <>
541int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
542
543template <>
544int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
545template <>
546int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
547template <>
548int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
549template <>
550int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
551template <>
552int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
553template <>
554int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
555template <>
556int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
557
558template <>
559int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
560template <>
561int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
562template <>
563int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
564template <>
565int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
566template <>
567int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
568template <>
569int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
570template <>
571int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
572
573template <>
574int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
575template <>
576int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
577template <>
578int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
579template <>
580int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
581template <>
582int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
583template <>
584int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
585template <>
586int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
587
588template <>
589int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
590template <>
591int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
592template <>
593int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
594template <>
595int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
596template <>
597int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
598template <>
599int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
600template <>
601int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
602
603template <>
604int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
605template <>
606int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
607template <>
608int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
609template <>
610int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
611template <>
612int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
613template <>
614int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
615template <>
616int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
617
618// assume we only dump float type tensor now
619template <>
620int Tensor0<float>::dumpTensor(FILE* out) const;
621template <>
622int Tensor1<float>::dumpTensor(FILE* out) const;
623template <>
624int Tensor2<float>::dumpTensor(FILE* out) const;
625template <>
626int Tensor3<float>::dumpTensor(FILE* out) const;
627template <>
628int Tensor4<float>::dumpTensor(FILE* out) const;
629template <>
630int Tensor5<float>::dumpTensor(FILE* out) const;
631template <>
632int Tensor6<float>::dumpTensor(FILE* out) const;
633template <>
634int Tensor0<int32_t>::dumpTensor(FILE* out) const;
635template <>
636int Tensor1<int32_t>::dumpTensor(FILE* out) const;
637template <>
638int Tensor2<int32_t>::dumpTensor(FILE* out) const;
639template <>
640int Tensor3<int32_t>::dumpTensor(FILE* out) const;
641template <>
642int Tensor4<int32_t>::dumpTensor(FILE* out) const;
643template <>
644int Tensor5<int32_t>::dumpTensor(FILE* out) const;
645template <>
646int Tensor6<int32_t>::dumpTensor(FILE* out) const;
647template <>
648int Tensor0<int64_t>::dumpTensor(FILE* out) const;
649template <>
650int Tensor1<int64_t>::dumpTensor(FILE* out) const;
651template <>
652int Tensor2<int64_t>::dumpTensor(FILE* out) const;
653template <>
654int Tensor3<int64_t>::dumpTensor(FILE* out) const;
655template <>
656int Tensor4<int64_t>::dumpTensor(FILE* out) const;
657template <>
658int Tensor5<int64_t>::dumpTensor(FILE* out) const;
659template <>
660int Tensor6<int64_t>::dumpTensor(FILE* out) const;
661template <>
662int Tensor0<bool>::dumpTensor(FILE* out) const;
663template <>
664int Tensor1<bool>::dumpTensor(FILE* out) const;
665template <>
666int Tensor2<bool>::dumpTensor(FILE* out) const;
667template <>
668int Tensor3<bool>::dumpTensor(FILE* out) const;
669template <>
670int Tensor4<bool>::dumpTensor(FILE* out) const;
671template <>
672int Tensor5<bool>::dumpTensor(FILE* out) const;
673template <>
674int Tensor6<bool>::dumpTensor(FILE* out) const;
675
676class TensorFactory
677{
678public:
679 static Tensor* newTensor(std::string tensorName_,
680 DType tensorDtype_,
681 const std::vector<Usage>& tensorUsage_,
682 const std::vector<Format>& tensorFormat_,
683 std::vector<int> shape_,
684 int isConst_,
685 const uint32_t rank)
686 {
687 switch (tensorDtype_)
688 {
689 case DType_FLOAT:
690 switch (rank)
691 {
692 case 0:
693 return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
694 isConst_);
695 case 1:
696 return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
697 isConst_);
698 case 2:
699 return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
700 isConst_);
701 case 3:
702 return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
703 isConst_);
704 case 4:
705 return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
706 isConst_);
707 case 5:
708 return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
709 isConst_);
710 case 6:
711 return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
712 isConst_);
713 default:
714 goto done;
715 }
716 case DType_INT32:
717 case DType_AINT8:
718 case DType_UINT8:
719 case DType_INT4:
720 case DType_INT8:
721 case DType_INT16:
722 switch (rank)
723 {
724 case 0:
725 return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
726 isConst_);
727 case 1:
728 return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
729 isConst_);
730 case 2:
731 return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
732 isConst_);
733 case 3:
734 return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
735 isConst_);
736 case 4:
737 return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
738 isConst_);
739 case 5:
740 return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
741 isConst_);
742 case 6:
743 return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
744 isConst_);
745 default:
746 goto done;
747 }
748 case DType_INT48:
749 switch (rank)
750 {
751 case 0:
752 return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
753 isConst_);
754 case 1:
755 return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
756 isConst_);
757 case 2:
758 return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
759 isConst_);
760 case 3:
761 return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
762 isConst_);
763 case 4:
764 return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
765 isConst_);
766 case 5:
767 return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
768 isConst_);
769 case 6:
770 return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
771 isConst_);
772 default:
773 goto done;
774 }
775 case DType_BOOL:
776 switch (rank)
777 {
778 case 0:
779 return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
780 isConst_);
781 case 1:
782 return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
783 isConst_);
784 case 2:
785 return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
786 isConst_);
787 case 3:
788 return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
789 isConst_);
790 case 4:
791 return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
792 isConst_);
793 case 5:
794 return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
795 isConst_);
796 case 6:
797 return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
798 isConst_);
799 default:
800 goto done;
801 }
802 default:
803 goto done;
804 }
805
806 done:
807 FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_],
808 rank);
809 }
810
811 static Tensor* newTensor(DType type, const std::vector<int> shape);
812};
813}; // namespace TosaReference
814
815#endif