自学内容网 自学内容网

PyTorch Lightning Callback介绍

PyTorch Lightning Callback 介绍

在 PyTorch 中,callbacks(回调函数)不是原生支持的核心功能,但在深度学习中非常常见,尤其是用来监控训练过程、调整超参数或执行特定的任务。许多高级深度学习框架(如 PyTorch Lightning 和 FastAI)都基于 PyTorch,并内置了 callback 支持。

PyTorch Lightning 提供了一个易于扩展的回调机制,允许用户在训练过程中插入自定义逻辑。回调类继承自 pytorch_lightning.callbacks.Callback,可以覆盖以下方法:

常用方法
  • on_fit_start: 在训练(fit)开始时调用。
  • on_fit_end: 在训练(fit)结束时调用。
  • on_train_epoch_start: 在每个训练 epoch 开始时调用。
  • on_train_epoch_end: 在每个训练 epoch 结束时调用。
  • on_validation_epoch_start: 在每个验证 epoch 开始时调用。
  • on_validation_epoch_end: 在每个验证 epoch 结束时调用。
  • on_test_epoch_start: 在测试 epoch 开始时调用。
  • on_test_epoch_end: 在测试 epoch 结束时调用。
  • on_train_batch_end: 在每个训练 batch 结束时调用。
  • on_validation_batch_end: 在每个验证 batch 结束时调用。
  • on_test_batch_end: 在每个测试 batch 结束时调用。

示例: 自定义 Callback

以下示例实现了一个打印日志的回调:

from pytorch_lightning.callbacks import Callback

class PrintCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Training ended!")

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Validation ended!")

使用时将回调传递给 Trainer

from pytorch_lightning import Trainer

trainer = Trainer(callbacks=[PrintCallback()])

基于 Hydra 配置实例化 Callback

Hydra 是一个灵活的配置管理工具,常用于深度学习项目中动态管理超参数。通过结合 Hydra 和 PyTorch Lightning,可以动态配置并实例化 Callback。

步骤:

1. 安装 Hydra

pip install hydra-core --upgrade

2. 定义 Hydra 配置文件: 创建一个 YAML 配置文件(如 config.yaml)来管理 Callback 的配置:

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val_loss"
    save_top_k: 1
    mode: "min"

  early_stopping:
    _target_: pytorch_lightning.callbacks.EarlyStopping
    monitor: "val_loss"
    patience: 5
    mode: "min"

3. 在代码中动态实例化: 使用 hydra.utils.instantiate 方法实例化回调对象:

import hydra
from hydra.utils import instantiate
from pytorch_lightning import Trainer
from omegaconf import OmegaConf

@hydra.main(config_path=".", config_name="config")
def main(cfg):
    # Instantiate callbacks from config
    callbacks = [instantiate(cfg.callbacks[key]) for key in cfg.callbacks]

    # Example: Define a simple PyTorch Lightning model
    from pytorch_lightning import LightningModule
    import torch.nn.functional as F

    class SimpleModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(10, 1)

        def forward(self, x):
            return self.layer(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.mse_loss(y_hat, y)
            return loss

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.001)

    # Instantiate trainer
    trainer = Trainer(callbacks=callbacks, max_epochs=10)

    # Simulated data loader
    from torch.utils.data import DataLoader, TensorDataset
    import torch

    x = torch.rand(100, 10)
    y = torch.rand(100, 1)
    train_loader = DataLoader(TensorDataset(x, y), batch_size=32)

    model = SimpleModel()
    trainer.fit(model, train_loader)

if __name__ == "__main__":
    main()
解释:如何通过配置文件动态管理 Callback
  1. 配置文件中,_target_ 指定回调类的完整路径。
  2. 使用 hydra.utils.instantiate 根据配置动态实例化对象。
  3. 将实例化后的回调传递给 Trainer
优势
  1. 动态配置:通过 YAML 文件可以快速更改回调逻辑而无需修改代码。
  2. 模块化管理:方便管理多个回调类,清晰直观。
  3. 灵活性:支持自定义 Callback 和 Lightning 内置回调的结合使用。

此方法适用于多种场景,比如动态调整模型保存路径、早停策略等。


原文地址:https://blog.csdn.net/qq_27390023/article/details/144748341

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!