Self-training for ASR with HuggingFace
Automatic speech recognition models can be fused with language models to boost the performance. This opens great opportunity to leverage self-training techniques to significantly improve the results in low resource settings.
- 1. Train model on labeled data
- 2. LM boosted inference
- 3. Tune LM hyperparameters on dev data
- 4. Generate pseudo-labeled data
- 5. Filter pseudo-labeled data
- 6. Re-train model using mixture of labeled and pseudo-labeled data
- Concluding thoughts
- References
Art by dalle-mini
Performance of automatic speech recognition (ASR) models had increased substantially in past few years. One of the driving factors is availability of pre-trained large-scale speech representation models. Recently XLS-R model was released. It is pre-trained on large cross-lingual dataset which covers 128 different languages. When fine-tuned on medium-size labeled dataset model provides good performance on speech recognition task. Unfortunately the performance on low resource languages might still be mediocre. To improve the performance in low data regime self-training is often employed. Noizy-student[ref] is popular self-training method. Iterative pseudo-labeling is particularly well suited for ASR. One can fuse audio model with language model to improve the speech recognition performance. Language model not only provides a mean to improve quality of pseudo-labels but also the LM fusion scores can be used to filter pseudo labeled data to select most confident samples. In this notebook I employ simple heuristic for filtering the data introduced in Improved Noisy Student Training for Automatic Speech Recognition.
The procedure can be summarized as follows:
- Train
Wav2Vec2-xls-r
model on available labeled data - Train language model
- Tune hyperparameters for filtering on dev data
- Generate pseudo-labels for unlabeled data
- Filter generated labels using LM scores
- Re-train model using mixture of labeled and pseudo-labeled data
Steps 3-6 can be repeated multiple times. The filtering threshold can be relaxed for each iteration.
In this notebook I'll explore the application of the Noizy Student training procedure for speech recognition on Armenian language. Mozilla Common Voice dataset provides limited amount of high quality audio-transcript pairs. VoxLingua107 dataset has fair amount of untranscribed audio data for 107 languages including 66 hours of Armenian speech.
I'm using pyctcdecode library for integrating LM to ASR system. For detailed introduction to training KenLM and using it with HuggingFace refer to this exelent blogpost.
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import (
Wav2Vec2ProcessorWithLM,
Wav2Vec2Processor,
Wav2Vec2FeatureExtractor,
Wav2Vec2CTCTokenizer,
Wav2Vec2ForCTC,
Trainer,
TrainingArguments,
)
from datasets import (
load_dataset,
load_metric,
Audio,
concatenate_datasets,
DatasetDict,
Dataset,
load_from_disk
)
import bitsandbytes as bnb
I'm using WandB for experiment tracking. Let's set some environment variables to prepare the experiment.
import wandb
%env WANDB_ENTITY = arampacha
wandb_entity = os.environ["WANDB_ENTITY"]
%env WANDB_PROJECT = xlsr-hy
wandb_project = os.environ["WANDB_PROJECT"]
%env WANDB_LOG_MODEL = false
%env WANDB_WATCH = false
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lenghts and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
if "labels" in features[0].keys():
label_features = [{"input_ids": feature["labels"]} for feature in features]
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
The first we need a model trained on labeled data. Common Voice dataset has 738 training samples for Armenian and I haven't found any additional labeled data publicly available. The model pre-trained on Common Voice is available here. It achieves 22.6 word error rate (WER) with LM boosted decoding.
Notice that large ASR models like those of Wav2Vec2-XLS-R family can learn grammar in some extent. At some point during training the model can start to overfit to train set vocab. The valid loss stops decreasing while validation WER still improves. In my experience validation loss is better indicator of performance after LM fusion.
model_dir = "wav2vec2-xls-r-300m-hy-ns"
lang_id = "hy-AM"
repo_name = "wav2vec2-xls-r-300m-hy-ns"
@torch.no_grad()
def predict(model, dataset, bs=32, device="cpu"):
model.eval()
model.to(device)
loader = DataLoader(
dataset, batch_size=bs, collate_fn=data_collator, shuffle=False, drop_last=False, num_workers=4
)
all_logits = []
for batch in tqdm(loader):
batch = {k:v.to(device) for k,v in batch.items()}
logits = model(**batch).logits.cpu()
all_logits.append(logits)
lens = [logits.shape[1] for logits in all_logits]
max_len = max(lens)
all_logits = [F.pad(logits, (0, 0, 0, max_len-l), value=-100.) for logits, l in zip(all_logits, lens)]
return torch.cat(all_logits)
common_voice_train = load_dataset("mozilla-foundation/common_voice_8_0", lang_id, split="train+validation", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_8_0", lang_id, split="test", use_auth_token=True)
common_voice_train = common_voice_train.remove_columns(["accent", "age", "gender", "client_id", "down_votes", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "gender", "client_id", "down_votes", "locale", "segment", "up_votes"])
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
print(len(common_voice_train), len(common_voice_test))
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'«»\(\)։՝՞՛՚]'
def normalize_text(batch):
text = batch["sentence"]
text = re.sub(chars_to_remove_regex, '', text.lower())+" "
return {"sentence":text}
common_voice_train = common_voice_train.map(normalize_text)
common_voice_test = common_voice_test.map(normalize_text)
Now the transcripts should be ready. Let's check out some samples and verify train and test split vocabularies match.
test_transcripts = common_voice_test["sentence"]
test_transcripts[:5]
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
all_chars_train = sorted(vocab_train["vocab"][0])
all_chars_test = sorted(vocab_test["vocab"][0])
print("".join(all_chars_train))
print("".join(all_chars_test))
Armenian alphabet consists of 39 distinct characters. The vocabulary also includes whitespace and two special tokens - [PAD]
and [UNK]
.
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_train["vocab"][0]))}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
if not os.path.isfile(f"{model_dir}/vocab.json"):
import json
with open(f"{model_dir}/vocab.json", "w") as f:
json.dump(vocab_dict, f)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_dir, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
def prepare_dataset(batch):
audio = batch["audio"]
# batched output is "un-batched"
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_test.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
For measuring the model performance word error rate (WER) and character error rate (CER) metrics are used.
Both metrics are available through datasets
library.
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer, "cer":cer}
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
We start from Wav2Vec2-XLS-R-300m pretrained checkpoint by Meta AI.
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.0,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.75,
mask_feature_prob=0.25,
mask_feature_length=64,
layerdrop=0.05,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
model.freeze_feature_encoder();
For training I'm using bitsandbytes 8bit optimizer. It's designed to reduce memory usage and allows to fit with larger batch size. Tri-stage LR schedule was used for fine-tuning by the authors of the original paper. I train for total of 1600 steps, which is enough for validation loss to reach its minimum for dataset of this size.
from torch.optim.lr_scheduler import LambdaLR
def get_tri_stage_schedule(
optimizer, num_training_steps, ratios=[0.1, 0.4, 0.5], num_warmup_steps=None, num_hold_steps=None, start_ratio=0.01, end_ratio=0.05
):
assert (num_warmup_steps is None) == (num_hold_steps is None)
if num_warmup_steps is None:
num_warmup_steps = int(ratios[0]*num_training_steps)
num_hold_steps = int(ratios[1]*num_training_steps)
start_decay_step = num_warmup_steps + num_hold_steps
a_w, b_w = (1-start_ratio)/num_warmup_steps, start_ratio
num_decay_steps = num_training_steps - start_decay_step
a_d, b_d = (end_ratio-1)/num_decay_steps, 1.
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return a_w * float(current_step) + b_w
if current_step < start_decay_step:
return 1.
return max(end_ratio, a_d * float(current_step - start_decay_step) + b_d )
return LambdaLR(optimizer, lr_lambda)
num_training_steps = 1600
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=8e-5, betas=(0.9, 0.98), eps=1e-8, weight_decay=0.)
scheduler = get_tri_stage_schedule(optimizer, num_training_steps)
training_args = TrainingArguments(
output_dir=model_dir,
group_by_length=True,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
gradient_accumulation_steps=4,
dataloader_num_workers=8,
evaluation_strategy="steps",
max_steps=num_training_steps,
gradient_checkpointing=True,
fp16=True,
save_steps=200,
eval_steps=200,
logging_steps=200,
learning_rate=8e-5,
save_total_limit=4,
push_to_hub=False,
run_name="xlsr-300m-hy-demo-1",
report_to="wandb",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
optimizers=(optimizer, scheduler)
)
output = trainer.train()