blob: af09153b2d96f0a4d06a6ca0ab1ffe844f110940 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_SCATTER_GATHER_H
17#define OPS_SCATTER_GATHER_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
Kevin Cheng77d0f762020-11-24 10:26:32 -080026template <DType Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070027class OpGather : public GraphNode
28{
29public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~OpGather();
32
33 virtual int checkTensorAttributes();
34 virtual int eval();
35
Kevin Cheng77d0f762020-11-24 10:26:32 -080036 using EigenType = typename GetEigenType<Dtype>::type;
37 using TValue = Eigen::Tensor<EigenType, 3>;
38 using TIndex = Eigen::Tensor<int32_t, 2>;
39 using TOutput = Eigen::Tensor<EigenType, 3>;
Eric Kunzee5e26762020-10-13 16:11:07 -070040
41protected:
Kevin Cheng77d0f762020-11-24 10:26:32 -080042 int32_t N, W, K, C;
43 TosaReference::TensorTemplate<TValue>* values;
44 TosaReference::TensorTemplate<TIndex>* indices;
45 TosaReference::TensorTemplate<TOutput>* output;
46};
47
48template <DType Dtype>
49class OpScatter : public GraphNode
50{
51public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000052 OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Kevin Cheng77d0f762020-11-24 10:26:32 -080053 virtual ~OpScatter();
54
55 virtual int checkTensorAttributes();
56 virtual int eval();
57
58 using EigenType = typename GetEigenType<Dtype>::type;
59 using TValue = Eigen::Tensor<EigenType, 3>;
60 using TIndex = Eigen::Tensor<int32_t, 2>;
61 using TOutput = Eigen::Tensor<EigenType, 3>;
62
63protected:
64 int32_t N, W, K, C;
65 TosaReference::TensorTemplate<TValue>* values_in;
66 TosaReference::TensorTemplate<TIndex>* indices;
67 TosaReference::TensorTemplate<TValue>* input;
68 TosaReference::TensorTemplate<TOutput>* values_out;
Eric Kunzee5e26762020-10-13 16:11:07 -070069};
70
71}; // namespace TosaReference
72
73#endif