blob: 55365838cfcbfea60275ae451bdd4e7657461bcc [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
19#include "model_common.h"
20#include "ops/template_types.h"
21#include "tosa_generated.h"
22#include "tosa_serialization_handler.h"
23#include <Eigen/CXX11/Tensor>
24#include <list>
25#include <vector>
26
27using namespace tosa;
28
29namespace TosaReference
30{
31class GraphNode;
32
33class Tensor
34{
35public:
Kevin Cheng989cb052021-04-28 16:29:44 -070036 Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 virtual ~Tensor();
39
40 int setIsSubgraphInput();
41 int setIsSubgraphOutput();
42
43 int getIsSubgraphInput() const
44 {
45 return isSubgraphInput;
46 }
47
48 int getIsSubgraphOutput() const
49 {
50 return isSubgraphOutput;
51 }
52
53 int setProducer(GraphNode* node);
54 int addConsumer(GraphNode* node);
55
56 int setIsValid()
57 {
58 isValid = 1;
59 return 0;
60 }
61
62 int clearIsValid()
63 {
64 isValid = 0;
65 return 0;
66 }
67
68 int getIsValid() const
69 {
70 return isValid;
71 }
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 GraphNode* getProducer()
74 {
75 return producer;
76 }
77
78 std::vector<GraphNode*>& getConsumers()
79 {
80 return consumers;
81 }
82
83 const std::string& getName() const
84 {
85 return tensorName;
86 }
87
88 const std::vector<int>& getShape() const
89 {
90 return shape;
91 }
92
93 std::string getShapeAsString() const
94 {
95 std::string shape_str("[");
96 for (auto& dim : shape)
97 {
98 shape_str += (std::to_string(dim) + ", ");
99 }
100 shape_str.append("]");
101 return shape_str;
102 }
103
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 const uint32_t getElementCount() const
105 {
106 uint32_t elements = 1;
107 for (size_t i = 0; i < shape.size(); i++)
108 elements *= shape[i];
109
110 return elements;
111 }
112
113 // Comparison of rank and type with other tensors
114 const int matchRank(const Tensor& ref) const
115 {
116 return (ref.shape.size() == shape.size()) ? 0 : 1;
117 }
118
119 const int matchType(const Tensor& ref) const
120 {
121 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
122 }
123
124 const int matchRankType(const Tensor& ref) const
125 {
126 return (matchType(ref) || matchRank(ref));
127 }
128
129 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
130 {
131 if (matchRankType(ref))
132 return 1;
133
134 for (size_t i = 0; i < shape.size(); i++)
135 {
136 if (shape[i] != ref.shape[i])
137 {
138 if (!broadcastOk ||
139 // For broadcasts, at least one operand must have size 1
140 // if they don't both match
141 (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
142 {
143 return 1;
144 }
145 }
146 }
147
148 return 0;
149 }
150
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800151 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
152 {
153 if (matchRank(ref))
154 return 1;
155
156 for (size_t i = 0; i < shape.size(); i++)
157 {
158 if (shape[i] != ref.shape[i])
159 {
160 if (!broadcastOk ||
161 // For broadcasts, at least one operand must have size 1
162 // if they don't both match
163 (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
164 {
165 return 1;
166 }
167 }
168 }
169
170 return 0;
171 }
172
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 // Sometimes we might want to match several semi-compatible types,
174 // so just check rank and size here
175 const int matchRankSize(const Tensor& ref) const
176 {
177 if (matchRank(ref))
178 return 1;
179
180 for (size_t i = 0; i < shape.size(); i++)
181 {
182 if (shape[i] != ref.shape[i])
183 return 1;
184 }
185
186 return 0;
187 }
188
189 // Unary check to make sure rank matches
190 const int checkRequiredRank(const int exactRank) const
191 {
192 return (shape.size() == (size_t)exactRank) ? 0 : 1;
193 }
194
195 const int checkRequiredRank(const int minRank, const int maxRank) const
196 {
197 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
198 }
199
200 const int getRank() const
201 {
202 return shape.size();
203 }
204
205 const DType getDtype() const
206 {
207 return tensorDtype;
208 }
209
210 virtual int dumpTensor(FILE* out) const = 0;
211 virtual int dumpTensorParams(FILE* out) const;
212 virtual int dumpTensorParams(std::ostream& out) const;
213
214 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
215 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
216 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
217 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
218 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
219 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
220 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
221 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
222
223 virtual int readFromNpyFile(const char* filename);
224 virtual int writeToNpyFile(const char* filename) const;
225 virtual int copyValueFrom(Tensor* tensor) = 0;
226
227 const char* bool_to_str(bool in) const
228 {
229 static const char* true_str = "true";
230 static const char* false_str = "false";
231 return in ? true_str : false_str;
232 }
233
234 virtual int allocate() = 0;
235 virtual int deallocate() = 0;
236 virtual bool is_allocated() = 0;
237
238protected:
239 std::string tensorName;
240 DType tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 int isValid;
242 std::vector<int> shape;
243 int isSubgraphInput;
244 int isSubgraphOutput;
245 bool isAllocated;
246
247 GraphNode* producer;
248 std::vector<GraphNode*> consumers;
249
250 // Note: the Eigen::Tensor is not declared in Tensor
251 // Instead, the TensorTemplate class keeps the templated tensor
252 // declaration so that the graph manipulation tools are isolated
253 // from the templated tensor type.
254 //
255 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
256 // so that they can operate on the right types.
257};
258
259template <class T>
260class TensorTemplate : public Tensor
261{
262public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800264 : Tensor(tensorName_, tensorDtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 {
266 tensor = nullptr;
267 }
268
269 virtual ~TensorTemplate()
270 {
271 deallocate();
272 }
273
274 virtual int allocate()
275 {
276 tensor = new T();
277 if (tensor)
278 return 0;
279 else
280 return 1;
281 }
282
283 virtual int deallocate()
284 {
285 if (tensor)
286 {
287 delete tensor;
288 }
289 tensor = nullptr;
290 return 0;
291 }
292
293 virtual bool is_allocated()
294 {
295 if (tensor)
296 {
297 return true;
298 }
299 return false;
300 }
301
302 T& getTensor()
303 {
304 return *tensor;
305 }
306
307 virtual int dumpTensor(FILE* out) const;
308
309 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
310 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
311 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
312 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
313 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
314 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
315 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
316 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
317
318 virtual int copyValueFrom(Tensor* tensor);
319
320protected:
321 T* tensor;
322};
323
324// allocate() template specializations to allocate the different tensor sizes
325// Let the compiler know here before the factory uses them, but define them in the .cc file.
326template <>
327int Tensor0<float>::allocate();
328template <>
329int Tensor1<float>::allocate();
330template <>
331int Tensor2<float>::allocate();
332template <>
333int Tensor3<float>::allocate();
334template <>
335int Tensor4<float>::allocate();
336template <>
337int Tensor5<float>::allocate();
338template <>
339int Tensor6<float>::allocate();
340
341template <>
342int Tensor0<int32_t>::allocate();
343template <>
344int Tensor1<int32_t>::allocate();
345template <>
346int Tensor2<int32_t>::allocate();
347template <>
348int Tensor3<int32_t>::allocate();
349template <>
350int Tensor4<int32_t>::allocate();
351template <>
352int Tensor5<int32_t>::allocate();
353template <>
354int Tensor6<int32_t>::allocate();
355
356template <>
357int Tensor0<int64_t>::allocate();
358template <>
359int Tensor1<int64_t>::allocate();
360template <>
361int Tensor2<int64_t>::allocate();
362template <>
363int Tensor3<int64_t>::allocate();
364template <>
365int Tensor4<int64_t>::allocate();
366template <>
367int Tensor5<int64_t>::allocate();
368template <>
369int Tensor6<int64_t>::allocate();
370
371template <>
372int Tensor0<bool>::allocate();
373template <>
374int Tensor1<bool>::allocate();
375template <>
376int Tensor2<bool>::allocate();
377template <>
378int Tensor3<bool>::allocate();
379template <>
380int Tensor4<bool>::allocate();
381template <>
382int Tensor5<bool>::allocate();
383template <>
384int Tensor6<bool>::allocate();
385
386template <>
387int Tensor0<float>::copyValueFrom(Tensor* src);
388template <>
389int Tensor1<float>::copyValueFrom(Tensor* src);
390template <>
391int Tensor2<float>::copyValueFrom(Tensor* src);
392template <>
393int Tensor3<float>::copyValueFrom(Tensor* src);
394template <>
395int Tensor4<float>::copyValueFrom(Tensor* src);
396template <>
397int Tensor5<float>::copyValueFrom(Tensor* src);
398template <>
399int Tensor6<float>::copyValueFrom(Tensor* src);
400
401template <>
402int Tensor0<int32_t>::copyValueFrom(Tensor* src);
403template <>
404int Tensor1<int32_t>::copyValueFrom(Tensor* src);
405template <>
406int Tensor2<int32_t>::copyValueFrom(Tensor* src);
407template <>
408int Tensor3<int32_t>::copyValueFrom(Tensor* src);
409template <>
410int Tensor4<int32_t>::copyValueFrom(Tensor* src);
411template <>
412int Tensor5<int32_t>::copyValueFrom(Tensor* src);
413template <>
414int Tensor6<int32_t>::copyValueFrom(Tensor* src);
415
416template <>
417int Tensor0<int64_t>::copyValueFrom(Tensor* src);
418template <>
419int Tensor1<int64_t>::copyValueFrom(Tensor* src);
420template <>
421int Tensor2<int64_t>::copyValueFrom(Tensor* src);
422template <>
423int Tensor3<int64_t>::copyValueFrom(Tensor* src);
424template <>
425int Tensor4<int64_t>::copyValueFrom(Tensor* src);
426template <>
427int Tensor5<int64_t>::copyValueFrom(Tensor* src);
428template <>
429int Tensor6<int64_t>::copyValueFrom(Tensor* src);
430
431template <>
432int Tensor0<bool>::copyValueFrom(Tensor* src);
433template <>
434int Tensor1<bool>::copyValueFrom(Tensor* src);
435template <>
436int Tensor2<bool>::copyValueFrom(Tensor* src);
437template <>
438int Tensor3<bool>::copyValueFrom(Tensor* src);
439template <>
440int Tensor4<bool>::copyValueFrom(Tensor* src);
441template <>
442int Tensor5<bool>::copyValueFrom(Tensor* src);
443template <>
444int Tensor6<bool>::copyValueFrom(Tensor* src);
445
446template <>
447int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
448template <>
449int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
450template <>
451int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
452template <>
453int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
454template <>
455int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
456template <>
457int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
458template <>
459int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
460
461template <>
462int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
463template <>
464int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
465template <>
466int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
467template <>
468int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
469template <>
470int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
471template <>
472int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
473template <>
474int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
475
476template <>
477int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
478template <>
479int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
480template <>
481int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
482template <>
483int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
484template <>
485int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
486template <>
487int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
488template <>
489int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
490
491template <>
492int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
493template <>
494int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
495template <>
496int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
497template <>
498int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
499template <>
500int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
501template <>
502int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
503template <>
504int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
505
506template <>
507int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
508template <>
509int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
510template <>
511int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
512template <>
513int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
514template <>
515int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
516template <>
517int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
518template <>
519int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
520
521template <>
522int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
523template <>
524int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
525template <>
526int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
527template <>
528int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
529template <>
530int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
531template <>
532int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
533template <>
534int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
535
536template <>
537int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
538template <>
539int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
540template <>
541int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
542template <>
543int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
544template <>
545int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
546template <>
547int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
548template <>
549int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
550
551template <>
552int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
553template <>
554int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
555template <>
556int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
557template <>
558int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
559template <>
560int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
561template <>
562int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
563template <>
564int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
565
566// assume we only dump float type tensor now
567template <>
568int Tensor0<float>::dumpTensor(FILE* out) const;
569template <>
570int Tensor1<float>::dumpTensor(FILE* out) const;
571template <>
572int Tensor2<float>::dumpTensor(FILE* out) const;
573template <>
574int Tensor3<float>::dumpTensor(FILE* out) const;
575template <>
576int Tensor4<float>::dumpTensor(FILE* out) const;
577template <>
578int Tensor5<float>::dumpTensor(FILE* out) const;
579template <>
580int Tensor6<float>::dumpTensor(FILE* out) const;
581template <>
582int Tensor0<int32_t>::dumpTensor(FILE* out) const;
583template <>
584int Tensor1<int32_t>::dumpTensor(FILE* out) const;
585template <>
586int Tensor2<int32_t>::dumpTensor(FILE* out) const;
587template <>
588int Tensor3<int32_t>::dumpTensor(FILE* out) const;
589template <>
590int Tensor4<int32_t>::dumpTensor(FILE* out) const;
591template <>
592int Tensor5<int32_t>::dumpTensor(FILE* out) const;
593template <>
594int Tensor6<int32_t>::dumpTensor(FILE* out) const;
595template <>
596int Tensor0<int64_t>::dumpTensor(FILE* out) const;
597template <>
598int Tensor1<int64_t>::dumpTensor(FILE* out) const;
599template <>
600int Tensor2<int64_t>::dumpTensor(FILE* out) const;
601template <>
602int Tensor3<int64_t>::dumpTensor(FILE* out) const;
603template <>
604int Tensor4<int64_t>::dumpTensor(FILE* out) const;
605template <>
606int Tensor5<int64_t>::dumpTensor(FILE* out) const;
607template <>
608int Tensor6<int64_t>::dumpTensor(FILE* out) const;
609template <>
610int Tensor0<bool>::dumpTensor(FILE* out) const;
611template <>
612int Tensor1<bool>::dumpTensor(FILE* out) const;
613template <>
614int Tensor2<bool>::dumpTensor(FILE* out) const;
615template <>
616int Tensor3<bool>::dumpTensor(FILE* out) const;
617template <>
618int Tensor4<bool>::dumpTensor(FILE* out) const;
619template <>
620int Tensor5<bool>::dumpTensor(FILE* out) const;
621template <>
622int Tensor6<bool>::dumpTensor(FILE* out) const;
623
624class TensorFactory
625{
626public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700627 static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 {
629 switch (tensorDtype_)
630 {
631 case DType_FLOAT:
632 switch (rank)
633 {
634 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800635 return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700636 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800637 return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700638 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800641 return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700642 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800643 return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700644 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800645 return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700646 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800647 return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700648 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700649 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700650 case DType_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700651 case DType_UINT8:
652 case DType_INT4:
653 case DType_INT8:
654 case DType_INT16:
655 switch (rank)
656 {
657 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800658 return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800660 return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700661 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800662 return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700663 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800664 return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700665 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800666 return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700667 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800668 return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700669 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800670 return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700672 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700673 case DType_INT48:
674 switch (rank)
675 {
676 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700678 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800683 return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800685 return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700686 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800687 return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800689 return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700691 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700692 case DType_BOOL:
693 switch (rank)
694 {
695 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800696 return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700697 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800698 return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700699 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800700 return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700701 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800702 return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700703 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800704 return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700705 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800706 return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700707 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800708 return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700709 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700710 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700711 default:
Kevin Cheng989cb052021-04-28 16:29:44 -0700712 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700713 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700714 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700715 }
716
717 static Tensor* newTensor(DType type, const std::vector<int> shape);
718};
719}; // namespace TosaReference
720
721#endif