blob: 203cfec46c319d029cee1d5ef4cba934831c6db1 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef TOSA_REFERENCE_TENSOR_H
17#define TOSA_REFERENCE_TENSOR_H
18
Grant Watson64285a12022-11-16 15:32:39 +000019#include "array_proxy.h"
Tai Lya4d748b2023-03-28 22:06:56 +000020#include "dtype.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "model_common.h"
22#include "ops/template_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070023#include "tosa_serialization_handler.h"
24#include <Eigen/CXX11/Tensor>
25#include <list>
26#include <vector>
27
28using namespace tosa;
29
30namespace TosaReference
31{
32class GraphNode;
33
34class Tensor
35{
36public:
Tai Lya4d748b2023-03-28 22:06:56 +000037 Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -070038
39 virtual ~Tensor();
40
41 int setIsSubgraphInput();
42 int setIsSubgraphOutput();
Jerry Ge9e94af82022-10-27 09:57:00 -070043 int setIsParentGraphOutput();
44
Tai Ly47625642023-09-07 20:49:09 +000045 bool getIsParentGraphOutput() const
Jerry Ge9c9c8da2023-07-19 23:08:16 +000046 {
Jerry Ge9e94af82022-10-27 09:57:00 -070047 return isParentGraphOutput;
48 }
Tai Ly47625642023-09-07 20:49:09 +000049 int setIsVariable();
Eric Kunzee5e26762020-10-13 16:11:07 -070050
Tai Ly47625642023-09-07 20:49:09 +000051 bool getIsSubgraphInput() const
Eric Kunzee5e26762020-10-13 16:11:07 -070052 {
53 return isSubgraphInput;
54 }
55
Tai Ly47625642023-09-07 20:49:09 +000056 bool getIsSubgraphOutput() const
Eric Kunzee5e26762020-10-13 16:11:07 -070057 {
58 return isSubgraphOutput;
59 }
60
Tai Ly47625642023-09-07 20:49:09 +000061 bool getIsVariable() const
62 {
63 return isVariable;
64 }
65
Eric Kunzee5e26762020-10-13 16:11:07 -070066 int setProducer(GraphNode* node);
67 int addConsumer(GraphNode* node);
68
69 int setIsValid()
70 {
71 isValid = 1;
72 return 0;
73 }
74
75 int clearIsValid()
76 {
77 isValid = 0;
78 return 0;
79 }
80
81 int getIsValid() const
82 {
83 return isValid;
84 }
85
Eric Kunzee5e26762020-10-13 16:11:07 -070086 GraphNode* getProducer()
87 {
88 return producer;
89 }
90
91 std::vector<GraphNode*>& getConsumers()
92 {
93 return consumers;
94 }
95
96 const std::string& getName() const
97 {
98 return tensorName;
99 }
100
101 const std::vector<int>& getShape() const
102 {
103 return shape;
104 }
105
Jerry Ge264f7fa2023-04-21 22:49:57 +0000106 void setDimSize(size_t dim, uint32_t new_size)
107 {
108 this->shape[dim] = new_size;
109 return;
110 }
111
Eric Kunzee5e26762020-10-13 16:11:07 -0700112 std::string getShapeAsString() const
113 {
114 std::string shape_str("[");
115 for (auto& dim : shape)
116 {
117 shape_str += (std::to_string(dim) + ", ");
118 }
119 shape_str.append("]");
120 return shape_str;
121 }
122
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 const uint32_t getElementCount() const
124 {
125 uint32_t elements = 1;
126 for (size_t i = 0; i < shape.size(); i++)
127 elements *= shape[i];
128
129 return elements;
130 }
131
132 // Comparison of rank and type with other tensors
133 const int matchRank(const Tensor& ref) const
134 {
135 return (ref.shape.size() == shape.size()) ? 0 : 1;
136 }
137
138 const int matchType(const Tensor& ref) const
139 {
140 return (ref.tensorDtype == tensorDtype) ? 0 : 1;
141 }
142
143 const int matchRankType(const Tensor& ref) const
144 {
145 return (matchType(ref) || matchRank(ref));
146 }
147
148 const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
149 {
150 if (matchRankType(ref))
151 return 1;
152
153 for (size_t i = 0; i < shape.size(); i++)
154 {
155 if (shape[i] != ref.shape[i])
156 {
157 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000158 // For broadcasts, the order of *this and ref matters.
159 // *this should be the source tensor.
160 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
161 // this->shape must have size 1 if they don't match
162 (broadcastOk && (shape[i] != 1)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 {
164 return 1;
165 }
166 }
167 }
168
169 return 0;
170 }
171
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800172 const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
173 {
174 if (matchRank(ref))
175 return 1;
176
177 for (size_t i = 0; i < shape.size(); i++)
178 {
179 if (shape[i] != ref.shape[i])
180 {
181 if (!broadcastOk ||
Kevin Cheng2131a4d2021-11-11 19:35:30 +0000182 // For broadcasts, the order of *this and ref matters.
183 // *this should be the source tensor.
184 // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
185 // this->shape must have size 1 if they don't match
186 (broadcastOk && (shape[i] != 1)))
Kevin Cheng1c3c8472021-11-08 11:19:10 -0800187 {
188 return 1;
189 }
190 }
191 }
192
193 return 0;
194 }
195
Eric Kunzee5e26762020-10-13 16:11:07 -0700196 // Sometimes we might want to match several semi-compatible types,
197 // so just check rank and size here
198 const int matchRankSize(const Tensor& ref) const
199 {
200 if (matchRank(ref))
201 return 1;
202
203 for (size_t i = 0; i < shape.size(); i++)
204 {
205 if (shape[i] != ref.shape[i])
206 return 1;
207 }
208
209 return 0;
210 }
211
212 // Unary check to make sure rank matches
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000213 const int checkRequiredRank(const int minRank) const
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000215 return (shape.size() >= (size_t)minRank) ? 0 : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 }
217
218 const int checkRequiredRank(const int minRank, const int maxRank) const
219 {
220 return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
221 }
222
223 const int getRank() const
224 {
225 return shape.size();
226 }
227
Tai Lya4d748b2023-03-28 22:06:56 +0000228 const TOSA_REF_TYPE getDtype() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 {
230 return tensorDtype;
231 }
232
Tai Lya4d748b2023-03-28 22:06:56 +0000233 const DType getSerializationDtype() const
234 {
235 return serializationDtype;
236 }
237
Eric Kunzee5e26762020-10-13 16:11:07 -0700238 virtual int dumpTensor(FILE* out) const = 0;
239 virtual int dumpTensorParams(FILE* out) const;
240 virtual int dumpTensorParams(std::ostream& out) const;
241
Tai Lya4d748b2023-03-28 22:06:56 +0000242 virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700243 virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
244 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
245 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
246 virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
Tai Lya4d748b2023-03-28 22:06:56 +0000247 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700248 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
249 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
250 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
251 virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
252
253 virtual int readFromNpyFile(const char* filename);
254 virtual int writeToNpyFile(const char* filename) const;
255 virtual int copyValueFrom(Tensor* tensor) = 0;
256
Tai Lya4d748b2023-03-28 22:06:56 +0000257 virtual int readfromVector(const ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000258 virtual int readfromVector(const ArrayProxy<float> vals);
259 virtual int readfromVector(const ArrayProxy<half_float::half> vals);
260 virtual int readfromVector(const ArrayProxy<int32_t> vals);
261 virtual int readfromVector(const ArrayProxy<int64_t> vals);
262 virtual int readfromVector(const ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100263
Tai Lya4d748b2023-03-28 22:06:56 +0000264 virtual int writeToVector(ArrayProxy<double> vals);
Grant Watson64285a12022-11-16 15:32:39 +0000265 virtual int writeToVector(ArrayProxy<float> vals);
266 virtual int writeToVector(ArrayProxy<half_float::half> vals);
267 virtual int writeToVector(ArrayProxy<int32_t> vals);
268 virtual int writeToVector(ArrayProxy<int64_t> vals);
269 virtual int writeToVector(ArrayProxy<unsigned char> vals);
Matthew Sloyanba5fad32022-09-26 13:31:43 +0100270
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 const char* bool_to_str(bool in) const
272 {
273 static const char* true_str = "true";
274 static const char* false_str = "false";
275 return in ? true_str : false_str;
276 }
277
Tai Ly47625642023-09-07 20:49:09 +0000278 virtual int allocate() = 0;
279 virtual int deallocate() = 0;
280 virtual bool is_allocated() const = 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700281
282protected:
Tai Lya4d748b2023-03-28 22:06:56 +0000283 const std::string tensorName;
284 const DType serializationDtype;
Jerry Ge264f7fa2023-04-21 22:49:57 +0000285 std::vector<int> shape;
Tai Lya4d748b2023-03-28 22:06:56 +0000286 const TOSA_REF_TYPE tensorDtype;
Tai Ly47625642023-09-07 20:49:09 +0000287 bool isValid;
288 bool isSubgraphInput;
289 bool isSubgraphOutput;
290 bool isVariable;
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 bool isAllocated;
292
Jerry Ge9e94af82022-10-27 09:57:00 -0700293 bool isParentGraphOutput;
294
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 GraphNode* producer;
296 std::vector<GraphNode*> consumers;
297
298 // Note: the Eigen::Tensor is not declared in Tensor
299 // Instead, the TensorTemplate class keeps the templated tensor
300 // declaration so that the graph manipulation tools are isolated
301 // from the templated tensor type.
302 //
303 // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
304 // so that they can operate on the right types.
305};
306
307template <class T>
308class TensorTemplate : public Tensor
309{
310public:
Tai Lya4d748b2023-03-28 22:06:56 +0000311 TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
312 : Tensor(tensorName_, dtype_, shape_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700313 {
314 tensor = nullptr;
315 }
316
317 virtual ~TensorTemplate()
318 {
319 deallocate();
320 }
321
322 virtual int allocate()
323 {
324 tensor = new T();
325 if (tensor)
326 return 0;
327 else
328 return 1;
329 }
330
331 virtual int deallocate()
332 {
333 if (tensor)
334 {
Eric Kunze9a367552023-07-11 13:27:36 -0700335 DEBUG_INFO(GT, "Deallocating tensor %s", tensorName.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700336 delete tensor;
337 }
338 tensor = nullptr;
339 return 0;
340 }
341
Tai Ly47625642023-09-07 20:49:09 +0000342 virtual bool is_allocated() const
Eric Kunzee5e26762020-10-13 16:11:07 -0700343 {
344 if (tensor)
345 {
346 return true;
347 }
348 return false;
349 }
350
351 T& getTensor()
352 {
353 return *tensor;
354 }
355
356 virtual int dumpTensor(FILE* out) const;
357
Tai Lya4d748b2023-03-28 22:06:56 +0000358 virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
Eric Kunzee5e26762020-10-13 16:11:07 -0700359 virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
360 virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
361 virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
362 virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
Tai Lya4d748b2023-03-28 22:06:56 +0000363
364 virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
Eric Kunzee5e26762020-10-13 16:11:07 -0700365 virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
366 virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
367 virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
368 virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
369
370 virtual int copyValueFrom(Tensor* tensor);
371
372protected:
373 T* tensor;
374};
375
376// allocate() template specializations to allocate the different tensor sizes
377// Let the compiler know here before the factory uses them, but define them in the .cc file.
378template <>
379int Tensor0<float>::allocate();
380template <>
381int Tensor1<float>::allocate();
382template <>
383int Tensor2<float>::allocate();
384template <>
385int Tensor3<float>::allocate();
386template <>
387int Tensor4<float>::allocate();
388template <>
389int Tensor5<float>::allocate();
390template <>
391int Tensor6<float>::allocate();
392
393template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000394int Tensor0<double>::allocate();
395template <>
396int Tensor1<double>::allocate();
397template <>
398int Tensor2<double>::allocate();
399template <>
400int Tensor3<double>::allocate();
401template <>
402int Tensor4<double>::allocate();
403template <>
404int Tensor5<double>::allocate();
405template <>
406int Tensor6<double>::allocate();
407
408template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700409int Tensor0<int32_t>::allocate();
410template <>
411int Tensor1<int32_t>::allocate();
412template <>
413int Tensor2<int32_t>::allocate();
414template <>
415int Tensor3<int32_t>::allocate();
416template <>
417int Tensor4<int32_t>::allocate();
418template <>
419int Tensor5<int32_t>::allocate();
420template <>
421int Tensor6<int32_t>::allocate();
422
423template <>
424int Tensor0<int64_t>::allocate();
425template <>
426int Tensor1<int64_t>::allocate();
427template <>
428int Tensor2<int64_t>::allocate();
429template <>
430int Tensor3<int64_t>::allocate();
431template <>
432int Tensor4<int64_t>::allocate();
433template <>
434int Tensor5<int64_t>::allocate();
435template <>
436int Tensor6<int64_t>::allocate();
437
438template <>
439int Tensor0<bool>::allocate();
440template <>
441int Tensor1<bool>::allocate();
442template <>
443int Tensor2<bool>::allocate();
444template <>
445int Tensor3<bool>::allocate();
446template <>
447int Tensor4<bool>::allocate();
448template <>
449int Tensor5<bool>::allocate();
450template <>
451int Tensor6<bool>::allocate();
452
453template <>
454int Tensor0<float>::copyValueFrom(Tensor* src);
455template <>
456int Tensor1<float>::copyValueFrom(Tensor* src);
457template <>
458int Tensor2<float>::copyValueFrom(Tensor* src);
459template <>
460int Tensor3<float>::copyValueFrom(Tensor* src);
461template <>
462int Tensor4<float>::copyValueFrom(Tensor* src);
463template <>
464int Tensor5<float>::copyValueFrom(Tensor* src);
465template <>
466int Tensor6<float>::copyValueFrom(Tensor* src);
467
468template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000469int Tensor0<double>::copyValueFrom(Tensor* src);
470template <>
471int Tensor1<double>::copyValueFrom(Tensor* src);
472template <>
473int Tensor2<double>::copyValueFrom(Tensor* src);
474template <>
475int Tensor3<double>::copyValueFrom(Tensor* src);
476template <>
477int Tensor4<double>::copyValueFrom(Tensor* src);
478template <>
479int Tensor5<double>::copyValueFrom(Tensor* src);
480template <>
481int Tensor6<double>::copyValueFrom(Tensor* src);
482
483template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700484int Tensor0<int32_t>::copyValueFrom(Tensor* src);
485template <>
486int Tensor1<int32_t>::copyValueFrom(Tensor* src);
487template <>
488int Tensor2<int32_t>::copyValueFrom(Tensor* src);
489template <>
490int Tensor3<int32_t>::copyValueFrom(Tensor* src);
491template <>
492int Tensor4<int32_t>::copyValueFrom(Tensor* src);
493template <>
494int Tensor5<int32_t>::copyValueFrom(Tensor* src);
495template <>
496int Tensor6<int32_t>::copyValueFrom(Tensor* src);
497
498template <>
499int Tensor0<int64_t>::copyValueFrom(Tensor* src);
500template <>
501int Tensor1<int64_t>::copyValueFrom(Tensor* src);
502template <>
503int Tensor2<int64_t>::copyValueFrom(Tensor* src);
504template <>
505int Tensor3<int64_t>::copyValueFrom(Tensor* src);
506template <>
507int Tensor4<int64_t>::copyValueFrom(Tensor* src);
508template <>
509int Tensor5<int64_t>::copyValueFrom(Tensor* src);
510template <>
511int Tensor6<int64_t>::copyValueFrom(Tensor* src);
512
513template <>
514int Tensor0<bool>::copyValueFrom(Tensor* src);
515template <>
516int Tensor1<bool>::copyValueFrom(Tensor* src);
517template <>
518int Tensor2<bool>::copyValueFrom(Tensor* src);
519template <>
520int Tensor3<bool>::copyValueFrom(Tensor* src);
521template <>
522int Tensor4<bool>::copyValueFrom(Tensor* src);
523template <>
524int Tensor5<bool>::copyValueFrom(Tensor* src);
525template <>
526int Tensor6<bool>::copyValueFrom(Tensor* src);
527
528template <>
529int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
530template <>
531int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
532template <>
533int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
534template <>
535int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
536template <>
537int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
538template <>
539int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
540template <>
541int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
542
543template <>
544int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
545template <>
546int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
547template <>
548int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
549template <>
550int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
551template <>
552int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
553template <>
554int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
555template <>
556int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
557
558template <>
559int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
560template <>
561int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
562template <>
563int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
564template <>
565int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
566template <>
567int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
568template <>
569int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
570template <>
571int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
572
573template <>
574int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
575template <>
576int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
577template <>
578int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
579template <>
580int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
581template <>
582int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
583template <>
584int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
585template <>
586int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
587
588template <>
589int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
590template <>
591int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
592template <>
593int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
594template <>
595int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
596template <>
597int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
598template <>
599int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
600template <>
601int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
602
603template <>
604int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
605template <>
606int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
607template <>
608int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
609template <>
610int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
611template <>
612int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
613template <>
614int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
615template <>
616int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
617
618template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000619int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
620template <>
621int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
622template <>
623int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
624template <>
625int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
626template <>
627int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
628template <>
629int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
630template <>
631int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
632
633template <>
634int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
635template <>
636int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
637template <>
638int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
639template <>
640int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
641template <>
642int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
643template <>
644int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
645template <>
646int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
647
648template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700649int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
650template <>
651int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
652template <>
653int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
654template <>
655int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
656template <>
657int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
658template <>
659int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
660template <>
661int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
662
663template <>
664int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
665template <>
666int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
667template <>
668int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
669template <>
670int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
671template <>
672int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
673template <>
674int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
675template <>
676int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
677
Eric Kunzee5e26762020-10-13 16:11:07 -0700678template <>
679int Tensor0<float>::dumpTensor(FILE* out) const;
680template <>
681int Tensor1<float>::dumpTensor(FILE* out) const;
682template <>
683int Tensor2<float>::dumpTensor(FILE* out) const;
684template <>
685int Tensor3<float>::dumpTensor(FILE* out) const;
686template <>
687int Tensor4<float>::dumpTensor(FILE* out) const;
688template <>
689int Tensor5<float>::dumpTensor(FILE* out) const;
690template <>
691int Tensor6<float>::dumpTensor(FILE* out) const;
692template <>
Tai Lya4d748b2023-03-28 22:06:56 +0000693int Tensor0<double>::dumpTensor(FILE* out) const;
694template <>
695int Tensor1<double>::dumpTensor(FILE* out) const;
696template <>
697int Tensor2<double>::dumpTensor(FILE* out) const;
698template <>
699int Tensor3<double>::dumpTensor(FILE* out) const;
700template <>
701int Tensor4<double>::dumpTensor(FILE* out) const;
702template <>
703int Tensor5<float>::dumpTensor(FILE* out) const;
704template <>
705int Tensor6<double>::dumpTensor(FILE* out) const;
706template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700707int Tensor0<int32_t>::dumpTensor(FILE* out) const;
708template <>
709int Tensor1<int32_t>::dumpTensor(FILE* out) const;
710template <>
711int Tensor2<int32_t>::dumpTensor(FILE* out) const;
712template <>
713int Tensor3<int32_t>::dumpTensor(FILE* out) const;
714template <>
715int Tensor4<int32_t>::dumpTensor(FILE* out) const;
716template <>
717int Tensor5<int32_t>::dumpTensor(FILE* out) const;
718template <>
719int Tensor6<int32_t>::dumpTensor(FILE* out) const;
720template <>
721int Tensor0<int64_t>::dumpTensor(FILE* out) const;
722template <>
723int Tensor1<int64_t>::dumpTensor(FILE* out) const;
724template <>
725int Tensor2<int64_t>::dumpTensor(FILE* out) const;
726template <>
727int Tensor3<int64_t>::dumpTensor(FILE* out) const;
728template <>
729int Tensor4<int64_t>::dumpTensor(FILE* out) const;
730template <>
731int Tensor5<int64_t>::dumpTensor(FILE* out) const;
732template <>
733int Tensor6<int64_t>::dumpTensor(FILE* out) const;
734template <>
735int Tensor0<bool>::dumpTensor(FILE* out) const;
736template <>
737int Tensor1<bool>::dumpTensor(FILE* out) const;
738template <>
739int Tensor2<bool>::dumpTensor(FILE* out) const;
740template <>
741int Tensor3<bool>::dumpTensor(FILE* out) const;
742template <>
743int Tensor4<bool>::dumpTensor(FILE* out) const;
744template <>
745int Tensor5<bool>::dumpTensor(FILE* out) const;
746template <>
747int Tensor6<bool>::dumpTensor(FILE* out) const;
748
749class TensorFactory
750{
751public:
Tai Lya4d748b2023-03-28 22:06:56 +0000752 static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
Eric Kunzee5e26762020-10-13 16:11:07 -0700753 {
Tai Lya4d748b2023-03-28 22:06:56 +0000754 TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 switch (tensorDtype_)
756 {
Tai Lya4d748b2023-03-28 22:06:56 +0000757 case TOSA_REF_TYPE_FP32:
758 case TOSA_REF_TYPE_FP16:
759 case TOSA_REF_TYPE_BF16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 switch (rank)
761 {
762 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000763 return new Tensor0<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000765 return new Tensor1<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000767 return new Tensor2<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700768 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000769 return new Tensor3<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700770 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000771 return new Tensor4<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700772 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000773 return new Tensor5<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700774 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000775 return new Tensor6<float>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700776 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700777 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000778 case TOSA_REF_TYPE_INT32:
779 case TOSA_REF_TYPE_UINT8:
780 case TOSA_REF_TYPE_INT4:
781 case TOSA_REF_TYPE_INT8:
782 case TOSA_REF_TYPE_INT16:
783 case TOSA_REF_TYPE_UINT16:
Eric Kunzee5e26762020-10-13 16:11:07 -0700784 switch (rank)
785 {
786 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000787 return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700788 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000789 return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700790 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000791 return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000793 return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700794 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000795 return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700796 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000797 return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700798 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000799 return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700801 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000802 case TOSA_REF_TYPE_INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700803 switch (rank)
804 {
805 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000806 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700807 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000808 return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700809 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000810 return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700811 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000812 return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700813 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000814 return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700815 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000816 return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700817 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000818 return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700820 break;
Won Jeona21b2e82023-08-10 10:33:01 +0000821 case TOSA_REF_TYPE_SHAPE:
Tai Ly0913dba2023-08-22 22:50:18 +0000822 assert(rank == 0);
Won Jeona21b2e82023-08-10 10:33:01 +0000823 return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
Tai Lya4d748b2023-03-28 22:06:56 +0000824 case TOSA_REF_TYPE_BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700825 switch (rank)
826 {
827 case 0:
Tai Lya4d748b2023-03-28 22:06:56 +0000828 return new Tensor0<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700829 case 1:
Tai Lya4d748b2023-03-28 22:06:56 +0000830 return new Tensor1<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700831 case 2:
Tai Lya4d748b2023-03-28 22:06:56 +0000832 return new Tensor2<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700833 case 3:
Tai Lya4d748b2023-03-28 22:06:56 +0000834 return new Tensor3<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700835 case 4:
Tai Lya4d748b2023-03-28 22:06:56 +0000836 return new Tensor4<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700837 case 5:
Tai Lya4d748b2023-03-28 22:06:56 +0000838 return new Tensor5<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700839 case 6:
Tai Lya4d748b2023-03-28 22:06:56 +0000840 return new Tensor6<bool>(tensorName_, dtype_, shape_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700841 }
Kevin Cheng989cb052021-04-28 16:29:44 -0700842 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000843 case TOSA_REF_TYPE_FP64:
844 switch (rank)
845 {
846 case 0:
847 return new Tensor0<double>(tensorName_, dtype_, shape_);
848 case 1:
849 return new Tensor1<double>(tensorName_, dtype_, shape_);
850 case 2:
851 return new Tensor2<double>(tensorName_, dtype_, shape_);
852 case 3:
853 return new Tensor3<double>(tensorName_, dtype_, shape_);
854 case 4:
855 return new Tensor4<double>(tensorName_, dtype_, shape_);
856 case 5:
857 return new Tensor5<double>(tensorName_, dtype_, shape_);
858 case 6:
859 return new Tensor6<double>(tensorName_, dtype_, shape_);
860 }
861 break;
862 case TOSA_REF_TYPE_UNKNOWN:
863 assert(0); // tensorDtype_ is uninitialized
Kevin Cheng989cb052021-04-28 16:29:44 -0700864 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700865 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700866 return nullptr;
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700868};
869}; // namespace TosaReference
870
871#endif