自学内容网 自学内容网

解决Transformer训练中的AttributeError: ‘AdamW‘ object has no attribute ‘train‘问题

摘要

 本文分享了在使用transformers库进行BERT模型训练时遇到的AttributeError: ‘AdamW’ object has no attribute 'train’错误的解决过程。通过查找相关信息,发现问题源于accelerate库版本过低,并通过将库升级至0.34.2版本成功解决报错。本文详细介绍了问题排查、版本更新的步骤,以及如何忽略更新中的警告提示,以帮助读者快速解决类似问题。

报错信息描述

在使用 transformers 库的 Trainer 训练 BERT 模型时,遇到了以下报错信息:

File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train
    return inner_training_loop(
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 3477, in training_step
    self.optimizer.train()
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/optimizer.py", line 128, in train
    return self.optimizer.train()
AttributeError: 'AdamW' object has no attribute 'train'
Traceback (most recent call last):
  File "/home/jie/gitee/pku_industry/general/pipeline.py", line 202, in <module>
    run("optical_communication_laser")
  File "/home/jie/gitee/pku_industry/general/pipeline.py", line 97, in run
    bert_cls.train(5)
  File "/home/jie/gitee/pku_industry/general/bert_train.py", line 108, in train
    self.trainer.train()
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train
    return inner_training_loop(
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 3477, in training_step
    self.optimizer.train()
  File "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/optimizer.py", line 128, in train
    return self.optimizer.train()
AttributeError: 'AdamW' object has no attribute 'train'

因为之前这段代码是可以正常运行的,所以我怀疑问题可能与某些库的版本更新有关。

Bug修复过程

我使用搜索引擎查询了报错信息,尝试找到解决方案。友情提醒,如果使用百度搜索,可能很难找到有用的信息,因为对于这类专业性较强的问题,百度的表现还有待提升。百度加油!

百度搜索
点击查看:GitHub 上的相关 issue

在一个 GitHub issue 中,有人提到了需要将 accelerate 库更新到 0.34.2 版本,解决这个问题。

GitHub Issue

检查当前库的版本

使用以下命令查看当前安装的 accelerate 库版本:

pip show accelerate

显示版本

发现当前版本低于 0.34.2,所以需要进行更新。

更新 accelerate

使用以下命令将 accelerate 库更新到 0.34.2

pip install accelerate==0.34.2

在更新过程中,可能会出现一些警告信息,不过这些警告可以忽略。

忽略警告

验证更新结果

更新完成后,重新运行代码,问题已经解决,程序可以正常执行了。


这个过程表明,部分依赖库的更新可能会引入不兼容的改动。定期检查并更新项目的依赖项,可以避免遇到类似问题。希望这篇博客能帮助大家解决类似的报错问题。


原文地址:https://blog.csdn.net/sjxgghg/article/details/142973352

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!