自学内容网 自学内容网

thop计算模型复杂度(params,flops)

 thop安装

-pip install thop在线安装失败

-离线安装

github网址: 

pytorch-OpCounter:Count the MACs / FLOPs of your PyTorch model. - GitCode

python setup.py install

 测试:

from options import config as c
import os
os.environ["CUDA_VISIBLE_DEVICES"] = c.os_environ
import torch.nn
from modules.NET import Net
from utils.utils import load
from utils.yml import parse_yml, dict_to_nonedict
import numpy as np
from thop import profile
from modules.DCTGate_fast import DCT_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------noise set------------------------------
# read noise_config
yml_path = c.noise_opt_yml_path
option_yml = parse_yml(yml_path)
# convert to NoneDict, which returns None for missing keys
noise_opt = dict_to_nonedict(option_yml)

# -------------------MODEL load---------------------------
model_path = c.load_model_path
net = Net(noise_opt, device).to(device)
load(model_path, net)

# -----------------MODEL input-----------------------
cover = torch.randn(1, 3, c.cropsize_val, c.cropsize_val).to(device)
secret = torch.Tensor(np.random.choice([-0.5, 0.5], (cover.shape[0], c.input_message_length))).to(device)

dct_trans = DCT_transform(image_size=c.cropsize_val, block_size=8).to(device)
cover_dct = dct_trans(cover)

# ------------cal: params, FLOPS-----------
flops, params = profile(net, (cover, secret, cover_dct))
print(f'\nflops: {flops}, params: {params}\n')
print('the flops is {}G, the params is {}M\n'.format(round(flops / (10 ** 9), 2), round(params / (10 ** 6), 2)))




原文地址:https://blog.csdn.net/qq_43855258/article/details/142442806

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!