自学内容网 自学内容网

PyTorch用map函数进行tokenize之后,为什么需要remove_columns?

遇到的问题

下面以GLUE数据集中的tokenize为例,下面是典型的处理方法,经过一个map函数,使用lambda表达式对数据进行tokenize。后面会跟上remove_columns这句话,这是为什么呢?

if(task_name=="sst2"):
        tokenized_datasets = raw_dataset.map(lambda sample: sst2_tokenization(sample, tokenizer))
        tokenized_datasets = tokenized_datasets.remove_columns(['sentence', "label", "idx"])
    elif(task_name=="cola"):
        tokenized_datasets = raw_dataset.map(lambda sample: cola_tokenization(sample, tokenizer))
        tokenized_datasets = tokenized_datasets.remove_columns(['sentence', "label", "idx"])
    elif(task_name=="mnli"):
        tokenized_datasets = raw_dataset.map(lambda sample: mnli_tokenization(sample, tokenizer))
        tokenized_datasets = tokenized_datasets.remove_columns(["hypothesis", 'premise', "label", "idx"])
    elif(task_name=="mrpc"):
        tokenized_datasets = raw_dataset.map(lambda sample: mrpc_tokenization(sample, tokenizer))
        tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", 'sentence2', "label", "idx"])

上面代码使用的tokenize函数,比如下面两个sst2_tokenization和cola_tokenization

def sst2_tokenization(example, tokenizer):
    prefix = "Classification task:sst2. Choose a label from list:[0, 1] for this context: "
    input_str = example["sentence"]
    input_str = prefix + input_str
    input_str += " Among them, 0 represents negative, 1 represents positive. Answer:"
    label_dict = {"-1":"-1", "0":"0", "1":"1"}
    label = example["label"]
    target_str = label_dict[str(label)]

    tokenized_data = tokenizer(input_str, return_tensors="pt", truncation=True, max_length=256, padding = "max_length")
    input_ids = tokenized_data.input_ids.squeeze(0)
    attention_mask = tokenized_data.attention_mask.squeeze(0)
    target_ids = tokenizer(target_str, return_tensors="pt", truncation=True, max_length=3, padding = "max_length").input_ids
    target_ids = target_ids.squeeze(0)

    output_batch = {
                "input_ids": input_ids,
                "labels": target_ids,
                "attention_mask": attention_mask
            }
    
    return output_batch

def cola_tokenization(example, tokenizer):
    prefix = "Classification task:cola. Choose a label from list:[0, 1] for this context: "
    input_str = example["sentence"]
    input_str = prefix + input_str
    input_str += " Among them, 0 represents unacceptable, 1 represents acceptable. Answer:"
    label_dict = {"-1":"-1", "0":"0", "1":"1"}
    label = example["label"]
    target_str = label_dict[str(label)]

    tokenized_data = tokenizer(input_str, return_tensors="pt", truncation=True, max_length=256, padding = "max_length")
    input_ids = tokenized_data.input_ids.squeeze(0)
    attention_mask = tokenized_data.attention_mask.squeeze(0)
    target_ids = tokenizer(target_str, return_tensors="pt", truncation=True, max_length=3, padding = "max_length").input_ids
    target_ids = target_ids.squeeze(0)

    output_batch = {
                "input_ids": input_ids,
                "labels": target_ids,
                "attention_mask": attention_mask
            }
return output_batch

参考的代码来自:https://github.com/mlwu22/TPT/blob/main/data_process.py

解答

这是我问GPT的问题:

经过cola_tokenization之后,不是结果只有 “input_ids”: input_ids, # 经过编码的输入句子
“labels”: target_ids, # 经过编码的目标标签
“attention_mask”: attention_mask # 注意力掩码吗,为什么还有’sentence’, “label”, "idx"这几个字段?

下面是回答:

你理解得没错,经过 cola_tokenization 之后,处理过的数据应该只包含 input_idslabelsattention_mask 这些字段。那么为什么在 tokenized_datasets 中还会存在 sentencelabelidx 这几个字段呢?这是因为 map 函数默认会保留原始数据集中的所有列,除非你显式地删除它们。

让我们一步一步来看:

1. 原始数据集的结构

raw_dataset 原始数据集包含以下字段:

  • sentence: 句子的文本内容
  • label: 标签,表示句子是否符合语法
  • idx: 每个样本的唯一索引,标识数据集中的位置

2. map 函数的默认行为

当你使用 map 函数时,它会应用你定义的函数(这里是 cola_tokenization)来处理每个样本。cola_tokenization 返回的是一个新的字典,包含 input_idslabelsattention_mask但是map 函数的默认行为是将这些新生成的字段添加到原始数据集的每个样本上,同时保留原有的字段,除非你告诉它删除这些字段。

因此,经过 cola_tokenization 处理后,数据集中的每个样本可能包含以下字段:

  • 原始字段:
    • sentence: 原始的句子文本
    • label: 标签
    • idx: 样本索引
  • 新生成的字段:
    • input_ids: 经过 tokenizer 处理的输入
    • attention_mask: 输入的注意力掩码
    • labels: 模型的目标标签

3. remove_columns 的作用

你要做的是删除这些原始字段,因为它们对模型训练没有用了。通过 remove_columns(['sentence', "label", "idx"]),你明确地删除了 sentencelabelidx 这几个原始字段,这样数据集中就只剩下模型需要的 input_idsattention_masklabels

总结

map 函数的默认行为是保留原始数据集的字段,所以即使你已经添加了 input_idsattention_masklabels,原来的 sentencelabelidx 依然存在。通过 remove_columns,你才能将这些原始字段删除,避免它们在训练中产生冗余信息。

后记

2024年9月26日写于上海,基于GPT4o大模型生成。


原文地址:https://blog.csdn.net/shizheng_Li/article/details/142571875

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!