blob: c23291d06a225b5e9df2cfcf556257757cd6d45a [file] [log] [blame]
Simon Obute51f67772021-09-03 15:50:13 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
Simon Obute51f67772021-09-03 15:50:13 +01003// SPDX-License-Identifier: MIT
4//
5
Simon Obute51f67772021-09-03 15:50:13 +01006#include <armnn/backends/ITensorHandleFactory.hpp>
7#include <armnnUtils/Transpose.hpp>
8#include "RefChannelShuffleWorkload.hpp"
9#include "RefWorkloadUtils.hpp"
10#include "Profiling.hpp"
11#include "Decoders.hpp"
12#include "Encoders.hpp"
13
14namespace armnn
15{
16void RefChannelShuffleWorkload::Execute() const
17{
18 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
19}
20
Matthew Sloyan2d213a72022-06-30 17:13:04 +010021void RefChannelShuffleWorkload::ExecuteAsync(ExecutionData& executionData)
Simon Obute51f67772021-09-03 15:50:13 +010022{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010023 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
24 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Simon Obute51f67772021-09-03 15:50:13 +010025}
26
27// Reference implementation for channel shuffle taken from
28// https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/master/nn/common/operations/ChannelShuffle.cpp
29void RefChannelShuffleWorkload::Execute(std::vector<ITensorHandle*> inputs,
30 std::vector<ITensorHandle*> outputs) const
31{
Mike Kelly7cbe7812023-07-25 17:37:33 +010032 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefChannelShuffleWorkload_Execute");
Simon Obute51f67772021-09-03 15:50:13 +010033
34 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
35 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
36 std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo, inputs[0]->Map());
37 Decoder<float>& decoder = *decoderPtr;
38
39 std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
40 Encoder<float>& encoder = *encoderPtr;
41
42 auto getNumberOfElements = [](const TensorShape& tensorShape,uint32_t startAxis, uint32_t lastAxis)
43 {
44 uint32_t count = 1;
45 for (uint32_t i = startAxis; i < lastAxis; i++)
46 {
47 count *= tensorShape[i];
48 }
49 return count;
50 };
51 const TensorShape tensorShape = GetTensorInfo(inputs[0]).GetShape();
52 uint32_t channelsAxis = m_Data.m_Parameters.m_Axis; // channelsAxis to perform channel shuffle on
53
54 const uint32_t numGroups = m_Data.m_Parameters.m_NumGroups;
55 const uint32_t groupSize = tensorShape[channelsAxis] / numGroups;
56
57 uint32_t outerSize = getNumberOfElements(tensorShape, 0, channelsAxis);
58 uint32_t innerSize = getNumberOfElements(tensorShape, channelsAxis + 1, tensorShape.GetNumDimensions());
59
60 for (uint32_t outer = 0; outer < outerSize; ++outer)
61 {
62 for (uint32_t inner = 0; inner < innerSize; ++inner)
63 {
64 uint32_t decoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
65 decoder += decoderStep1;
66 uint32_t encoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
67 encoder += encoderStep1;
68 for (uint32_t i = 0; i < groupSize; i++)
69 {
70 for (uint32_t j = 0; j < numGroups; j++, encoder += innerSize, encoderStep1 += innerSize)
71 {
72 decoder += innerSize * (i + j * groupSize);
73 float decoded = decoder.Get();
74 encoder.Set(decoded);
75 decoder -= innerSize * (i + j * groupSize);
76 }
77 }
78 decoder -= decoderStep1;
79 encoder -= encoderStep1;
80 }
81 }
82}
83}