Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 1 | # Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | # Description: |
| 17 | # Functions used to write to a raw format (.npz) file. |
| 18 | import numpy as np |
| 19 | |
| 20 | from .high_level_command_to_npu_op import get_region |
| 21 | from .nn_graph import PassPlacement |
| 22 | from .operation import Op |
| 23 | |
| 24 | |
| 25 | def write_rawdata_output(nng, arch, filename): |
| 26 | subgraphs_to_write = [sg for sg in nng.subgraphs if sg.placement == PassPlacement.Cpu] |
| 27 | |
| 28 | for sg_idx, sg in enumerate(subgraphs_to_write): |
| 29 | custom_op = None |
| 30 | for ps in sg.passes: |
| 31 | for op in ps.ops: |
| 32 | if op.type == Op.CustomNpuOp: |
| 33 | custom_op = op |
| 34 | break |
| 35 | if custom_op: |
| 36 | break |
| 37 | |
| 38 | if custom_op: |
| 39 | ifm_shapes = [] |
Diqing Zhong | e3d18b0 | 2021-11-15 13:53:10 +0100 | [diff] [blame] | 40 | ifm_elem_sizes = [] |
Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 41 | ifm_regions = [] |
| 42 | ifm_offsets = [] |
| 43 | ofm_shapes = [] |
Diqing Zhong | e3d18b0 | 2021-11-15 13:53:10 +0100 | [diff] [blame] | 44 | ofm_elem_sizes = [] |
Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 45 | ofm_regions = [] |
| 46 | ofm_offsets = [] |
| 47 | cmd_stream_tensor, weight_tensor, scratch_tensor, scratch_fast_tensor = custom_op.inputs[:4] |
| 48 | weight_region = get_region(weight_tensor.mem_type, arch) |
| 49 | scratch_region = get_region(scratch_tensor.mem_type, arch) |
| 50 | scratch_fast_region = get_region(scratch_fast_tensor.mem_type, arch) |
| 51 | for ifm in custom_op.inputs[4:]: |
| 52 | ifm_shapes.append(ifm.shape) |
| 53 | ifm_regions.append(get_region(ifm.mem_type, arch)) |
| 54 | ifm_offsets.append(ifm.address) |
Diqing Zhong | e3d18b0 | 2021-11-15 13:53:10 +0100 | [diff] [blame] | 55 | ifm_elem_sizes.append(ifm.element_size()) |
Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 56 | for ofm in custom_op.outputs: |
| 57 | ofm_shapes.append(ofm.shape) |
| 58 | ofm_regions.append(get_region(ofm.mem_type, arch)) |
| 59 | ofm_offsets.append(ofm.address) |
Diqing Zhong | e3d18b0 | 2021-11-15 13:53:10 +0100 | [diff] [blame] | 60 | ofm_elem_sizes.append(ofm.element_size()) |
Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 61 | |
| 62 | filename_sg = f"{filename}_sg{sg_idx}_vela.npz" |
| 63 | np.savez( |
| 64 | filename_sg, |
| 65 | cmd_data=cmd_stream_tensor.values, |
| 66 | weight_data=weight_tensor.values, |
| 67 | weight_region=weight_region, |
| 68 | scratch_shape=scratch_tensor.shape, |
| 69 | scratch_region=scratch_region, |
| 70 | scratch_fast_shape=scratch_fast_tensor.shape, |
| 71 | scratch_fast_region=scratch_fast_region, |
| 72 | input_shape=ifm_shapes, |
Diqing Zhong | e3d18b0 | 2021-11-15 13:53:10 +0100 | [diff] [blame] | 73 | input_elem_size=ifm_elem_sizes, |
Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 74 | input_region=ifm_regions, |
| 75 | input_offset=ifm_offsets, |
| 76 | output_shape=ofm_shapes, |
Diqing Zhong | e3d18b0 | 2021-11-15 13:53:10 +0100 | [diff] [blame] | 77 | output_elem_size=ofm_elem_sizes, |
Diqing Zhong | 5e5a784 | 2021-08-16 17:24:09 +0200 | [diff] [blame] | 78 | output_region=ofm_regions, |
| 79 | output_offset=ofm_offsets, |
| 80 | ) |