blob: aeb9f1d4e4001f3b72f97a4815967db3706f19ad [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Lya4d748b2023-03-28 22:06:56 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_REDUCTION_H
17#define OPS_REDUCTION_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
Tai Lya4d748b2023-03-28 22:06:56 +000026template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070027class ReduceNode : public GraphNode
28{
29public:
Kevin Chengacb550f2021-06-29 15:32:19 -070030 ReduceNode(SubgraphTraverser* sgt_, const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070031 virtual ~ReduceNode();
32 virtual int checkTensorAttributes();
33 virtual int eval() = 0;
34
35 using InEigenType = typename GetEigenType<Dtype>::type;
36 using OutEigenType = typename GetEigenType<Dtype>::type;
37 using TIn = Eigen::Tensor<InEigenType, Rank>;
38 using TOut = Eigen::Tensor<OutEigenType, Rank>;
39
40protected:
41 Eigen::array<int, 1> dims;
42 TosaReference::TensorTemplate<TIn>* in;
43 TosaReference::TensorTemplate<TOut>* out;
44 TosaAxisAttribute* attribute;
45};
46
Tai Lya4d748b2023-03-28 22:06:56 +000047template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070048class OpReduceAll : public ReduceNode<Rank, Dtype>
49{
50public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000051 OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070052 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070053 {}
54 virtual int eval();
55};
56
Tai Lya4d748b2023-03-28 22:06:56 +000057template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070058class OpReduceAny : public ReduceNode<Rank, Dtype>
59{
60public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070062 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070063 {}
64 virtual int eval();
65};
66
Tai Lya4d748b2023-03-28 22:06:56 +000067template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070068class OpReduceMax : public ReduceNode<Rank, Dtype>
69{
70public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000071 OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070072 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MAX, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070073 {}
74 virtual int eval();
75};
76
Tai Lya4d748b2023-03-28 22:06:56 +000077template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070078class OpReduceMin : public ReduceNode<Rank, Dtype>
79{
80public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000081 OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070082 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MIN, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070083 {}
84 virtual int eval();
85};
86
Tai Lya4d748b2023-03-28 22:06:56 +000087template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070088class OpReduceProduct : public ReduceNode<Rank, Dtype>
89{
90public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070092 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070093 {}
94 virtual int eval();
95};
96
Tai Lya4d748b2023-03-28 22:06:56 +000097template <int Rank, TOSA_REF_TYPE Dtype>
98class OpReduceProductDouble : public ReduceNode<Rank, Dtype>
99{
100public:
101 OpReduceProductDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
102 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_)
103 {}
104 virtual int eval();
105};
106
107template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700108class OpReduceSum : public ReduceNode<Rank, Dtype>
109{
110public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000111 OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700112 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 {}
114 virtual int eval();
115};
116
Tai Lya4d748b2023-03-28 22:06:56 +0000117template <int Rank, TOSA_REF_TYPE Dtype>
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100118class OpReduceSumInt : public ReduceNode<Rank, Dtype>
119{
120public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000121 OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100122 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
123 {}
124 virtual int eval();
125};
126
Tai Lya4d748b2023-03-28 22:06:56 +0000127template <int Rank, TOSA_REF_TYPE Dtype>
128class OpReduceSumDouble : public ReduceNode<Rank, Dtype>
129{
130public:
131 OpReduceSumDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
132 : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
133 {}
134 virtual int eval();
135};
136
Eric Kunzee5e26762020-10-13 16:11:07 -0700137}; // namespace TosaReference
138
139#endif