Chapter 6.3-Creating data loaders
Chapter 6 -Fine-tuning for classification
6.3-Creating data loaders
-
注意文本消息有不同的长度;如果我们想将多个训练示例组合在一个批次中,我们必须
- 将所有消息截断为数据集或批处理中最短消息的长度
- 将所有消息填充到数据集或批处理中最长消息的长度
第一个选项在计算上更便宜,但是如果较短的消息比平均或最长的消息小得多,它可能会导致不显著的信息丢失,这可能会降低模型性能。因此,我们选择第二个选项,它保留所有消息的全部内容。为了实现批处理,将所有消息填充到数据集中最长消息的长度,我们将paddingTokens添加到所有较短的消息中。为此,我们使用“<|endoftext|>"作为填充标记。
但是,我们可以将与“<|endoftext|>”对应的令牌ID添加到编码的文本消息中,而不是直接将字符串“<|endoftext|>附加到每个文本消息中,如图下图所示。
50256是填充令牌“<|endoftext|>”的令牌ID。我们可以通过使用我们之前使用的TikToker包中的GPT-2令牌器对“<|endoftext|>”进行编码来仔细检查令牌ID是否正确
tokenizer = tiktoken.get_encoding("gpt2") print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})) """输出""" [50256]
下面的SpamDataset类继承了pytorch中创建数据集的Dataset基类(子类只需要实现__init__、__getitem__,__len__三个方法就行),在类中标识训练集中最长的序列,并将填充标记添加到其他序列以匹配该序列长度
import torch from torch.utils.data import Dataset class SpamDataset(Dataset): def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): self.data = pd.read_csv(csv_file) # Pre-tokenize texts self.encoded_texts = [ tokenizer.encode(text) for text in self.data["Text"] ] if max_length is None: self.max_length = self._longest_encoded_length() else: self.max_length = max_length # Truncate sequences if they are longer than max_length self.encoded_texts = [ encoded_text[:self.max_length] for encoded_text in self.encoded_texts ] # Pad sequences to the longest sequence self.encoded_texts = [ encoded_text + [pad_token_id] * (self.max_length - len(encoded_text)) for encoded_text in self.encoded_texts ] def __getitem__(self, index): encoded = self.encoded_texts[index] label = self.data.iloc[index]["Label"] return ( torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long) ) def __len__(self): return len(self.data) def _longest_encoded_length(self): max_length = 0 for encoded_text in self.encoded_texts: encoded_length = len(encoded_text) if encoded_length > max_length: max_length = encoded_length return max_length
train_dataset = SpamDataset( csv_file="train.csv", max_length=None, tokenizer=tokenizer ) print(train_dataset.max_length) """输出""" 120
代码输出120,表明最长的序列包含不超过120个标记,这是文本消息的常见长度。考虑到上下文长度限制,该模型最多可以处理1,024个标记的序列。如果我们的数据集中包含更长的文本,可以在创建时传递参数
max_length=1024
。接下来,我们将验证集和测试集填充至与最长训练序列相同的长度。需要注意的是,任何超过最长训练样本长度的验证集和测试集样本,都会在我们之前定义的
SpamDataset
中通过encoded_text[:self.max_length]
进行截断。这种截断是可选的;如果验证集和测试集中没有超过1,024个标记的序列,我们可以将max_length
设置为None
。val_dataset = SpamDataset( csv_file="validation.csv", max_length=train_dataset.max_length, tokenizer=tokenizer ) test_dataset = SpamDataset( csv_file="test.csv", max_length=train_dataset.max_length, tokenizer=tokenizer )
接下来,我们使用数据集实例化数据加载器,如下图,单个训练批次由表示为令牌 ID 的八个文本消息组成。每个文本消息由 120 个令牌 ID 组成。
from torch.utils.data import DataLoader num_workers = 0 batch_size = 8 torch.manual_seed(123) train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, ) val_loader = DataLoader( dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, ) test_loader = DataLoader( dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, )
作为验证步骤,我们遍历数据加载器并确保每个批次包含8个训练示例,其中每个训练示例由120个令牌组成
print("Train loader:") for input_batch, target_batch in train_loader: pass print("Input batch dimensions:", input_batch.shape) print("Label batch dimensions", target_batch.shape) """输出""" Train loader: Input batch dimensions: torch.Size([8, 120]) Label batch dimensions torch.Size([8])
最后,让我们打印每个数据集中的批次总数
print(f"{len(train_loader)} training batches") print(f"{len(val_loader)} validation batches") print(f"{len(test_loader)} test batches") """输出""" 130 training batches 19 validation batches 38 test batches
到现在,我们已经准备好了数据。
原文地址:https://blog.csdn.net/hbkybkzw/article/details/145292077
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!