ESPnet Speech Enhancement Demonstration¶
This notebook provides a demonstration of the speech enhancement and separation using ESPnet2-SE.
Presenters:
Shinji Watanabe (shinjiw@cmu.edu)
Chenda Li (lichenda1996@sjtu.edu.cn)
Jing Shi (shijing2014@ia.ac.cn)
Wangyou Zhang (wyz-97@sjtu.edu.cn)
Yen-Ju Lu (neil.lu@citi.sinica.edu.tw)
This notebook is created by: Chenda Li ([@LiChenda](https://github.com/LiChenda)) and Wangyou Zhang ([@Emrys365](https://github.com/Emrys365))
Contents¶
Tutorials on the Basic Usage
Install
Speech Enhancement with Pretrained Models
We support various interfaces, e.g. Python API, HuggingFace API, portable speech enhancement scripts for other tasks, etc.
2.1 Single-channel Enhancement (CHiME-4)
2.2 Enhance Your Own Recordings
2.3 Multi-channel Enhancement (CHiME-4)
Speech Separation with Pretrained Models
3.1 Model Selection
3.2 Separate Speech Mixture
Evaluate Separated Speech with the Pretrained ASR Model
Tutorials for Adding New Recipe and Contributing to ESPnet-SE Project
Creating a New Recipe
Implementing a New Speech Enhancement/Separation Model
(1) Tutorials on the Basic Usage¶
Install¶
[ ]:
%pip install -q espnet==0.10.1
%pip install -q espnet_model_zoo
Speech Enhancement with Pretrained Models¶
Single-Channel Enhancement, the CHiME example¶
[ ]:
# Download one utterance from real noisy speech of CHiME4
!gdown --id 1SmrN5NFSg6JuQSs2sfy3ehD8OIcqK6wS -O /content/M05_440C0213_PED_REAL.wav
import os
import soundfile
from IPython.display import display, Audio
mixwav_mc, sr = soundfile.read("/content/M05_440C0213_PED_REAL.wav")
# mixwav.shape: num_samples, num_channels
mixwav_sc = mixwav_mc[:,4]
display(Audio(mixwav_mc.T, rate=sr))
Download and load the pretrained Conv-Tasnet¶
[ ]:
!gdown --id 17DMWdw84wF3fz3t7ia1zssdzhkpVQGZm -O /content/chime_tasnet_singlechannel.zip
!unzip /content/chime_tasnet_singlechannel.zip -d /content/enh_model_sc
[ ]:
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech
separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_sc = SeparateSpeech(
train_config="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/config.yaml",
model_file="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/5epoch.pth",
# for segment-wise process on long speech
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=4,
normalize_output_wav=True,
device="cuda:0",
)
Enhance the single-channel real noisy speech in CHiME4¶
[ ]:
# play the enhanced single-channel speech
wave = enh_model_sc(mixwav_sc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_sc, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Enhance your own pre-recordings¶
[ ]:
from google.colab import files
from IPython.display import display, Audio
import soundfile
uploaded = files.upload()
for file_name in uploaded.keys():
speech, rate = soundfile.read(file_name)
assert rate == sr, "mismatch in sampling rate"
wave = enh_model_sc(speech[None, ...], sr)
print(f"Your input speech {file_name}", flush=True)
display(Audio(speech, rate=sr))
print(f"Enhanced speech for {file_name}", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Multi-Channel Enhancement¶
Download and load the pretrained mvdr neural beamformer.¶
[ ]:
# Download the pretained enhancement model
!gdown --id 1FohDfBlOa7ipc9v2luY-QIFQ_GJ1iW_i -O /content/mvdr_beamformer_16k_se_raw_valid.zip
!unzip /content/mvdr_beamformer_16k_se_raw_valid.zip -d /content/enh_model_mc
[ ]:
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech
separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_mc = SeparateSpeech(
train_config="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/config.yaml",
model_file="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/11epoch.pth",
# for segment-wise process on long speech
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=4,
normalize_output_wav=True,
device="cuda:0",
)
Enhance the multi-channel real noisy speech in CHiME4¶
[ ]:
wave = enh_model_mc(mixwav_mc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_mc.T, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Portable speech enhancement scripts for other tasks¶
For an ESPNet ASR or TTS dataset like below:
data
`-- et05_real_isolated_6ch_track
|-- spk2utt
|-- text
|-- utt2spk
|-- utt2uniq
`-- wav.scp
Run the following scripts to create an enhanced dataset:
scripts/utils/enhance_dataset.sh \
--spk_num 1 \
--gpu_inference true \
--inference_nj 4 \
--fs 16k \
--id_prefix "" \
dump/raw/et05_real_isolated_6ch_track \
data/et05_real_isolated_6ch_track_enh \
exp/enh_train_enh_beamformer_mvdr_raw/valid.loss.best.pth
The above script will generate a new directory data/et05_real_isolated_6ch_track_enh:
data
`-- et05_real_isolated_6ch_track_enh
|-- spk2utt
|-- text
|-- utt2spk
|-- utt2uniq
|-- wav.scp
`-- wavs/
where wav.scp contains paths to the enhanced audios (stored in wavs/).
Speech Separation¶
Model Selection¶
In this demonstration, we will show different speech separation models on wsj0_2mix.
The pretrained models can be download from direct URL, or from zenodo and huggingface with model ID.
[ ]:
#@title Choose Speech Separation model { run: "auto" }
fs = 8000 #@param {type:"integer"}
tag = "espnet/Chenda_Li_wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave" #@param ["Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave", "Chenda Li/wsj0_2mix_enh_train_enh_rnn_tf_raw_valid.si_snr.ave", "https://zenodo.org/record/4688000/files/enh_train_enh_dprnn_tasnet_raw_valid.si_snr.ave.zip", "espnet/Chenda_Li_wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave"]
[ ]:
# For models uploaded to Zenodo, you can use the following python script instead:
import sys
import soundfile
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech
d = ModelDownloader()
cfg = d.download_and_unpack(tag)
separate_speech = SeparateSpeech(
enh_train_config=cfg["train_config"],
enh_model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=None,
normalize_output_wav=True,
device="cuda:0",
)
Separate Speech Mixture¶
Separate the example in wsj0_2mix testing set¶
[ ]:
!gdown --id 1ZCUkd_Lb7pO2rpPr4FqYdtJBZ7JMiInx -O /content/447c020t_1.2106_422a0112_-1.2106.wav
import os
import soundfile
from IPython.display import display, Audio
mixwav, sr = soundfile.read("447c020t_1.2106_422a0112_-1.2106.wav")
waves_wsj = separate_speech(mixwav[None, ...], fs=sr)
print("Input mixture", flush=True)
display(Audio(mixwav, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_wsj[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_wsj[1].squeeze(), rate=sr))
Separate your own recordings¶
[ ]:
from google.colab import files
from IPython.display import display, Audio
import soundfile
uploaded = files.upload()
for file_name in uploaded.keys():
mixwav_yours, rate = soundfile.read(file_name)
assert rate == sr, "mismatch in sampling rate"
waves_yours = separate_speech(mixwav_yours[None, ...], fs=sr)
print("Input mixture", flush=True)
display(Audio(mixwav_yours, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_yours[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_yours[1].squeeze(), rate=sr))
Show spectrums of separated speech¶
[ ]:
import matplotlib.pyplot as plt
import torch
from torch_complex.tensor import ComplexTensor
from espnet.asr.asr_utils import plot_spectrogram
from espnet2.layers.stft import Stft
stft = Stft(
n_fft=512,
win_length=None,
hop_length=128,
window="hann",
)
ilens = torch.LongTensor([len(mixwav)])
# specs: (T, F)
spec_mix = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(mixwav).unsqueeze(0), ilens)[0].squeeze(),
dim=-1
)
)
spec_sep1 = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(waves_wsj[0]), ilens)[0].squeeze(),
dim=-1
)
)
spec_sep2 = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(waves_wsj[1]), ilens)[0].squeeze(),
dim=-1
)
)
# freqs = torch.linspace(0, sr / 2, spec_mix.shape[1])
# frames = torch.linspace(0, len(mixwav) / sr, spec_mix.shape[0])
samples = torch.linspace(0, len(mixwav) / sr, len(mixwav))
plt.figure(figsize=(24, 12))
plt.subplot(3, 2, 1)
plt.title('Mixture Spectrogram')
plot_spectrogram(
plt, abs(spec_mix).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 2)
plt.title('Mixture Wavform')
plt.plot(samples, mixwav)
plt.xlim(0, len(mixwav) / sr)
plt.subplot(3, 2, 3)
plt.title('Separated Spectrogram (spk1)')
plot_spectrogram(
plt, abs(spec_sep1).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 4)
plt.title('Separated Wavform (spk1)')
plt.plot(samples, waves_wsj[0].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.subplot(3, 2, 5)
plt.title('Separated Spectrogram (spk2)')
plot_spectrogram(
plt, abs(spec_sep2).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 6)
plt.title('Separated Wavform (spk2)')
plt.plot(samples, waves_wsj[1].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.xlabel("Time (s)")
plt.show()
Evluate separated speech with pretrained ASR model¶
The ground truths are:
text_1: SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR
text_2: THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK
(This may take a while for the speech recognition.)
[ ]:
%pip install -q https://github.com/kpu/kenlm/archive/master.zip # ASR need kenlm
[ ]:
import espnet_model_zoo
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text
wsj_8k_model_url="https://zenodo.org/record/4012264/files/asr_train_asr_transformer_raw_char_1gpu_valid.acc.ave.zip?download=1"
d = ModelDownloader()
speech2text = Speech2Text(
**d.download_and_unpack(wsj_8k_model_url),
device="cuda:0",
)
text_est = [None, None]
text_est[0], *_ = speech2text(waves_wsj[0].squeeze())[0]
text_est[1], *_ = speech2text(waves_wsj[1].squeeze())[0]
text_m, *_ = speech2text(mixwav)[0]
print("Mix Speech to Text: ", text_m)
print("Separated Speech 1 to Text: ", text_est[0])
print("Separated Speech 2 to Text: ", text_est[1])
[ ]:
import difflib
from itertools import permutations
import editdistance
import numpy as np
colors = dict(
red=lambda text: f"\033[38;2;255;0;0m{text}\033[0m" if text else "",
green=lambda text: f"\033[38;2;0;255;0m{text}\033[0m" if text else "",
yellow=lambda text: f"\033[38;2;225;225;0m{text}\033[0m" if text else "",
white=lambda text: f"\033[38;2;255;255;255m{text}\033[0m" if text else "",
black=lambda text: f"\033[38;2;0;0;0m{text}\033[0m" if text else "",
)
def diff_strings(ref, est):
"""Reference: https://stackoverflow.com/a/64404008/7384873"""
ref_str, est_str, err_str = [], [], []
matcher = difflib.SequenceMatcher(None, ref, est)
for opcode, a0, a1, b0, b1 in matcher.get_opcodes():
if opcode == "equal":
txt = ref[a0:a1]
ref_str.append(txt)
est_str.append(txt)
err_str.append(" " * (a1 - a0))
elif opcode == "insert":
ref_str.append("*" * (b1 - b0))
est_str.append(colors["green"](est[b0:b1]))
err_str.append(colors["black"]("I" * (b1 - b0)))
elif opcode == "delete":
ref_str.append(ref[a0:a1])
est_str.append(colors["red"]("*" * (a1 - a0)))
err_str.append(colors["black"]("D" * (a1 - a0)))
elif opcode == "replace":
diff = a1 - a0 - b1 + b0
if diff >= 0:
txt_ref = ref[a0:a1]
txt_est = colors["yellow"](est[b0:b1]) + colors["red"]("*" * diff)
txt_err = "S" * (b1 - b0) + "D" * diff
elif diff < 0:
txt_ref = ref[a0:a1] + "*" * -diff
txt_est = colors["yellow"](est[b0:b1]) + colors["green"]("*" * -diff)
txt_err = "S" * (b1 - b0) + "I" * -diff
ref_str.append(txt_ref)
est_str.append(txt_est)
err_str.append(colors["black"](txt_err))
return "".join(ref_str), "".join(est_str), "".join(err_str)
text_ref = [
"SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR",
"THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK",
]
print("=====================" , flush=True)
perms = list(permutations(range(2)))
string_edit = [
[
editdistance.eval(text_ref[m], text_est[n])
for m, n in enumerate(p)
]
for p in perms
]
dist = [sum(edist) for edist in string_edit]
perm_idx = np.argmin(dist)
perm = perms[perm_idx]
for i, p in enumerate(perm):
print("\n--------------- Text %d ---------------" % (i + 1), flush=True)
ref, est, err = diff_strings(text_ref[i], text_est[p])
print("REF: " + ref + "\n" + "HYP: " + est + "\n" + "ERR: " + err, flush=True)
print("Edit Distance = {}\n".format(string_edit[perm_idx][i]), flush=True)
(2) Tutorials on Contributing to ESPNet-SE Project¶
If you would like to contribute to the ESPnet-SE project, or if you would like to make modifications based on the current speech enhancement/separation functionality, the following tutorials will provide you detailed information about how to creating new recipes or new models in ESPnet-SE.
Creating a New Recipe¶
Step 1 Create recipe directory¶
First, run the following command to create the directory for the new recipe from our template:
egs2/TEMPLATE/enh1/setup.sh egs2/<your-recipe-name>/enh1
For the following steps, we assume the operations are done under the directory
egs2/<your-recipe-name>/enh1/
.
Step 2 Write scripts for data preparation¶
Prepare local/data.sh
, which will be used in stage 1 in enh.sh
. It can take some arguments as input, see egs2/wsj0_2mix/enh1/local/data.sh for reference.
The script local/data.sh
should finally generate Kaldi-style data directories under <recipe-dir>/data/
. Each subset directory should contains at least 4 files:
<recipe-dir>/data/<subset-name>/
├── spk{1,2,3...}.scp (clean speech references)
├── spk2utt
├── utt2spk
└── wav.scp (noisy speech)
Optionally, it can also contain noise{}.scp
and dereverb{}.scp
, which point to the corresponding noise and dereverberated references respectively. {} can be 1, 2, …, depending on the number of noise types (dereverberated signals) in the input signal in wav.scp
.
Make sure to sort the scp and other related files as in Kaldi. Also, remember to run . ./path.sh
in local/data.sh
before sorting, because it will force sorting to be byte-wise, i.e. export LC_ALL=C
.
Remember to check your new scripts with shellcheck, otherwise they may fail the tests in ci/test_shell.sh.
Step 3 Prepare training configuration¶
Prepare training configuration files (e.g. train.yaml) under conf/
.
If you have multiple configuration files, it is recommended to put them under
conf/tuning/
, and create a symbolic linkconf/tuning/train.yaml
pointing to the config file with the best performance.
Step 4 Prepare run.sh¶
Write run.sh
to provide a template entry script, so that users can easily run your recipe by ./run.sh
. See egs2/wsj0_2mix/enh1/run.sh for reference.
If your recipes provide references for noise and/or dereverberation, you can add the argument
--use_noise_ref true
and/or--use_dereverb_ref true
inrun.sh
.
Implementing a New Speech Enhancement/Separation Model¶
The current ESPnet-SE tool adopts an encoder-separator-decoder architecture for all models, e.g.
For Time-Frequency masking models, the encoder and decoder would be stft_encoder.py and stft_decoder.py respectively, and the separator can be any of dprnn_separator.py, rnn_separator.py, tcn_separator.py, and transformer_separator.py. For TasNet, the encoder and decoder are conv_encoder.py and conv_decoder.py respectively. The separator is tcn_separator.py.
Step 1 Create model scripts¶
For encoder, separator, and decoder models, create new scripts under espnet2/enh/encoder/, espnet2/enh/separator/, and espnet2/enh/decoder/, respectively.
For a separator model, please make sure it implements the num_spk
property. See espnet2/enh/separator/rnn_separator.py for reference.
Remember to format your new scripts to match the styles in
black
andflake8
, otherwise they may fail the tests in ci/test_python.sh.
Step 3 [Optional] Create new loss functions¶
If you want to use a new loss function for your model, you can add it to espnet2/enh/espnet_model.py, such as:
@staticmethod
def new_loss(ref, inf):
"""Your new loss
Args:
ref: (Batch, samples)
inf: (Batch, samples)
Returns:
loss: (Batch,)
"""
...
return loss
Then add your loss name to ALL_LOSS_TYPES, and handle the loss calculation in _compute_loss.
Step 4 Create unit tests for the new model¶
Finally, it would be nice to make some unit tests for your new model under test/espnet2/enh/encoder, test/espnet2/enh/decoder, or test/espnet2/enh/separator.