Chapter 6.2-Preparing the dataset
Chapter 6 -Fine-tuning for classification
6.2-Preparing the dataset
-
如下图所示,分类微调 LLM 的三阶段过程
-
数据集准备。
-
模型设置。
-
微调和评估模型。
-
-
本节准备用于分类微调的数据集。我们使用一个包含垃圾邮件和非垃圾邮件文本的数据集,对大语言模型(LLM)进行微调,以对其进行分类。首先,我们下载并解压数据集。
import urllib.request import zipfile import os from pathlib import Path url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip" zip_path = "sms_spam_collection.zip" extracted_path = "sms_spam_collection" data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): if data_file_path.exists(): print(f"{data_file_path} already exists. Skipping download and extraction.") return # Downloading the file with urllib.request.urlopen(url) as response: with open(zip_path, "wb") as out_file: out_file.write(response.read()) # Unzipping the file with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extracted_path) # Add .tsv file extension original_file_path = Path(extracted_path) / "SMSSpamCollection" os.rename(original_file_path, data_file_path) print(f"File downloaded and saved as {data_file_path}") download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
导入csv文件
import pandas as pd download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) df = pd.read_csv(data_file_path, sep='\t', header=None, names=["Label", "Text"]) df
当我们检查类分布时,我们看到数据包含“ham”(即“not spam”)的频率比“spam”高得多
print(df["Label"].value_counts()) """输出""" Label ham 4825 spam 747 Name: count, dtype: int64
处于快速微调大模型考虑,对数据集进行下采样(处理类平衡的方法之一),让每个类别包含出747个实例
def create_balanced_dataset(df): # Count the instances of "spam" num_spam = df[df["Label"] == "spam"].shape[0] # Randomly sample "ham" instances to match the number of "spam" instances ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) # Combine ham "subset" with "spam" balanced_df = pd.concat([ham_subset , df[df["Label"] == "spam"]] , ignore_index=True ) return balanced_df balanced_df = create_balanced_dataset(df) print(balanced_df["Label"].value_counts()) """输出""" Label ham 747 spam 747 Name: count, dtype: int64
接下来,我们将字符串类标签“ham”和“spam”更改为整数类标签0和1:
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) balanced_df
现在让我们定义一个函数,将数据集随机划分为训练、验证和测试子集,70%用于训练,10%用于验证,20%用于测试
def random_split(df, train_frac, validation_frac): # Shuffle the entire DataFrame df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Calculate split indices train_end = int(len(df) * train_frac) validation_end = train_end + int(len(df) * validation_frac) # Split the DataFrame train_df = df[:train_end] validation_df = df[train_end:validation_end] test_df = df[validation_end:] return train_df, validation_df, test_df train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1) # Test size is implied to be 0.2 as the remainder train_df.to_csv("train.csv", index=None) validation_df.to_csv("validation.csv", index=None) test_df.to_csv("test.csv", index=None)
我们已经下载了数据集,对其进行类别平衡并拆分为训练验证测试集。
6.3-Creating data loaders
原文地址:https://blog.csdn.net/hbkybkzw/article/details/145268094
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!