blob: 55e04e876d2b529981a7e72d9c620e565728e6b8 [file] [log] [blame]
Georgios Pinitas8a5146f2021-01-12 15:51:07 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_ACL_HPP_
25#define ARM_COMPUTE_ACL_HPP_
26
27#include "arm_compute/Acl.h"
28
29#include <cstdlib>
30#include <memory>
31#include <string>
Georgios Pinitas3f26ef42021-02-23 10:01:33 +000032#include <vector>
Georgios Pinitas8a5146f2021-01-12 15:51:07 +000033
34#if defined(ARM_COMPUTE_EXCEPTIONS_ENABLED)
35#include <exception>
36#endif /* defined(ARM_COMPUTE_EXCEPTIONS_ENABLED) */
37
38// Helper Macros
39#define ARM_COMPUTE_IGNORE_UNUSED(x) (void)(x)
40
41namespace acl
42{
43// Forward declarations
44class Context;
Georgios Pinitasc3c352e2021-03-18 10:59:40 +000045class Queue;
Georgios Pinitas3f26ef42021-02-23 10:01:33 +000046class Tensor;
47class TensorPack;
Georgios Pinitas8a5146f2021-01-12 15:51:07 +000048
49/**< Status code enum */
50enum class StatusCode
51{
52 Success = AclSuccess,
53 RuntimeError = AclRuntimeError,
54 OutOfMemory = AclOutOfMemory,
55 Unimplemented = AclUnimplemented,
56 UnsupportedTarget = AclUnsupportedTarget,
57 InvalidArgument = AclInvalidArgument,
58 InvalidTarget = AclInvalidTarget,
59 UnsupportedConfig = AclUnsupportedConfig,
60 InvalidObjectState = AclInvalidObjectState,
61};
62
63/**< Utility namespace containing helpers functions */
64namespace detail
65{
66/** Construct to handle destruction of objects
67 *
68 * @tparam T Object base type
69 */
70template <typename T>
71struct ObjectDeleter
72{
73};
74
75#define OBJECT_DELETER(obj, func) \
76 template <> \
77 struct ObjectDeleter<obj> \
78 \
79 { \
80 static inline AclStatus Destroy(obj v) \
81 { \
82 return func(v); \
83 } \
84 };
85
86OBJECT_DELETER(AclContext, AclDestroyContext)
Georgios Pinitasc3c352e2021-03-18 10:59:40 +000087OBJECT_DELETER(AclQueue, AclDestroyQueue)
Georgios Pinitas3f26ef42021-02-23 10:01:33 +000088OBJECT_DELETER(AclTensor, AclDestroyTensor)
89OBJECT_DELETER(AclTensorPack, AclDestroyTensorPack)
Georgios Pinitas41648142021-08-03 08:24:00 +010090OBJECT_DELETER(AclOperator, AclDestroyOperator)
Georgios Pinitas8a5146f2021-01-12 15:51:07 +000091
92#undef OBJECT_DELETER
93
94/** Convert a strongly typed enum to an old plain c enum
95 *
96 * @tparam E Plain old C enum
97 * @tparam SE Strongly typed resulting enum
98 *
99 * @param[in] v Value to convert
100 *
101 * @return A corresponding plain old C enumeration
102 */
103template <typename E, typename SE>
104constexpr E as_cenum(SE v) noexcept
105{
106 return static_cast<E>(static_cast<typename std::underlying_type<SE>::type>(v));
107}
108
109/** Convert plain old enumeration to a strongly typed enum
110 *
111 * @tparam SE Strongly typed resulting enum
112 * @tparam E Plain old C enum
113 *
114 * @param[in] val Value to convert
115 *
116 * @return A corresponding strongly typed enumeration
117 */
118template <typename SE, typename E>
119constexpr SE as_enum(E val) noexcept
120{
121 return static_cast<SE>(val);
122}
123
124/** Object base class for library objects
125 *
126 * Class is defining basic common interface for all the library objects
127 *
128 * @tparam T Object type to be templated on
129 */
130template <typename T>
131class ObjectBase
132{
133public:
134 /** Destructor */
135 ~ObjectBase() = default;
136 /** Copy constructor */
137 ObjectBase(const ObjectBase<T> &) = default;
138 /** Move Constructor */
139 ObjectBase(ObjectBase<T> &&) = default;
140 /** Copy assignment operator */
141 ObjectBase<T> &operator=(const ObjectBase<T> &) = default;
142 /** Move assignment operator */
143 ObjectBase<T> &operator=(ObjectBase<T> &&) = default;
144 /** Reset object value
145 *
146 * @param [in] val Value to set
147 */
148 void reset(T *val)
149 {
150 _object.reset(val, detail::ObjectDeleter<T *>::Destroy);
151 }
152 /** Access uderlying object
153 *
154 * @return Underlying object
155 */
156 const T *get() const
157 {
158 return _object.get();
159 }
160 /** Access uderlying object
161 *
162 * @return Underlying object
163 */
164 T *get()
165 {
166 return _object.get();
167 }
168
169protected:
170 /** Constructor */
171 ObjectBase() = default;
172
173protected:
174 std::shared_ptr<T> _object{ nullptr }; /**< Library object */
175};
176
177/** Equality operator for library object
178 *
179 * @tparam T Parameter to template on
180 *
181 * @param[in] lhs Left hand-side argument
182 * @param[in] rhs Right hand-side argument
183 *
184 * @return True if objects are equal, else false
185 */
186template <typename T>
187bool operator==(const ObjectBase<T> &lhs, const ObjectBase<T> &rhs)
188{
189 return lhs.get() == rhs.get();
190}
191
192/** Inequality operator for library object
193 *
194 * @tparam T Parameter to template on
195 *
196 * @param[in] lhs Left hand-side argument
197 * @param[in] rhs Right hand-side argument
198 *
199 * @return True if objects are equal, else false
200 */
201template <typename T>
202bool operator!=(const ObjectBase<T> &lhs, const ObjectBase<T> &rhs)
203{
204 return !(lhs == rhs);
205}
206} // namespace detail
207
208#if defined(ARM_COMPUTE_EXCEPTIONS_ENABLED)
209/** Status class
210 *
211 * Class is an extension of std::exception and contains the underlying
212 * status construct and an error explanatory message to be reported.
213 *
214 * @note Class is visible only when exceptions are enabled during compilation
215 */
216class Status : public std::exception
217{
218public:
219 /** Constructor
220 *
221 * @param[in] status Status returned
222 * @param[in] msg Error message to be bound with the exception
223 */
224 Status(StatusCode status, const std::string &msg)
225 : _status(status), _msg(msg)
226 {
227 }
228 /** Returns an explanatory exception message
229 *
230 * @return Status message
231 */
232 const char *what() const noexcept override
233 {
234 return _msg.c_str();
235 }
236 /** Underlying status accessor
237 *
238 * @return Status code
239 */
240 StatusCode status() const
241 {
242 return _status;
243 }
244 /** Explicit status converter
245 *
246 * @return Status code
247 */
248 explicit operator StatusCode() const
249 {
250 return _status;
251 }
252
253private:
254 StatusCode _status; /**< Status code */
255 std::string _msg; /**< Status message */
256};
257
258/** Reports an error status and throws an exception object in case of failure
259 *
260 * @note This implementation is used when exceptions are enabled during compilation
261 *
262 * @param[in] status Status to report
263 * @param[in] msg Explanatory error messaged
264 *
265 * @return Status code
266 */
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000267static inline void report_status(StatusCode status, const std::string &msg)
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000268{
269 if(status != StatusCode::Success)
270 {
271 throw Status(status, msg);
272 }
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000273}
274#else /* defined(ARM_COMPUTE_EXCEPTIONS_ENABLED) */
275/** Reports a status code
276 *
277 * @note This implementation is used when exceptions are disabled during compilation
278 * @note Message is surpressed and not reported in this case
279 *
280 * @param[in] status Status to report
281 * @param[in] msg Explanatory error messaged
282 *
283 * @return Status code
284 */
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000285static inline void report_status(StatusCode status, const std::string &msg)
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000286{
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000287 ARM_COMPUTE_IGNORE_UNUSED(status);
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000288 ARM_COMPUTE_IGNORE_UNUSED(msg);
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000289}
290#endif /* defined(ARM_COMPUTE_EXCEPTIONS_ENABLED) */
291
292/**< Target enum */
293enum class Target
294{
295 Cpu = AclCpu, /**< Cpu target that leverages SIMD */
296 GpuOcl = AclGpuOcl /**< Gpu target that leverages OpenCL */
297};
298
299/**< Available execution modes */
300enum class ExecutionMode
301{
302 FastRerun = AclPreferFastRerun, /**< Prefer minimum latency in consecutive runs, might introduce higher startup times */
303 FastStart = AclPreferFastStart, /**< Prefer minimizing startup time */
304};
305
306/** Context class
307 *
308 * Context acts as a central aggregate service for further objects created from it.
309 * It provides, internally, common facilities in order to avoid the use of global
310 * statically initialized objects that can lead to important side-effect under
311 * specific execution contexts.
312 *
313 * For example context contains allocators for object creation, for further backing memory allocation,
314 * any serialization interfaces and other modules that affect the construction of objects,
315 * like program caches for OpenCL.
316 */
317class Context : public detail::ObjectBase<AclContext_>
318{
319public:
320 /**< Context options */
321 struct Options
322 {
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000323 static constexpr int32_t num_threads_auto = -1; /**< Allow runtime to specify number of threads */
324
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000325 /** Default Constructor
326 *
327 * @note By default no precision loss is enabled for operators
328 * @note By default the preferred execution mode is to favor multiple consecutive reruns of an operator
329 */
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000330 Options()
331 : Options(ExecutionMode::FastRerun /* mode */,
332 AclCpuCapabilitiesAuto /* caps */,
333 false /* enable_fast_math */,
334 nullptr /* kernel_config */,
335 num_threads_auto /* max_compute_units */,
336 nullptr /* allocator */)
337 {
338 }
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000339 /** Constructor
340 *
341 * @param[in] mode Execution mode to be used
342 * @param[in] caps Capabilities to be used
343 * @param[in] enable_fast_math Allow precision loss in favor of performance
344 * @param[in] kernel_config Kernel configuration file containing construction tuning meta-data
345 * @param[in] max_compute_units Max compute units that are expected to used
346 * @param[in] allocator Allocator to be used for internal memory allocation
347 */
348 Options(ExecutionMode mode,
349 AclTargetCapabilities caps,
350 bool enable_fast_math,
351 const char *kernel_config,
352 int32_t max_compute_units,
353 AclAllocator *allocator)
354 {
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000355 copts.mode = detail::as_cenum<AclExecutionMode>(mode);
356 copts.capabilities = caps;
357 copts.enable_fast_math = enable_fast_math;
358 copts.kernel_config_file = kernel_config;
359 copts.max_compute_units = max_compute_units;
360 copts.allocator = allocator;
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000361 }
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000362
363 AclContextOptions copts{};
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000364 };
365
366public:
367 /** Constructor
368 *
369 * @note Serves as a simpler delegate constructor
370 * @note As context options, default conservative options will be used
371 *
372 * @param[in] target Target to create context for
373 * @param[out] status Status information if requested
374 */
375 explicit Context(Target target, StatusCode *status = nullptr)
376 : Context(target, Options(), status)
377 {
378 }
379 /** Constructor
380 *
381 * @param[in] target Target to create context for
382 * @param[in] options Context construction options
383 * @param[out] status Status information if requested
384 */
385 Context(Target target, const Options &options, StatusCode *status = nullptr)
386 {
387 AclContext ctx;
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000388 const auto st = detail::as_enum<StatusCode>(AclCreateContext(&ctx, detail::as_cenum<AclTarget>(target), &options.copts));
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000389 reset(ctx);
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000390 report_status(st, "[Compute Library] Failed to create context");
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000391 if(status)
392 {
393 *status = st;
394 }
395 }
396};
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000397
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000398/**< Available tuning modes */
399enum class TuningMode
400{
401 Rapid = AclRapid,
402 Normal = AclNormal,
403 Exhaustive = AclExhaustive
404};
405
406/** Queue class
407 *
408 * Queue is responsible for the execution related aspects, with main responsibilities those of
409 * scheduling and tuning operators.
410 *
411 * Multiple queues can be created from the same context, and the same operator can be scheduled on each concurrently.
412 *
413 * @note An operator might depend on the maximum possible compute units that are provided in the context,
414 * thus in cases where the number of the scheduling units of the queue are greater might lead to errors.
415 */
416class Queue : public detail::ObjectBase<AclQueue_>
417{
418public:
419 /**< Queue options */
420 struct Options
421 {
422 /** Default Constructor
423 *
424 * As default options, no tuning will be performed, and the number of scheduling units will
425 * depends on internal device discovery functionality
426 */
427 Options()
428 : opts{ AclTuningModeNone, 0 } {};
429 /** Constructor
430 *
431 * @param[in] mode Tuning mode to be used
432 * @param[in] compute_units Number of scheduling units to be used
433 */
434 Options(TuningMode mode, int32_t compute_units)
435 : opts{ detail::as_cenum<AclTuningMode>(mode), compute_units }
436 {
437 }
438
439 AclQueueOptions opts;
440 };
441
442public:
443 /** Constructor
444 *
445 * @note Serves as a simpler delegate constructor
446 * @note As queue options, default conservative options will be used
447 *
448 * @param[in] ctx Context to create queue for
449 * @param[out] status Status information if requested
450 */
451 explicit Queue(Context &ctx, StatusCode *status = nullptr)
452 : Queue(ctx, Options(), status)
453 {
454 }
455 /** Constructor
456 *
457 * @note As queue options, default conservative options will be used
458 *
459 * @param[in] ctx Context from where the queue will be created from
460 * @param[in] options Queue options to be used
461 * @param[out] status Status information if requested
462 */
463 explicit Queue(Context &ctx, const Options &options = Options(), StatusCode *status = nullptr)
464 {
465 AclQueue queue;
466 const auto st = detail::as_enum<StatusCode>(AclCreateQueue(&queue, ctx.get(), &options.opts));
467 reset(queue);
468 report_status(st, "[Compute Library] Failed to create queue!");
469 if(status)
470 {
471 *status = st;
472 }
473 }
474 /** Block until all the tasks of the queue have been marked as finished
475 *
476 * @return Status code
477 */
478 StatusCode finish()
479 {
480 return detail::as_enum<StatusCode>(AclQueueFinish(_object.get()));
481 }
482};
483
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000484/**< Data type enumeration */
485enum class DataType
486{
487 Unknown = AclDataTypeUnknown,
488 UInt8 = AclUInt8,
489 Int8 = AclInt8,
490 UInt16 = AclUInt16,
491 Int16 = AclInt16,
492 UInt32 = AclUint32,
493 Int32 = AclInt32,
494 Float16 = AclFloat16,
495 BFloat16 = AclBFloat16,
496 Float32 = AclFloat32,
497};
498
499/** Tensor Descriptor class
500 *
501 * Structure that contains all the required meta-data to represent a tensor
502 */
503class TensorDescriptor
504{
505public:
506 /** Constructor
507 *
508 * @param[in] shape Shape of the tensor
509 * @param[in] data_type Data type of the tensor
510 */
511 TensorDescriptor(const std::vector<int32_t> &shape, DataType data_type)
512 : _shape(shape), _data_type(data_type)
513 {
514 _cdesc.ndims = _shape.size();
515 _cdesc.shape = _shape.data();
516 _cdesc.data_type = detail::as_cenum<AclDataType>(_data_type);
517 _cdesc.strides = nullptr;
518 _cdesc.boffset = 0;
519 }
Sang-Hoon Parkc6fcfb42021-03-31 15:18:16 +0100520 /** Constructor
521 *
522 * @param[in] desc C-type descriptor
523 */
524 explicit TensorDescriptor(const AclTensorDescriptor &desc)
525 {
526 _cdesc = desc;
527 _data_type = detail::as_enum<DataType>(desc.data_type);
528 _shape.reserve(desc.ndims);
529 for(int32_t d = 0; d < desc.ndims; ++d)
530 {
531 _shape.emplace_back(desc.shape[d]);
532 }
533 }
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000534 /** Get underlying C tensor descriptor
535 *
536 * @return Underlying structure
537 */
538 const AclTensorDescriptor *get() const
539 {
540 return &_cdesc;
541 }
Sang-Hoon Parkc6fcfb42021-03-31 15:18:16 +0100542 /** Operator to compare two TensorDescriptor
543 *
544 * @param[in] other The instance to compare against
545 *
546 * @return True if two instances have the same shape and data type
547 */
548 bool operator==(const TensorDescriptor &other)
549 {
550 bool is_same = true;
551
552 is_same &= _data_type == other._data_type;
553 is_same &= _shape.size() == other._shape.size();
554
555 if(is_same)
556 {
557 for(uint32_t d = 0; d < _shape.size(); ++d)
558 {
559 is_same &= _shape[d] == other._shape[d];
560 }
561 }
562
563 return is_same;
564 }
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000565
566private:
567 std::vector<int32_t> _shape{};
568 DataType _data_type{};
569 AclTensorDescriptor _cdesc{};
570};
571
572/** Import memory types */
573enum class ImportType
574{
575 Host = AclImportMemoryType::AclHostPtr
576};
577
578/** Tensor class
579 *
580 * Tensor is an mathematical construct that can represent an N-Dimensional space.
581 *
582 * @note Maximum dimensionality support is 6 internally at the moment
583 */
584class Tensor : public detail::ObjectBase<AclTensor_>
585{
586public:
587 /** Constructor
588 *
589 * @note Tensor memory is allocated
590 *
591 * @param[in] ctx Context from where the tensor will be created from
592 * @param[in] desc Tensor descriptor to be used
593 * @param[out] status Status information if requested
594 */
595 Tensor(Context &ctx, const TensorDescriptor &desc, StatusCode *status = nullptr)
596 : Tensor(ctx, desc, true, status)
597 {
598 }
599 /** Constructor
600 *
601 * @param[in] ctx Context from where the tensor will be created from
602 * @param[in] desc Tensor descriptor to be used
603 * @param[in] allocate Flag to indicate if the tensor needs to be allocated
604 * @param[out] status Status information if requested
605 */
606 Tensor(Context &ctx, const TensorDescriptor &desc, bool allocate, StatusCode *status)
607 {
608 AclTensor tensor;
609 const auto st = detail::as_enum<StatusCode>(AclCreateTensor(&tensor, ctx.get(), desc.get(), allocate));
610 reset(tensor);
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000611 report_status(st, "[Compute Library] Failed to create tensor!");
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000612 if(status)
613 {
614 *status = st;
615 }
616 }
617 /** Maps the backing memory of a given tensor that can be used by the host to access any contents
618 *
619 * @return A valid non-zero pointer in case of success else nullptr
620 */
621 void *map()
622 {
623 void *handle = nullptr;
624 const auto st = detail::as_enum<StatusCode>(AclMapTensor(_object.get(), &handle));
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000625 report_status(st, "[Compute Library] Failed to map the tensor and extract the tensor's backing memory!");
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000626 return handle;
627 }
628 /** Unmaps tensor's memory
629 *
630 * @param[in] handle Handle to unmap
631 *
632 * @return Status code
633 */
634 StatusCode unmap(void *handle)
635 {
636 const auto st = detail::as_enum<StatusCode>(AclUnmapTensor(_object.get(), handle));
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000637 report_status(st, "[Compute Library] Failed to unmap the tensor!");
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000638 return st;
639 }
640 /** Import external memory to a given tensor object
641 *
642 * @param[in] handle External memory handle
643 * @param[in] type Type of memory to be imported
644 *
645 * @return Status code
646 */
647 StatusCode import(void *handle, ImportType type)
648 {
649 const auto st = detail::as_enum<StatusCode>(AclTensorImport(_object.get(), handle, detail::as_cenum<AclImportMemoryType>(type)));
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000650 report_status(st, "[Compute Library] Failed to import external memory to tensor!");
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000651 return st;
652 }
Sang-Hoon Parkc6fcfb42021-03-31 15:18:16 +0100653 /** Get the size of the tensor in byte
654 *
655 * @note The size isn't based on allocated memory, but based on information in its descriptor (dimensions, data type, etc.).
656 *
657 * @return The size of the tensor in byte
658 */
659 uint64_t get_size()
660 {
661 uint64_t size{ 0 };
662 const auto st = detail::as_enum<StatusCode>(AclGetTensorSize(_object.get(), &size));
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000663 report_status(st, "[Compute Library] Failed to get the size of the tensor");
Sang-Hoon Parkc6fcfb42021-03-31 15:18:16 +0100664 return size;
665 }
666 /** Get the descriptor of this tensor
667 *
668 * @return The descriptor describing the characteristics of this tensor
669 */
670 TensorDescriptor get_descriptor()
671 {
672 AclTensorDescriptor desc;
673 const auto st = detail::as_enum<StatusCode>(AclGetTensorDescriptor(_object.get(), &desc));
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000674 report_status(st, "[Compute Library] Failed to get the descriptor of the tensor");
Sang-Hoon Parkc6fcfb42021-03-31 15:18:16 +0100675 return TensorDescriptor(desc);
676 }
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000677};
678
679/** Tensor pack class
680 *
681 * Pack is a utility construct that is used to create a collection of tensors that can then
682 * be passed into operator as inputs.
683 */
684class TensorPack : public detail::ObjectBase<AclTensorPack_>
685{
686public:
687 /** Pack pair construct */
688 struct PackPair
689 {
690 /** Constructor
691 *
692 * @param[in] tensor_ Tensor to pack
693 * @param[in] slot_id_ Slot identification of the tensor in respect with the operator
694 */
695 PackPair(Tensor *tensor_, int32_t slot_id_)
696 : tensor(tensor_), slot_id(slot_id_)
697 {
698 }
699
700 Tensor *tensor{ nullptr }; /**< Tensor object */
701 int32_t slot_id{ AclSlotUnknown }; /**< Slot id in respect with the operator */
702 };
703
704public:
705 /** Constructor
706 *
707 * @param[in] ctx Context from where the tensor pack will be created from
708 * @param[out] status Status information if requested
709 */
710 explicit TensorPack(Context &ctx, StatusCode *status = nullptr)
711 {
712 AclTensorPack pack;
713 const auto st = detail::as_enum<StatusCode>(AclCreateTensorPack(&pack, ctx.get()));
714 reset(pack);
Georgios Pinitasc3c352e2021-03-18 10:59:40 +0000715 report_status(st, "[Compute Library] Failure during tensor pack creation");
Georgios Pinitas3f26ef42021-02-23 10:01:33 +0000716 if(status)
717 {
718 *status = st;
719 }
720 }
721 /** Add tensor to tensor pack
722 *
723 * @param[in] slot_id Slot id of the tensor in respect with the operator
724 * @param[in] tensor Tensor to be added in the pack
725 *
726 * @return Status code
727 */
728 StatusCode add(Tensor &tensor, int32_t slot_id)
729 {
730 return detail::as_enum<StatusCode>(AclPackTensor(_object.get(), tensor.get(), slot_id));
731 }
732 /** Add a list of tensors to a tensor pack
733 *
734 * @param[in] packed Pair packs to be added
735 *
736 * @return Status code
737 */
738 StatusCode add(std::initializer_list<PackPair> packed)
739 {
740 const size_t size = packed.size();
741 std::vector<int32_t> slots(size);
742 std::vector<AclTensor> tensors(size);
743 int i = 0;
744 for(auto &p : packed)
745 {
746 slots[i] = p.slot_id;
747 tensors[i] = AclTensor(p.tensor);
748 ++i;
749 }
750 return detail::as_enum<StatusCode>(AclPackTensors(_object.get(), tensors.data(), slots.data(), size));
751 }
752};
Georgios Pinitas06ac6e42021-07-05 08:08:52 +0100753
754/** Operator class
755 *
756 * Operators are the basic algorithmic blocks responsible for performing distinct operations
757 */
758class Operator : public detail::ObjectBase<AclOperator_>
759{
760public:
761 /** Run an operator on a given input list
762 *
763 * @param[in,out] queue Queue to scheduler the operator on
764 * @param pack Tensor list to be used as input
765 *
766 * @return Status Code
767 */
768 StatusCode run(Queue &queue, TensorPack &pack)
769 {
770 return detail::as_cenum<StatusCode>(AclRunOperator(_object.get(), queue.get(), pack.get()));
771 }
772
773protected:
774 /** Constructor */
775 Operator() = default;
776};
Georgios Pinitas41648142021-08-03 08:24:00 +0100777
778/// Operators
779using ActivationDesc = AclActivationDescriptor;
780class Activation : public Operator
781{
782public:
783 Activation(Context &ctx, const TensorDescriptor &src, const TensorDescriptor &dst, const ActivationDesc &desc, StatusCode *status = nullptr)
784 {
785 AclOperator op;
786 const auto st = detail::as_enum<StatusCode>(AclActivation(&op, ctx.get(), src.get(), dst.get(), desc));
787 reset(op);
788 report_status(st, "[Compute Library] Failure during Activation operator creation");
789 if(status)
790 {
791 *status = st;
792 }
793 }
794};
Georgios Pinitas8a5146f2021-01-12 15:51:07 +0000795} // namespace acl
796#undef ARM_COMPUTE_IGNORE_UNUSED
797#endif /* ARM_COMPUTE_ACL_HPP_ */