自学内容网 自学内容网

onnx代码解读

一、定义

  1. torch.jit.trace 相关代码解读
  2. onnx 内部实现
    3 查看是否为aten 算子
  3. aten 算子实现
  4. torch.autograd.Functions 算子实现
  5. 自定义算子实现
  6. 查找未实现的节点
  7. 一次性发现所有的未实现 aten 算子

二、实现

  1. torch.jit.trace 相关代码解读
    1. torch.jit.script() : 将其转换为可运行的脚本。转换后的脚本可以像普通的 Python 函数一样调用,也可以保存到磁盘并在没有 PyTorch 依赖的环境中执行。
    2. torch.jit.trace : 跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。

3 查看是否为aten 算子

import torch

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

算子追踪,在这里插入图片描述
3. aten 算子实现
  1.查看torch 接口定义    torch/nn/functional.pyi
  2.查看onnx 算子命名    https://github.com/onnx/onnx/blob/main/docs/Operators.md
  3. 查看注册函数书写   symbolic_opset9.py

import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.

#def relu(input: Tensor) -> Tensor: ...   查看接口定义,
def correct_relu_symbolic_function(g, input):
    return g.op("Relu", input)             #查看onnx 实现

torch.onnx.register_custom_op_symbolic(     #注册
    "aten::relu",
    correct_relu_symbolic_function,
    opset_version=opset_version,
)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)
  1. torch.autograd.Functions 算子实现
    如果算子是torch.autograd.Functions 的子模块,可以使用该方法实现。
import torch

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))


import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15

myrelu = MyRelu.apply        #核心
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.Linear(4, 5),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return myrelu(self.layers(x))

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)
  1. 自定义算子实现
    1. onnx 算子实现

    1. 自定义c++ 算子 +Extending TorchScript with Custom C++ Operators 实现
  2. 查找未实现的节点

import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.   注册函数错误,导致find_mismatch 算子
def incorrect_relu_symbolic_function(g, self):
    return self
torch.onnx.register_custom_op_symbolic(
    "aten::relu",
    incorrect_relu_symbolic_function,
    opset_version=opset_version,
)
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)

#===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X   __2 X    __1 \u2713
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 \u2713
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 \u2713
                           id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}

修改后:
aten 算子实现

import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.

#def relu(input: Tensor) -> Tensor: ...   查看接口定义,
def correct_relu_symbolic_function(g, input):
    return g.op("Relu", input)             #查看onnx 实现

torch.onnx.register_custom_op_symbolic(     #注册
    "aten::relu",
    correct_relu_symbolic_function,
    opset_version=opset_version,
)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)

方式二、
c++ 自定义算子


import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15


from torch.onnx import register_custom_op_symbolic        # 为 TorchScript 算子补充注册符号函数
from torch.onnx.symbolic_helper import parse_args
# '''
# 装饰器 @parse_args 了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如"v"表示 Torch 库里的 value 类型,
# 一般用于标注张量,而"i"表示 int 类型,"f"表示 float 类型,"none"表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py
# '''
@parse_args("v", "v")
def correct_relu_symbolic_function(g,input):
    return g.op("Relu", input)


torch.onnx.register_custom_op_symbolic(     #注册
    "aten::relu",
    correct_relu_symbolic_function,
    opset_version=opset_version,
)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)

  1. 一次性发现所有的未实现 aten 算子
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)


torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    Model(), (torch.randn(2, 3),), opset_version=opset_version
)

print(set(unconvertible_ops))

原文地址:https://blog.csdn.net/weixin_40777649/article/details/142858767

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!