自学内容网 自学内容网

AIGC笔记--Classifer Guidance的代码理解

1--完整可debug代码

Simple_Classifer_Guide

2--核心代码

import torch as th
import torch.nn.functional as F

from tools.unet import EncoderUNetModel

# 核心函数
def cond_fn(x, t, y = None, classifier = None): 
    # x.shape: [batch_size, 3, 64, 64]
    # t.shape: [batch_size] 
    # y.shape: [batch_size]
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t) # [batch_size, 1000] # classifier本质上是EncoderUNetModel
        log_probs = F.log_softmax(logits, dim = -1) # [batch_size, 1000]
        selected = log_probs[range(len(logits)), y.view(-1)] # [batch_size]
        classifier_scale = 1.0 # set by ljf
        return th.autograd.grad(selected.sum(), x_in)[0] * classifier_scale # th.autograd.grad(y, x)计算y对x的导数
        # return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale # [batch_size, 3, 64, 64]

def condition_mean(cond_fn, p_mean_var, x, t, model_kwargs = None, classifier = None):
    # gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) # origin version
    gradient = cond_fn(x, t, y = model_kwargs, classifier = classifier) # [batch_size, 3, 64, 64] set by ljf
    new_mean = (
        p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() # [batch_size, 3, 64, 64]
    )
    return new_mean

def p_sample(x, t, cond_fn = None, model_kwargs = None, classifier = None):
        # origin version
        # out = self.p_mean_variance(
        #     model,
        #     x,
        #     t,
        #     clip_denoised=clip_denoised,
        #     denoised_fn=denoised_fn,
        #     model_kwargs=model_kwargs,
        # )

        # 这里使用random的方法模拟p_mean_variance的过程
        out = {}
        out["mean"] = th.rand(x.shape)
        out["variance"] = th.rand(x.shape)
        out["log_variance"] = th.rand(x.shape)
        out["pred_xstart"] = th.rand(x.shape)

        noise = th.randn_like(x)
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        if cond_fn is not None:
            out["mean"] = condition_mean(
                cond_fn, out, x, t, model_kwargs = model_kwargs, classifier = classifier
            )
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise # [batch_size, 3, 64, 64]
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}


if __name__ == "__main__":
    batch_size = 2
    channel = 3
    height = 64
    width = 64
    x = th.rand(batch_size, channel, height, width) # 随机生成输入
    t = th.tensor([999]).repeat(batch_size) # 随机模拟time steps
    y = th.randint(low = 0, high = 999, size = (2,)) # 模拟类别
    classifier = EncoderUNetModel(
        image_size = 64,
        in_channels = 3,
        model_channels = 128,
        out_channels = 1000, # 1000个类别
        num_res_blocks = 3,
        attention_resolutions = [32,16,8])
    
    output = p_sample(x = x, t = t, model_kwargs = y, cond_fn = cond_fn, classifier = classifier) # 传入cond_fn函数和分类器classifier
    assert list(output['sample'].shape) == [batch_size, channel, height, width]
    print("output['sample'].shape: ", output['sample'].shape)
    print("All Done!")


原文地址:https://blog.csdn.net/weixin_43863869/article/details/140493373

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!