blob: d857dc81125e15094454d51346c610a0153e3670 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
19#include "model_common.h"
20#include "ops/template_types.h"
21#include "tosa_generated.h"
22#include "tosa_serialization_handler.h"
23#include <Eigen/CXX11/Tensor>
24#include <list>
25#include <vector>
26
27using namespace tosa;
28
29namespace TosaReference
30{
31class GraphNode;
32
33class Tensor
34{
35public:
Kevin Cheng989cb052021-04-28 16:29:44 -070036 Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 virtual ~Tensor();
39
40 int setIsSubgraphInput();
41 int setIsSubgraphOutput();
42
43 int getIsSubgraphInput() const
44 {
45 return isSubgraphInput;
46 }
47
48 int getIsSubgraphOutput() const
49 {
50 return isSubgraphOutput;
51 }
52
53 int setProducer(GraphNode* node);
54 int addConsumer(GraphNode* node);
55
56 int setIsValid()
57 {
58 isValid = 1;
59 return 0;
60 }
61
62 int clearIsValid()
63 {
64 isValid = 0;
65 return 0;
66 }
67
68 int getIsValid() const
69 {
70 return isValid;
71 }
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 GraphNode* getProducer()
74 {
75 return producer;
76 }
77
78 std::vector<GraphNode*>& getConsumers()
79 {
80 return consumers;
81 }
82
83 const std::string& getName() const
84 {
85 return tensorName;
86 }
87
88 const std::vector<int>& getShape() const
89 {
90 return shape;
91 }
92
93 std::string getShapeAsString() const
94 {
95 std::string shape_str("[");
96 for (auto& dim : shape)
97 {
98 shape_str += (std::to_string(dim) + ", ");
99 }
100 shape_str.append("]");
101 return shape_str;
102 }
103
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 const uint32_t getElementCount() const
105 {
106 uint32_t elements = 1;
107 for (size_t i = 0; i < shape.size(); i++)
108 elements *= shape[i];
109
110 return elements;
111 }
112
113 // Comparison of rank and type with other tensors
114 const int matchRank(const Tensor& ref) const
115 {
116 return (ref.shape.size() == shape.size()) ? 0 : 1;
117 }
118
119 const int matchType(const Tensor& ref) const
120 {
121 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
122 }
123
124 const int matchRankType(const Tensor& ref) const
125 {
126 return (matchType(ref) || matchRank(ref));
127 }
128
129 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
130 {
131 if (matchRankType(ref))
132 return 1;
133
134 for (size_t i = 0; i < shape.size(); i++)
135 {
136 if (shape[i] != ref.shape[i])
137 {
138 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000139 // For broadcasts, the order of *this and ref matters.
140 // *this should be the source tensor.
141 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
142 // this->shape must have size 1 if they don't match
143 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 {
145 return 1;
146 }
147 }
148 }
149
150 return 0;
151 }
152
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800153 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
154 {
155 if (matchRank(ref))
156 return 1;
157
158 for (size_t i = 0; i < shape.size(); i++)
159 {
160 if (shape[i] != ref.shape[i])
161 {
162 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000163 // For broadcasts, the order of *this and ref matters.
164 // *this should be the source tensor.
165 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
166 // this->shape must have size 1 if they don't match
167 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800168 {
169 return 1;
170 }
171 }
172 }
173
174 return 0;
175 }
176
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 // Sometimes we might want to match several semi-compatible types,
178 // so just check rank and size here
179 const int matchRankSize(const Tensor& ref) const
180 {
181 if (matchRank(ref))
182 return 1;
183
184 for (size_t i = 0; i < shape.size(); i++)
185 {
186 if (shape[i] != ref.shape[i])
187 return 1;
188 }
189
190 return 0;
191 }
192
193 // Unary check to make sure rank matches
194 const int checkRequiredRank(const int exactRank) const
195 {
196 return (shape.size() == (size_t)exactRank) ? 0 : 1;
197 }
198
199 const int checkRequiredRank(const int minRank, const int maxRank) const
200 {
201 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
202 }
203
204 const int getRank() const
205 {
206 return shape.size();
207 }
208
209 const DType getDtype() const
210 {
211 return tensorDtype;
212 }
213
214 virtual int dumpTensor(FILE* out) const = 0;
215 virtual int dumpTensorParams(FILE* out) const;
216 virtual int dumpTensorParams(std::ostream& out) const;
217
218 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
219 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
220 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
221 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
222 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
223 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
224 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
225 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
226
227 virtual int readFromNpyFile(const char* filename);
228 virtual int writeToNpyFile(const char* filename) const;
229 virtual int copyValueFrom(Tensor* tensor) = 0;
230
231 const char* bool_to_str(bool in) const
232 {
233 static const char* true_str = "true";
234 static const char* false_str = "false";
235 return in ? true_str : false_str;
236 }
237
238 virtual int allocate() = 0;
239 virtual int deallocate() = 0;
240 virtual bool is_allocated() = 0;
241
242protected:
243 std::string tensorName;
244 DType tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700245 int isValid;
246 std::vector<int> shape;
247 int isSubgraphInput;
248 int isSubgraphOutput;
249 bool isAllocated;
250
251 GraphNode* producer;
252 std::vector<GraphNode*> consumers;
253
254 // Note: the Eigen::Tensor is not declared in Tensor
255 // Instead, the TensorTemplate class keeps the templated tensor
256 // declaration so that the graph manipulation tools are isolated
257 // from the templated tensor type.
258 //
259 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
260 // so that they can operate on the right types.
261};
262
263template <class T>
264class TensorTemplate : public Tensor
265{
266public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700267 TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800268 : Tensor(tensorName_, tensorDtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 {
270 tensor = nullptr;
271 }
272
273 virtual ~TensorTemplate()
274 {
275 deallocate();
276 }
277
278 virtual int allocate()
279 {
280 tensor = new T();
281 if (tensor)
282 return 0;
283 else
284 return 1;
285 }
286
287 virtual int deallocate()
288 {
289 if (tensor)
290 {
291 delete tensor;
292 }
293 tensor = nullptr;
294 return 0;
295 }
296
297 virtual bool is_allocated()
298 {
299 if (tensor)
300 {
301 return true;
302 }
303 return false;
304 }
305
306 T& getTensor()
307 {
308 return *tensor;
309 }
310
311 virtual int dumpTensor(FILE* out) const;
312
313 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
314 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
315 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
316 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
317 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
318 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
319 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
320 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
321
322 virtual int copyValueFrom(Tensor* tensor);
323
324protected:
325 T* tensor;
326};
327
328// allocate() template specializations to allocate the different tensor sizes
329// Let the compiler know here before the factory uses them, but define them in the .cc file.
330template <>
331int Tensor0<float>::allocate();
332template <>
333int Tensor1<float>::allocate();
334template <>
335int Tensor2<float>::allocate();
336template <>
337int Tensor3<float>::allocate();
338template <>
339int Tensor4<float>::allocate();
340template <>
341int Tensor5<float>::allocate();
342template <>
343int Tensor6<float>::allocate();
344
345template <>
346int Tensor0<int32_t>::allocate();
347template <>
348int Tensor1<int32_t>::allocate();
349template <>
350int Tensor2<int32_t>::allocate();
351template <>
352int Tensor3<int32_t>::allocate();
353template <>
354int Tensor4<int32_t>::allocate();
355template <>
356int Tensor5<int32_t>::allocate();
357template <>
358int Tensor6<int32_t>::allocate();
359
360template <>
361int Tensor0<int64_t>::allocate();
362template <>
363int Tensor1<int64_t>::allocate();
364template <>
365int Tensor2<int64_t>::allocate();
366template <>
367int Tensor3<int64_t>::allocate();
368template <>
369int Tensor4<int64_t>::allocate();
370template <>
371int Tensor5<int64_t>::allocate();
372template <>
373int Tensor6<int64_t>::allocate();
374
375template <>
376int Tensor0<bool>::allocate();
377template <>
378int Tensor1<bool>::allocate();
379template <>
380int Tensor2<bool>::allocate();
381template <>
382int Tensor3<bool>::allocate();
383template <>
384int Tensor4<bool>::allocate();
385template <>
386int Tensor5<bool>::allocate();
387template <>
388int Tensor6<bool>::allocate();
389
390template <>
391int Tensor0<float>::copyValueFrom(Tensor* src);
392template <>
393int Tensor1<float>::copyValueFrom(Tensor* src);
394template <>
395int Tensor2<float>::copyValueFrom(Tensor* src);
396template <>
397int Tensor3<float>::copyValueFrom(Tensor* src);
398template <>
399int Tensor4<float>::copyValueFrom(Tensor* src);
400template <>
401int Tensor5<float>::copyValueFrom(Tensor* src);
402template <>
403int Tensor6<float>::copyValueFrom(Tensor* src);
404
405template <>
406int Tensor0<int32_t>::copyValueFrom(Tensor* src);
407template <>
408int Tensor1<int32_t>::copyValueFrom(Tensor* src);
409template <>
410int Tensor2<int32_t>::copyValueFrom(Tensor* src);
411template <>
412int Tensor3<int32_t>::copyValueFrom(Tensor* src);
413template <>
414int Tensor4<int32_t>::copyValueFrom(Tensor* src);
415template <>
416int Tensor5<int32_t>::copyValueFrom(Tensor* src);
417template <>
418int Tensor6<int32_t>::copyValueFrom(Tensor* src);
419
420template <>
421int Tensor0<int64_t>::copyValueFrom(Tensor* src);
422template <>
423int Tensor1<int64_t>::copyValueFrom(Tensor* src);
424template <>
425int Tensor2<int64_t>::copyValueFrom(Tensor* src);
426template <>
427int Tensor3<int64_t>::copyValueFrom(Tensor* src);
428template <>
429int Tensor4<int64_t>::copyValueFrom(Tensor* src);
430template <>
431int Tensor5<int64_t>::copyValueFrom(Tensor* src);
432template <>
433int Tensor6<int64_t>::copyValueFrom(Tensor* src);
434
435template <>
436int Tensor0<bool>::copyValueFrom(Tensor* src);
437template <>
438int Tensor1<bool>::copyValueFrom(Tensor* src);
439template <>
440int Tensor2<bool>::copyValueFrom(Tensor* src);
441template <>
442int Tensor3<bool>::copyValueFrom(Tensor* src);
443template <>
444int Tensor4<bool>::copyValueFrom(Tensor* src);
445template <>
446int Tensor5<bool>::copyValueFrom(Tensor* src);
447template <>
448int Tensor6<bool>::copyValueFrom(Tensor* src);
449
450template <>
451int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
452template <>
453int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
454template <>
455int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
456template <>
457int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
458template <>
459int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
460template <>
461int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
462template <>
463int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
464
465template <>
466int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
467template <>
468int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
469template <>
470int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
471template <>
472int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
473template <>
474int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
475template <>
476int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
477template <>
478int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
479
480template <>
481int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
482template <>
483int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
484template <>
485int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
486template <>
487int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
488template <>
489int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
490template <>
491int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
492template <>
493int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
494
495template <>
496int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
497template <>
498int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
499template <>
500int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
501template <>
502int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
503template <>
504int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
505template <>
506int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
507template <>
508int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
509
510template <>
511int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
512template <>
513int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
514template <>
515int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
516template <>
517int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
518template <>
519int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
520template <>
521int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
522template <>
523int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
524
525template <>
526int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
527template <>
528int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
529template <>
530int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
531template <>
532int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
533template <>
534int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
535template <>
536int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
537template <>
538int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
539
540template <>
541int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
542template <>
543int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
544template <>
545int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
546template <>
547int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
548template <>
549int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
550template <>
551int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
552template <>
553int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
554
555template <>
556int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
557template <>
558int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
559template <>
560int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
561template <>
562int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
563template <>
564int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
565template <>
566int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
567template <>
568int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
569
570// assume we only dump float type tensor now
571template <>
572int Tensor0<float>::dumpTensor(FILE* out) const;
573template <>
574int Tensor1<float>::dumpTensor(FILE* out) const;
575template <>
576int Tensor2<float>::dumpTensor(FILE* out) const;
577template <>
578int Tensor3<float>::dumpTensor(FILE* out) const;
579template <>
580int Tensor4<float>::dumpTensor(FILE* out) const;
581template <>
582int Tensor5<float>::dumpTensor(FILE* out) const;
583template <>
584int Tensor6<float>::dumpTensor(FILE* out) const;
585template <>
586int Tensor0<int32_t>::dumpTensor(FILE* out) const;
587template <>
588int Tensor1<int32_t>::dumpTensor(FILE* out) const;
589template <>
590int Tensor2<int32_t>::dumpTensor(FILE* out) const;
591template <>
592int Tensor3<int32_t>::dumpTensor(FILE* out) const;
593template <>
594int Tensor4<int32_t>::dumpTensor(FILE* out) const;
595template <>
596int Tensor5<int32_t>::dumpTensor(FILE* out) const;
597template <>
598int Tensor6<int32_t>::dumpTensor(FILE* out) const;
599template <>
600int Tensor0<int64_t>::dumpTensor(FILE* out) const;
601template <>
602int Tensor1<int64_t>::dumpTensor(FILE* out) const;
603template <>
604int Tensor2<int64_t>::dumpTensor(FILE* out) const;
605template <>
606int Tensor3<int64_t>::dumpTensor(FILE* out) const;
607template <>
608int Tensor4<int64_t>::dumpTensor(FILE* out) const;
609template <>
610int Tensor5<int64_t>::dumpTensor(FILE* out) const;
611template <>
612int Tensor6<int64_t>::dumpTensor(FILE* out) const;
613template <>
614int Tensor0<bool>::dumpTensor(FILE* out) const;
615template <>
616int Tensor1<bool>::dumpTensor(FILE* out) const;
617template <>
618int Tensor2<bool>::dumpTensor(FILE* out) const;
619template <>
620int Tensor3<bool>::dumpTensor(FILE* out) const;
621template <>
622int Tensor4<bool>::dumpTensor(FILE* out) const;
623template <>
624int Tensor5<bool>::dumpTensor(FILE* out) const;
625template <>
626int Tensor6<bool>::dumpTensor(FILE* out) const;
627
628class TensorFactory
629{
630public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700631 static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700632 {
633 switch (tensorDtype_)
634 {
635 case DType_FLOAT:
636 switch (rank)
637 {
638 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800641 return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700642 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800643 return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700644 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800645 return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700646 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800647 return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700648 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800649 return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700650 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800651 return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700652 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700653 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700654 case DType_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700655 case DType_UINT8:
656 case DType_INT4:
657 case DType_INT8:
658 case DType_INT16:
659 switch (rank)
660 {
661 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800662 return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700663 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800664 return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700665 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800666 return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700667 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800668 return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700669 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800670 return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800672 return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700673 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800674 return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700676 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700677 case DType_INT48:
678 switch (rank)
679 {
680 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800683 return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800685 return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700686 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800687 return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800689 return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800691 return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700692 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800693 return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700694 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700695 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 case DType_BOOL:
697 switch (rank)
698 {
699 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800700 return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700701 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800702 return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700703 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800704 return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700705 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800706 return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700707 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800708 return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700709 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800710 return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700711 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800712 return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700713 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700714 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700715 default:
Kevin Cheng989cb052021-04-28 16:29:44 -0700716 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700717 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700718 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700719 }
720
721 static Tensor* newTensor(DType type, const std::vector<int> shape);
722};
723}; // namespace TosaReference
724
725#endif