blob: e97554fbaae4e7ce91038f4004ad81b7b040577e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
19#include "model_common.h"
20#include "ops/template_types.h"
21#include "tosa_generated.h"
22#include "tosa_serialization_handler.h"
23#include <Eigen/CXX11/Tensor>
24#include <list>
25#include <vector>
26
27using namespace tosa;
28
29namespace TosaReference
30{
31class GraphNode;
32
33class Tensor
34{
35public:
Kevin Cheng989cb052021-04-28 16:29:44 -070036 Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070037
38 virtual ~Tensor();
39
40 int setIsSubgraphInput();
41 int setIsSubgraphOutput();
42
43 int getIsSubgraphInput() const
44 {
45 return isSubgraphInput;
46 }
47
48 int getIsSubgraphOutput() const
49 {
50 return isSubgraphOutput;
51 }
52
53 int setProducer(GraphNode* node);
54 int addConsumer(GraphNode* node);
55
56 int setIsValid()
57 {
58 isValid = 1;
59 return 0;
60 }
61
62 int clearIsValid()
63 {
64 isValid = 0;
65 return 0;
66 }
67
68 int getIsValid() const
69 {
70 return isValid;
71 }
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 GraphNode* getProducer()
74 {
75 return producer;
76 }
77
78 std::vector<GraphNode*>& getConsumers()
79 {
80 return consumers;
81 }
82
83 const std::string& getName() const
84 {
85 return tensorName;
86 }
87
88 const std::vector<int>& getShape() const
89 {
90 return shape;
91 }
92
93 std::string getShapeAsString() const
94 {
95 std::string shape_str("[");
96 for (auto& dim : shape)
97 {
98 shape_str += (std::to_string(dim) + ", ");
99 }
100 shape_str.append("]");
101 return shape_str;
102 }
103
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 const uint32_t getElementCount() const
105 {
106 uint32_t elements = 1;
107 for (size_t i = 0; i < shape.size(); i++)
108 elements *= shape[i];
109
110 return elements;
111 }
112
113 // Comparison of rank and type with other tensors
114 const int matchRank(const Tensor& ref) const
115 {
116 return (ref.shape.size() == shape.size()) ? 0 : 1;
117 }
118
119 const int matchType(const Tensor& ref) const
120 {
121 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
122 }
123
124 const int matchRankType(const Tensor& ref) const
125 {
126 return (matchType(ref) || matchRank(ref));
127 }
128
129 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
130 {
131 if (matchRankType(ref))
132 return 1;
133
134 for (size_t i = 0; i < shape.size(); i++)
135 {
136 if (shape[i] != ref.shape[i])
137 {
138 if (!broadcastOk ||
139 // For broadcasts, at least one operand must have size 1
140 // if they don't both match
141 (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
142 {
143 return 1;
144 }
145 }
146 }
147
148 return 0;
149 }
150
151 // Sometimes we might want to match several semi-compatible types,
152 // so just check rank and size here
153 const int matchRankSize(const Tensor& ref) const
154 {
155 if (matchRank(ref))
156 return 1;
157
158 for (size_t i = 0; i < shape.size(); i++)
159 {
160 if (shape[i] != ref.shape[i])
161 return 1;
162 }
163
164 return 0;
165 }
166
167 // Unary check to make sure rank matches
168 const int checkRequiredRank(const int exactRank) const
169 {
170 return (shape.size() == (size_t)exactRank) ? 0 : 1;
171 }
172
173 const int checkRequiredRank(const int minRank, const int maxRank) const
174 {
175 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
176 }
177
178 const int getRank() const
179 {
180 return shape.size();
181 }
182
183 const DType getDtype() const
184 {
185 return tensorDtype;
186 }
187
188 virtual int dumpTensor(FILE* out) const = 0;
189 virtual int dumpTensorParams(FILE* out) const;
190 virtual int dumpTensorParams(std::ostream& out) const;
191
192 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
193 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
194 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
195 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
196 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
197 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
198 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
199 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
200
201 virtual int readFromNpyFile(const char* filename);
202 virtual int writeToNpyFile(const char* filename) const;
203 virtual int copyValueFrom(Tensor* tensor) = 0;
204
205 const char* bool_to_str(bool in) const
206 {
207 static const char* true_str = "true";
208 static const char* false_str = "false";
209 return in ? true_str : false_str;
210 }
211
212 virtual int allocate() = 0;
213 virtual int deallocate() = 0;
214 virtual bool is_allocated() = 0;
215
216protected:
217 std::string tensorName;
218 DType tensorDtype;
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 int isValid;
220 std::vector<int> shape;
221 int isSubgraphInput;
222 int isSubgraphOutput;
223 bool isAllocated;
224
225 GraphNode* producer;
226 std::vector<GraphNode*> consumers;
227
228 // Note: the Eigen::Tensor is not declared in Tensor
229 // Instead, the TensorTemplate class keeps the templated tensor
230 // declaration so that the graph manipulation tools are isolated
231 // from the templated tensor type.
232 //
233 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
234 // so that they can operate on the right types.
235};
236
237template <class T>
238class TensorTemplate : public Tensor
239{
240public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700241 TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800242 : Tensor(tensorName_, tensorDtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700243 {
244 tensor = nullptr;
245 }
246
247 virtual ~TensorTemplate()
248 {
249 deallocate();
250 }
251
252 virtual int allocate()
253 {
254 tensor = new T();
255 if (tensor)
256 return 0;
257 else
258 return 1;
259 }
260
261 virtual int deallocate()
262 {
263 if (tensor)
264 {
265 delete tensor;
266 }
267 tensor = nullptr;
268 return 0;
269 }
270
271 virtual bool is_allocated()
272 {
273 if (tensor)
274 {
275 return true;
276 }
277 return false;
278 }
279
280 T& getTensor()
281 {
282 return *tensor;
283 }
284
285 virtual int dumpTensor(FILE* out) const;
286
287 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
288 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
289 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
290 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
291 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
292 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
293 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
294 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
295
296 virtual int copyValueFrom(Tensor* tensor);
297
298protected:
299 T* tensor;
300};
301
302// allocate() template specializations to allocate the different tensor sizes
303// Let the compiler know here before the factory uses them, but define them in the .cc file.
304template <>
305int Tensor0<float>::allocate();
306template <>
307int Tensor1<float>::allocate();
308template <>
309int Tensor2<float>::allocate();
310template <>
311int Tensor3<float>::allocate();
312template <>
313int Tensor4<float>::allocate();
314template <>
315int Tensor5<float>::allocate();
316template <>
317int Tensor6<float>::allocate();
318
319template <>
320int Tensor0<int32_t>::allocate();
321template <>
322int Tensor1<int32_t>::allocate();
323template <>
324int Tensor2<int32_t>::allocate();
325template <>
326int Tensor3<int32_t>::allocate();
327template <>
328int Tensor4<int32_t>::allocate();
329template <>
330int Tensor5<int32_t>::allocate();
331template <>
332int Tensor6<int32_t>::allocate();
333
334template <>
335int Tensor0<int64_t>::allocate();
336template <>
337int Tensor1<int64_t>::allocate();
338template <>
339int Tensor2<int64_t>::allocate();
340template <>
341int Tensor3<int64_t>::allocate();
342template <>
343int Tensor4<int64_t>::allocate();
344template <>
345int Tensor5<int64_t>::allocate();
346template <>
347int Tensor6<int64_t>::allocate();
348
349template <>
350int Tensor0<bool>::allocate();
351template <>
352int Tensor1<bool>::allocate();
353template <>
354int Tensor2<bool>::allocate();
355template <>
356int Tensor3<bool>::allocate();
357template <>
358int Tensor4<bool>::allocate();
359template <>
360int Tensor5<bool>::allocate();
361template <>
362int Tensor6<bool>::allocate();
363
364template <>
365int Tensor0<float>::copyValueFrom(Tensor* src);
366template <>
367int Tensor1<float>::copyValueFrom(Tensor* src);
368template <>
369int Tensor2<float>::copyValueFrom(Tensor* src);
370template <>
371int Tensor3<float>::copyValueFrom(Tensor* src);
372template <>
373int Tensor4<float>::copyValueFrom(Tensor* src);
374template <>
375int Tensor5<float>::copyValueFrom(Tensor* src);
376template <>
377int Tensor6<float>::copyValueFrom(Tensor* src);
378
379template <>
380int Tensor0<int32_t>::copyValueFrom(Tensor* src);
381template <>
382int Tensor1<int32_t>::copyValueFrom(Tensor* src);
383template <>
384int Tensor2<int32_t>::copyValueFrom(Tensor* src);
385template <>
386int Tensor3<int32_t>::copyValueFrom(Tensor* src);
387template <>
388int Tensor4<int32_t>::copyValueFrom(Tensor* src);
389template <>
390int Tensor5<int32_t>::copyValueFrom(Tensor* src);
391template <>
392int Tensor6<int32_t>::copyValueFrom(Tensor* src);
393
394template <>
395int Tensor0<int64_t>::copyValueFrom(Tensor* src);
396template <>
397int Tensor1<int64_t>::copyValueFrom(Tensor* src);
398template <>
399int Tensor2<int64_t>::copyValueFrom(Tensor* src);
400template <>
401int Tensor3<int64_t>::copyValueFrom(Tensor* src);
402template <>
403int Tensor4<int64_t>::copyValueFrom(Tensor* src);
404template <>
405int Tensor5<int64_t>::copyValueFrom(Tensor* src);
406template <>
407int Tensor6<int64_t>::copyValueFrom(Tensor* src);
408
409template <>
410int Tensor0<bool>::copyValueFrom(Tensor* src);
411template <>
412int Tensor1<bool>::copyValueFrom(Tensor* src);
413template <>
414int Tensor2<bool>::copyValueFrom(Tensor* src);
415template <>
416int Tensor3<bool>::copyValueFrom(Tensor* src);
417template <>
418int Tensor4<bool>::copyValueFrom(Tensor* src);
419template <>
420int Tensor5<bool>::copyValueFrom(Tensor* src);
421template <>
422int Tensor6<bool>::copyValueFrom(Tensor* src);
423
424template <>
425int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
426template <>
427int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
428template <>
429int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
430template <>
431int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
432template <>
433int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
434template <>
435int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
436template <>
437int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
438
439template <>
440int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
441template <>
442int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
443template <>
444int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
445template <>
446int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
447template <>
448int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
449template <>
450int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
451template <>
452int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
453
454template <>
455int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
456template <>
457int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
458template <>
459int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
460template <>
461int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
462template <>
463int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
464template <>
465int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
466template <>
467int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
468
469template <>
470int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
471template <>
472int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
473template <>
474int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
475template <>
476int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
477template <>
478int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
479template <>
480int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
481template <>
482int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
483
484template <>
485int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
486template <>
487int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
488template <>
489int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
490template <>
491int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
492template <>
493int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
494template <>
495int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
496template <>
497int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
498
499template <>
500int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
501template <>
502int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
503template <>
504int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
505template <>
506int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
507template <>
508int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
509template <>
510int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
511template <>
512int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
513
514template <>
515int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
516template <>
517int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
518template <>
519int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
520template <>
521int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
522template <>
523int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
524template <>
525int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
526template <>
527int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
528
529template <>
530int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
531template <>
532int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
533template <>
534int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
535template <>
536int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
537template <>
538int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
539template <>
540int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
541template <>
542int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
543
544// assume we only dump float type tensor now
545template <>
546int Tensor0<float>::dumpTensor(FILE* out) const;
547template <>
548int Tensor1<float>::dumpTensor(FILE* out) const;
549template <>
550int Tensor2<float>::dumpTensor(FILE* out) const;
551template <>
552int Tensor3<float>::dumpTensor(FILE* out) const;
553template <>
554int Tensor4<float>::dumpTensor(FILE* out) const;
555template <>
556int Tensor5<float>::dumpTensor(FILE* out) const;
557template <>
558int Tensor6<float>::dumpTensor(FILE* out) const;
559template <>
560int Tensor0<int32_t>::dumpTensor(FILE* out) const;
561template <>
562int Tensor1<int32_t>::dumpTensor(FILE* out) const;
563template <>
564int Tensor2<int32_t>::dumpTensor(FILE* out) const;
565template <>
566int Tensor3<int32_t>::dumpTensor(FILE* out) const;
567template <>
568int Tensor4<int32_t>::dumpTensor(FILE* out) const;
569template <>
570int Tensor5<int32_t>::dumpTensor(FILE* out) const;
571template <>
572int Tensor6<int32_t>::dumpTensor(FILE* out) const;
573template <>
574int Tensor0<int64_t>::dumpTensor(FILE* out) const;
575template <>
576int Tensor1<int64_t>::dumpTensor(FILE* out) const;
577template <>
578int Tensor2<int64_t>::dumpTensor(FILE* out) const;
579template <>
580int Tensor3<int64_t>::dumpTensor(FILE* out) const;
581template <>
582int Tensor4<int64_t>::dumpTensor(FILE* out) const;
583template <>
584int Tensor5<int64_t>::dumpTensor(FILE* out) const;
585template <>
586int Tensor6<int64_t>::dumpTensor(FILE* out) const;
587template <>
588int Tensor0<bool>::dumpTensor(FILE* out) const;
589template <>
590int Tensor1<bool>::dumpTensor(FILE* out) const;
591template <>
592int Tensor2<bool>::dumpTensor(FILE* out) const;
593template <>
594int Tensor3<bool>::dumpTensor(FILE* out) const;
595template <>
596int Tensor4<bool>::dumpTensor(FILE* out) const;
597template <>
598int Tensor5<bool>::dumpTensor(FILE* out) const;
599template <>
600int Tensor6<bool>::dumpTensor(FILE* out) const;
601
602class TensorFactory
603{
604public:
Kevin Cheng989cb052021-04-28 16:29:44 -0700605 static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700606 {
607 switch (tensorDtype_)
608 {
609 case DType_FLOAT:
610 switch (rank)
611 {
612 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800613 return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700614 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800615 return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700616 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800617 return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800619 return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800621 return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700622 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800623 return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700624 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800625 return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700626 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700627 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 case DType_INT32:
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 case DType_UINT8:
630 case DType_INT4:
631 case DType_INT8:
632 case DType_INT16:
633 switch (rank)
634 {
635 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800636 return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700637 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800638 return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700641 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800642 return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700643 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800644 return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700645 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800646 return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700647 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800648 return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700649 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700650 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700651 case DType_INT48:
652 switch (rank)
653 {
654 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800655 return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700656 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800657 return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700660 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800661 return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700662 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800663 return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700664 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800665 return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800667 return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700668 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700669 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700670 case DType_BOOL:
671 switch (rank)
672 {
673 case 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800674 return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 case 1:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800676 return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700677 case 2:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800678 return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700679 case 3:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800680 return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700681 case 4:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800682 return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700683 case 5:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800684 return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700685 case 6:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800686 return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700688 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700689 default:
Kevin Cheng989cb052021-04-28 16:29:44 -0700690 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700691 }
692
Kevin Cheng989cb052021-04-28 16:29:44 -0700693 std::string shape_str("[");
694 for (auto& dim : shape_)
695 {
696 shape_str += (std::to_string(dim) + ", ");
697 }
698 shape_str.append("]");
699
700 FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d, shape=%s", tensorName_.c_str(),
701 EnumNamesDType()[tensorDtype_], rank, shape_str.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700702 }
703
704 static Tensor* newTensor(DType type, const std::vector<int> shape);
705};
706}; // namespace TosaReference
707
708#endif