自学内容网 自学内容网

语音实战(一)中文语音识别

一、下载文件


from datasets import load_dataset

name = 'mozilla-foundation/common_voice_16_0'
load_dataset(name, 'zh-CN', split='train').save_to_disk('dataset/' + name)
from transformers import Wav2Vec2BertForCTC

name = 'lansinuote/Chinese_Speech_to_Text_CTC'
Wav2Vec2BertForCTC.from_pretrained(name).save_pretrained('model/' + name)

这里是使用 Hugging Face 的 datasetstransformers 库来加载和保存中文语音识别相关的数据集和预训练模型

①Mozilla Foundation 提供的 Common Voice 数据集,包含了多种语言的语音数据,旨在促进语音识别技术的发展

②av2Vec2BertForCTC 是结合了 Wav2Vec2 和 BERT 的一个模型,适用于语音到文本(Speech-to-Text)任务。这里使用的是 CTC(Connectionist Temporal Classification)损失函数,常用于语音识别任务。

二、语音识别任务

2.1预处理工具

import torch
import random

from transformers import Wav2Vec2CTCTokenizer, SeamlessM4TFeatureExtractor, Wav2Vec2BertProcessor

#文字编码工具
#使用processor文件夹下的vocab.json构建tokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('./processor',
                                                 bos_token='[CLS]',
                                                 eos_token='[SEP]',
                                                 unk_token='[UNK]',
                                                 pad_token='[PAD]')

#声音信号编码工具
feature_extractor = SeamlessM4TFeatureExtractor(sampling_rate=16000,
                                                padding_value=0.0)

#组合上面两个工具
processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor,
                                  tokenizer=tokenizer)

del tokenizer
del feature_extractor

processor
  • 使用 Wav2Vec2CTCTokenizer 创建了一个用于文本标记化的工具。
  • 使用 SeamlessM4TFeatureExtractor 创建了一个用于音频信号特征提取的工具。
  • 将这两个工具结合成 Wav2Vec2BertProcessor,用于后续的语音识别任务处理。
  • 通过删除中间对象来优化内存使用。

2.2 预处理

#测试
data = processor(text=['测试文字1', '测试测试文字2'],
                 audio=[torch.randn(8000).numpy(),
                        torch.randn(16000).numpy()],
                 sampling_rate=16000,
                 padding=True,
                 return_tensors='pt')

#其实分开用更方便一点
data = processor.tokenizer(['测试文字1', '测试测试文字2'],
                           padding=True,
                           truncation=True,
                           max_length=35 + 2,
                           return_tensors='pt')

data = processor.feature_extractor(
    [torch.randn(8000).numpy(),
     torch.randn(16000).numpy()],
    sampling_rate=16000,
    padding=True,
    truncation=True,
    max_length=900,
    padding_value=0.0,
    return_tensors='pt')

for k, v in data.items():
    print(k, v.shape, v.dtype, v)

使用 processor 预处理文本和音频数据:

  • processor:结合了文本标记化器和音频特征提取器的工具,可以同时处理文本和音频数据。
  • tokenizerfeature_extractor:可以单独使用,分别处理文本和音频。
  • return_tensors='pt':输出的是 PyTorch 张量,方便后续在模型中使用。
  • 打印输出:帮助验证每个预处理后的数据的形状和类型。

2.3 处理 Mozilla Common Voice 数据集

from datasets import load_from_disk, Audio

dataset = load_from_disk('dataset/mozilla-foundation/common_voice_16_0')

dataset = dataset.remove_columns([
    'accent', 'age', 'client_id', 'down_votes', 'gender', 'locale', 'segment',
    'up_votes', 'path', 'variant'
])
dataset = dataset.rename_columns({'sentence': 'text'})
dataset = dataset.cast_column('audio', Audio(sampling_rate=16000))


def f(data):
    lens_audio = len(data['audio']['array']) / 16000
    lens_text = len(data['text'])
    return 1 <= lens_audio <= 9 and 2 <= lens_text <= 35


dataset = dataset.filter(f)

dataset, dataset[3]
  1. 加载存储在磁盘上的 Mozilla Common Voice 数据集。
  2. 删除多余的列(如 accentage 等)。
  3. 重命名列,使其更有意义(sentence 重命名为 text)。
  4. 将音频数据转换为 Audio 格式并设置采样率为 16000 Hz。
  5. 根据音频时长(1 到 9 秒)和文本长度(2 到 35 个字符)进行过滤,去掉不符合条件的数据。
  6. 最后,返回处理后的数据集和其中的第 4 个样本(即 dataset[3])。

2.4 播放音频并显示对应的文本

def show(data):
    from IPython.display import Audio, display
    display(Audio(data=data, rate=16000))


show(dataset[3]['audio']['array'])
dataset[3]['text']

  • 这段代码可以在 Jupyter Notebook 中播放指定样本的音频,并显示它对应的文本。
    • show() 函数用于播放音频数据。
    • dataset[3]['text'] 返回该样本的文本内容。

2.5 音频和文本数据的训练或推理

def f(data):
    text = [i['text'] for i in data]
    text = processor.tokenizer(text,
                               padding=True,
                               truncation=True,
                               max_length=35 + 2,
                               return_tensors='pt').to('cuda')

    audio = [i['audio']['array'] for i in data]
    audio = processor.feature_extractor(audio,
                                        sampling_rate=16000,
                                        padding=True,
                                        truncation=True,
                                        max_length=900,
                                        padding_value=0.0,
                                        return_tensors='pt').to('cuda')

    return text.input_ids, audio.input_features, audio.attention_mask


loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=4,
                                     collate_fn=f,
                                     drop_last=True,
                                     shuffle=True)

len(loader), next(iter(loader))
  • 该代码定义了如何通过自定义的 collate_fn 函数将批次中的音频和文本数据进行处理,并将其转移到 GPU。
  • DataLoader 将批量加载数据,使用 collate_fn=f 来处理每个批次,将其转换为模型可以接受的格式。
  • len(loader) 返回批次数,next(iter(loader)) 返回第一个批次的数据。

2.6 预测字符序列

class Wav2Vec2BertForCTC(torch.nn.Module):

    def __init__(self):
        super().__init__()

        from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig
        config = Wav2Vec2BertConfig.from_pretrained(
            'model/lansinuote/Chinese_Speech_to_Text_CTC')

        self.wav2vec2_bert = Wav2Vec2BertModel(config)
        self.dropout = torch.nn.Dropout(0.1)
        self.lm_head = torch.nn.Linear(1024, processor.tokenizer.vocab_size)

        from transformers import Wav2Vec2BertForCTC
        parameters = Wav2Vec2BertForCTC.from_pretrained(
            'model/lansinuote/Chinese_Speech_to_Text_CTC')
        self.wav2vec2_bert.load_state_dict(
            parameters.wav2vec2_bert.state_dict())
        #丢弃部分参数,验证训练过程是有效的
        #self.lm_head.load_state_dict(parameters.lm_head.state_dict())
        del parameters

        self.train()
        self.to('cuda')

    def forward(self, input_features, attention_mask):
        last_hidden_state = self.wav2vec2_bert(
            input_features, attention_mask=attention_mask).last_hidden_state

        last_hidden_state = self.dropout(last_hidden_state)

        return self.lm_head(last_hidden_state)


model = Wav2Vec2BertForCTC()

with torch.no_grad():
    input_features = torch.randn(4, 377, 160).to('cuda')
    attention_mask = torch.ones(4, 377).long().to('cuda')
    print(model(input_features, attention_mask).shape)
  • 这个模型结合了 Wav2Vec2(用于音频特征提取)和 BERT(用于上下文感知的表示),并通过 CTC 损失函数进行训练,旨在解决语音到文本的任务。
  • Wav2Vec2BertForCTC 模型首先对音频数据进行特征提取,然后将特征通过 BERT 进行上下文建模,最后通过线性层映射到词汇表空间。模型的输出将与实际标签进行比较,计算 CTC 损失。
  • 该模型加载了预训练的权重,并允许在已有权重的基础上进行微调。

2.7 训练过程

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = torch.nn.CTCLoss(blank=processor.tokenizer.pad_token_id,
                             reduction='mean',
                             zero_infinity=False)

for epoch in range(1):
    for i, (input_ids, input_features, attention_mask) in enumerate(loader):
        logits = model(input_features, attention_mask)

        log_probs = logits.log_softmax(dim=2).transpose(0, 1)
        input_lengths = (attention_mask.sum(1) / 2).ceil().long()
        input_ids_mask = input_ids != processor.tokenizer.pad_token_id
        
        loss = criterion(log_probs=log_probs,
                         targets=input_ids[input_ids_mask],
                         input_lengths=input_lengths,
                         target_lengths=input_ids_mask.sum(-1))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 500 == 0:
            print(epoch, i, loss.item())
            print(processor.tokenizer.decode(input_ids[0]))
            print(processor.tokenizer.decode(logits[0].argmax(1)))

这段代码实现了一个简单的训练过程,用于训练 Wav2Vec2BertForCTC 模型进行语音到文本的任务。训练过程中使用 CTC 损失 来优化模型的参数。关键步骤包括:

  • 使用 AdamW 优化器更新参数。
  • 使用 CTC 损失计算模型的预测与真实文本标签之间的差距。
  • 每 500 次迭代输出训练过程中的损失值及预测结果。

原文地址:https://blog.csdn.net/qq_51292909/article/details/145038187

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!