ESPnet Speech Enhancement Demonstration¶
This notebook provides a demonstration of the speech enhancement and separation using ESPnet2-SE.
Author: Chenda Li ([@LiChenda](https://github.com/LiChenda)), Wangyou Zhang ([@Emrys365](https://github.com/Emrys365))
Install¶
[ ]:
%pip install -q espnet==0.10.1
%pip install -q espnet_model_zoo
Speech Enhancement¶
Single-Channel Enhancement, the CHiME example¶
[ ]:
# Download one utterance from real noisy speech of CHiME4
!gdown --id 1SmrN5NFSg6JuQSs2sfy3ehD8OIcqK6wS -O /content/M05_440C0213_PED_REAL.wav
import os
import soundfile
from IPython.display import display, Audio
mixwav_mc, sr = soundfile.read("/content/M05_440C0213_PED_REAL.wav")
# mixwav.shape: num_samples, num_channels
mixwav_sc = mixwav_mc[:,4]
display(Audio(mixwav_mc.T, rate=sr))
Download and load the pretrained Conv-Tasnet¶
[ ]:
!gdown --id 17DMWdw84wF3fz3t7ia1zssdzhkpVQGZm -O /content/chime_tasnet_singlechannel.zip
!unzip /content/chime_tasnet_singlechannel.zip -d /content/enh_model_sc
[ ]:
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech
separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_sc = SeparateSpeech(
enh_train_config="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/config.yaml",
enh_model_file="/content/enh_model_sc/exp/enh_train_enh_conv_tasnet_raw/5epoch.pth",
# for segment-wise process on long speech
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=4,
normalize_output_wav=True,
device="cuda:0",
)
Enhance the single-channel real noisy speech in CHiME4¶
[ ]:
# play the enhanced single-channel speech
wave = enh_model_sc(mixwav_sc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_sc, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Enhance your own pre-recordings¶
[ ]:
from google.colab import files
from IPython.display import display, Audio
import soundfile
uploaded = files.upload()
for file_name in uploaded.keys():
speech, rate = soundfile.read(file_name)
assert rate == sr, "mismatch in sampling rate"
wave = enh_model_sc(speech[None, ...], sr)
print(f"Your input speech {file_name}", flush=True)
display(Audio(speech, rate=sr))
print(f"Enhanced speech for {file_name}", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Multi-Channel Enhancement¶
Download and load the pretrained mvdr neural beamformer.¶
[ ]:
# Download the pretained enhancement model
!gdown --id 1FohDfBlOa7ipc9v2luY-QIFQ_GJ1iW_i -O /content/mvdr_beamformer_16k_se_raw_valid.zip
!unzip /content/mvdr_beamformer_16k_se_raw_valid.zip -d /content/enh_model_mc
[ ]:
# Load the model
# If you encounter error "No module named 'espnet2'", please re-run the 1st Cell. This might be a colab bug.
import sys
import soundfile
from espnet2.bin.enh_inference import SeparateSpeech
separate_speech = {}
# For models downloaded from GoogleDrive, you can use the following script:
enh_model_mc = SeparateSpeech(
enh_train_config="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/config.yaml",
enh_model_file="/content/enh_model_mc/exp/enh_train_enh_beamformer_mvdr_raw/11epoch.pth",
# for segment-wise process on long speech
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=4,
normalize_output_wav=True,
device="cuda:0",
)
Enhance the multi-channel real noisy speech in CHiME4¶
[ ]:
wave = enh_model_mc(mixwav_mc[None, ...], sr)
print("Input real noisy speech", flush=True)
display(Audio(mixwav_mc.T, rate=sr))
print("Enhanced speech", flush=True)
display(Audio(wave[0].squeeze(), rate=sr))
Speech Separation¶
Model Selection¶
Please select model shown in espnet_model_zoo
In this demonstration, we will show different speech separation models on wsj0_2mix.
[ ]:
#@title Choose Speech Separation model { run: "auto" }
fs = 8000 #@param {type:"integer"}
tag = "Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave" #@param ["Chenda Li/wsj0_2mix_enh_train_enh_conv_tasnet_raw_valid.si_snr.ave", "Chenda Li/wsj0_2mix_enh_train_enh_rnn_tf_raw_valid.si_snr.ave", "https://zenodo.org/record/4688000/files/enh_train_enh_dprnn_tasnet_raw_valid.si_snr.ave.zip"]
[ ]:
# For models uploaded to Zenodo, you can use the following python script instead:
import sys
import soundfile
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech
d = ModelDownloader()
cfg = d.download_and_unpack(tag)
separate_speech = SeparateSpeech(
enh_train_config=cfg["train_config"],
enh_model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=None,
normalize_output_wav=True,
device="cuda:0",
)
Separate Speech Mixture¶
Separate the example in wsj0_2mix testing set¶
[ ]:
!gdown --id 1ZCUkd_Lb7pO2rpPr4FqYdtJBZ7JMiInx -O /content/447c020t_1.2106_422a0112_-1.2106.wav
import os
import soundfile
from IPython.display import display, Audio
mixwav, sr = soundfile.read("447c020t_1.2106_422a0112_-1.2106.wav")
waves_wsj = separate_speech(mixwav[None, ...], fs=sr)
print("Input mixture", flush=True)
display(Audio(mixwav, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_wsj[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_wsj[1].squeeze(), rate=sr))
Separate your own recordings¶
[ ]:
from google.colab import files
from IPython.display import display, Audio
import soundfile
uploaded = files.upload()
for file_name in uploaded.keys():
mixwav_yours, rate = soundfile.read(file_name)
assert rate == sr, "mismatch in sampling rate"
waves_yours = separate_speech(mixwav_yours[None, ...], fs=sr)
print("Input mixture", flush=True)
display(Audio(mixwav_yours, rate=sr))
print(f"========= Separated speech with model {tag} =========", flush=True)
print("Separated spk1", flush=True)
display(Audio(waves_yours[0].squeeze(), rate=sr))
print("Separated spk2", flush=True)
display(Audio(waves_yours[1].squeeze(), rate=sr))
Show spectrums of separated speech¶
[ ]:
import matplotlib.pyplot as plt
import torch
from torch_complex.tensor import ComplexTensor
from espnet.asr.asr_utils import plot_spectrogram
from espnet2.layers.stft import Stft
stft = Stft(
n_fft=512,
win_length=None,
hop_length=128,
window="hann",
)
ilens = torch.LongTensor([len(mixwav)])
# specs: (T, F)
spec_mix = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(mixwav).unsqueeze(0), ilens)[0].squeeze(),
dim=-1
)
)
spec_sep1 = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(waves_wsj[0]), ilens)[0].squeeze(),
dim=-1
)
)
spec_sep2 = ComplexTensor(
*torch.unbind(
stft(torch.as_tensor(waves_wsj[1]), ilens)[0].squeeze(),
dim=-1
)
)
# freqs = torch.linspace(0, sr / 2, spec_mix.shape[1])
# frames = torch.linspace(0, len(mixwav) / sr, spec_mix.shape[0])
samples = torch.linspace(0, len(mixwav) / sr, len(mixwav))
plt.figure(figsize=(24, 12))
plt.subplot(3, 2, 1)
plt.title('Mixture Spectrogram')
plot_spectrogram(
plt, abs(spec_mix).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 2)
plt.title('Mixture Wavform')
plt.plot(samples, mixwav)
plt.xlim(0, len(mixwav) / sr)
plt.subplot(3, 2, 3)
plt.title('Separated Spectrogram (spk1)')
plot_spectrogram(
plt, abs(spec_sep1).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 4)
plt.title('Separated Wavform (spk1)')
plt.plot(samples, waves_wsj[0].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.subplot(3, 2, 5)
plt.title('Separated Spectrogram (spk2)')
plot_spectrogram(
plt, abs(spec_sep2).transpose(-1, -2).numpy(), fs=sr,
mode='db', frame_shift=None,
bottom=False, labelbottom=False
)
plt.subplot(3, 2, 6)
plt.title('Separated Wavform (spk2)')
plt.plot(samples, waves_wsj[1].squeeze())
plt.xlim(0, len(mixwav) / sr)
plt.xlabel("Time (s)")
plt.show()
Evluate separated speech with pretrained ASR model¶
The ground truths are:
text_1: SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR
text_2: THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK
(This may take a while for the speech recognition.)
[ ]:
import espnet_model_zoo
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text
wsj_8k_model_url="https://zenodo.org/record/4012264/files/asr_train_asr_transformer_raw_char_1gpu_valid.acc.ave.zip?download=1"
d = ModelDownloader()
speech2text = Speech2Text(
**d.download_and_unpack(wsj_8k_model_url),
device="cuda:0",
)
text_est = [None, None]
text_est[0], *_ = speech2text(waves_wsj[0].squeeze())[0]
text_est[1], *_ = speech2text(waves_wsj[1].squeeze())[0]
text_m, *_ = speech2text(mixwav)[0]
print("Mix Speech to Text: ", text_m)
print("Separated Speech 1 to Text: ", text_est[0])
print("Separated Speech 2 to Text: ", text_est[1])
[ ]:
import difflib
from itertools import permutations
import editdistance
import numpy as np
colors = dict(
red=lambda text: f"\033[38;2;255;0;0m{text}\033[0m" if text else "",
green=lambda text: f"\033[38;2;0;255;0m{text}\033[0m" if text else "",
yellow=lambda text: f"\033[38;2;225;225;0m{text}\033[0m" if text else "",
white=lambda text: f"\033[38;2;255;255;255m{text}\033[0m" if text else "",
black=lambda text: f"\033[38;2;0;0;0m{text}\033[0m" if text else "",
)
def diff_strings(ref, est):
"""Reference: https://stackoverflow.com/a/64404008/7384873"""
ref_str, est_str, err_str = [], [], []
matcher = difflib.SequenceMatcher(None, ref, est)
for opcode, a0, a1, b0, b1 in matcher.get_opcodes():
if opcode == "equal":
txt = ref[a0:a1]
ref_str.append(txt)
est_str.append(txt)
err_str.append(" " * (a1 - a0))
elif opcode == "insert":
ref_str.append("*" * (b1 - b0))
est_str.append(colors["green"](est[b0:b1]))
err_str.append(colors["black"]("I" * (b1 - b0)))
elif opcode == "delete":
ref_str.append(ref[a0:a1])
est_str.append(colors["red"]("*" * (a1 - a0)))
err_str.append(colors["black"]("D" * (a1 - a0)))
elif opcode == "replace":
diff = a1 - a0 - b1 + b0
if diff >= 0:
txt_ref = ref[a0:a1]
txt_est = colors["yellow"](est[b0:b1]) + colors["red"]("*" * diff)
txt_err = "S" * (b1 - b0) + "D" * diff
elif diff < 0:
txt_ref = ref[a0:a1] + "*" * -diff
txt_est = colors["yellow"](est[b0:b1]) + colors["green"]("*" * -diff)
txt_err = "S" * (b1 - b0) + "I" * -diff
ref_str.append(txt_ref)
est_str.append(txt_est)
err_str.append(colors["black"](txt_err))
return "".join(ref_str), "".join(est_str), "".join(err_str)
text_ref = [
"SOME CRITICS INCLUDING HIGH REAGAN ADMINISTRATION OFFICIALS ARE RAISING THE ALARM THAT THE FED'S POLICY IS TOO TIGHT AND COULD CAUSE A RECESSION NEXT YEAR",
"THE UNITED STATES UNDERTOOK TO DEFEND WESTERN EUROPE AGAINST SOVIET ATTACK",
]
print("=====================" , flush=True)
perms = list(permutations(range(2)))
string_edit = [
[
editdistance.eval(text_ref[m], text_est[n])
for m, n in enumerate(p)
]
for p in perms
]
dist = [sum(edist) for edist in string_edit]
perm_idx = np.argmin(dist)
perm = perms[perm_idx]
for i, p in enumerate(perm):
print("\n--------------- Text %d ---------------" % (i + 1), flush=True)
ref, est, err = diff_strings(text_ref[i], text_est[p])
print("REF: " + ref + "\n" + "HYP: " + est + "\n" + "ERR: " + err, flush=True)
print("Edit Distance = {}\n".format(string_edit[perm_idx][i]), flush=True)