Simon Obute | 51f6777 | 2021-09-03 15:50:13 +0100 | [diff] [blame] | 1 | // |
Mike Kelly | 7cbe781 | 2023-07-25 17:37:33 +0100 | [diff] [blame] | 2 | // Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved. |
Simon Obute | 51f6777 | 2021-09-03 15:50:13 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
Simon Obute | 51f6777 | 2021-09-03 15:50:13 +0100 | [diff] [blame] | 6 | #include <armnn/backends/ITensorHandleFactory.hpp> |
| 7 | #include <armnnUtils/Transpose.hpp> |
| 8 | #include "RefChannelShuffleWorkload.hpp" |
| 9 | #include "RefWorkloadUtils.hpp" |
| 10 | #include "Profiling.hpp" |
| 11 | #include "Decoders.hpp" |
| 12 | #include "Encoders.hpp" |
| 13 | |
| 14 | namespace armnn |
| 15 | { |
| 16 | void RefChannelShuffleWorkload::Execute() const |
| 17 | { |
| 18 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 19 | } |
| 20 | |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 21 | void RefChannelShuffleWorkload::ExecuteAsync(ExecutionData& executionData) |
Simon Obute | 51f6777 | 2021-09-03 15:50:13 +0100 | [diff] [blame] | 22 | { |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 23 | WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); |
| 24 | Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); |
Simon Obute | 51f6777 | 2021-09-03 15:50:13 +0100 | [diff] [blame] | 25 | } |
| 26 | |
| 27 | // Reference implementation for channel shuffle taken from |
| 28 | // https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/master/nn/common/operations/ChannelShuffle.cpp |
| 29 | void RefChannelShuffleWorkload::Execute(std::vector<ITensorHandle*> inputs, |
| 30 | std::vector<ITensorHandle*> outputs) const |
| 31 | { |
Mike Kelly | 7cbe781 | 2023-07-25 17:37:33 +0100 | [diff] [blame] | 32 | ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefChannelShuffleWorkload_Execute"); |
Simon Obute | 51f6777 | 2021-09-03 15:50:13 +0100 | [diff] [blame] | 33 | |
| 34 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 35 | const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); |
| 36 | std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
| 37 | Decoder<float>& decoder = *decoderPtr; |
| 38 | |
| 39 | std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 40 | Encoder<float>& encoder = *encoderPtr; |
| 41 | |
| 42 | auto getNumberOfElements = [](const TensorShape& tensorShape,uint32_t startAxis, uint32_t lastAxis) |
| 43 | { |
| 44 | uint32_t count = 1; |
| 45 | for (uint32_t i = startAxis; i < lastAxis; i++) |
| 46 | { |
| 47 | count *= tensorShape[i]; |
| 48 | } |
| 49 | return count; |
| 50 | }; |
| 51 | const TensorShape tensorShape = GetTensorInfo(inputs[0]).GetShape(); |
| 52 | uint32_t channelsAxis = m_Data.m_Parameters.m_Axis; // channelsAxis to perform channel shuffle on |
| 53 | |
| 54 | const uint32_t numGroups = m_Data.m_Parameters.m_NumGroups; |
| 55 | const uint32_t groupSize = tensorShape[channelsAxis] / numGroups; |
| 56 | |
| 57 | uint32_t outerSize = getNumberOfElements(tensorShape, 0, channelsAxis); |
| 58 | uint32_t innerSize = getNumberOfElements(tensorShape, channelsAxis + 1, tensorShape.GetNumDimensions()); |
| 59 | |
| 60 | for (uint32_t outer = 0; outer < outerSize; ++outer) |
| 61 | { |
| 62 | for (uint32_t inner = 0; inner < innerSize; ++inner) |
| 63 | { |
| 64 | uint32_t decoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner; |
| 65 | decoder += decoderStep1; |
| 66 | uint32_t encoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner; |
| 67 | encoder += encoderStep1; |
| 68 | for (uint32_t i = 0; i < groupSize; i++) |
| 69 | { |
| 70 | for (uint32_t j = 0; j < numGroups; j++, encoder += innerSize, encoderStep1 += innerSize) |
| 71 | { |
| 72 | decoder += innerSize * (i + j * groupSize); |
| 73 | float decoded = decoder.Get(); |
| 74 | encoder.Set(decoded); |
| 75 | decoder -= innerSize * (i + j * groupSize); |
| 76 | } |
| 77 | } |
| 78 | decoder -= decoderStep1; |
| 79 | encoder -= encoderStep1; |
| 80 | } |
| 81 | } |
| 82 | } |
| 83 | } |