import soundfile as sf
import torch
from scipy.io import wavfile
import os
import numpy as np
audio_file_save = "/home*/wav/1.wav"
clean, s = sf.read("/home/*/clean_trainset_wav/p287_424.wav")
x1 = torch.stft(torch.tensor(clean), n_fft=512, hop_length=100, win_length=400, return_complex=False)# 1,256,641,2
x1 = x1.unsqueeze(dim=0)
x1 = x1.permute(0, 3, 2, 1) # 1,2,641,256
# # conduct sqrt power-compression
x1_mag, x1_phase = torch.norm(x1, dim=1) ** 0.5, torch.atan2(x1[:, -1, ...], x1[:, 0, ...]) # 1,641,256; 1,641,256
x1 = torch.stack((x1_mag * torch.cos(x1_phase), x1_mag * torch.sin(x1_phase)), dim=1) # 1,2,641,256
# pad = torch.nn.ZeroPad2d((0, 1, 0, 0))
# x1 = pad(x1)
out1 = x1.permute(0, 3, 2, 1)
out1 = torch.istft(out1, 512, hop_length=100, win_length=400, return_complex=False).unsqueeze(1)
out1 = out1.numpy()
wavfile.write(audio_file_save, 16000, out1.astype(np.float32))
2.测试code
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from math import sqrt, log10, ceil
from pystoi import stoi
from pesq import pesq
from pesq.cypesq import PesqError
from pypesq import pesq as pesq_mos
import argparse
from mir_eval.separation import bss_eval_sources
from pit_criterion import cal_loss
from collections import OrderedDict
import numpy as np
import torch
from data_whamr import TestDataset, TestCollate
from D2Net.mc_power_compression_D2Net import Net
from torch.utils.data import DataLoader
# from D2Net.mc_D2NET import Net
# from D2Net.mc_bss_D2NET import Net
import os
import torch.nn as nn
def remove_pad(inputs, inputs_lengths):
"""
Args:
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
inputs_lengths: torch.Tensor, [B]
Returns:
results: a list containing B items, each item is [C, T], T varies
"""
results = []
dim = inputs.dim()
if dim == 3:
C = inputs.size(1)
for input, length in zip(inputs, inputs_lengths):
if dim == 3: # [B, C, T]
results.append(input[:, :length].view(C, -1).cpu().numpy())
elif dim == 2: # [B, T]
results.append(input[:length].view(-1).cpu().numpy())
return results
parser = argparse.ArgumentParser('Evaluate separation performance using FaSNet + TAC')
# parser.add_argument('--test_dir', type=str, default="/home/weiWB/dataset/FaSNet_TAC/or_0.25_spk1_snr/MC_Libri_adhoc/test/2mic/", help='path to test/2mic/samples')
parser.add_argument('--test_dir', type=str,
default="/home/wangLS/dataset/mc_whamr/mc_whamr_mix_5/mc_whamr_tt_5R_list",
help='path to test/2mic/samples')
parser.add_argument('--model_path', type=str, default='/home/weiWB/code/FaSNet-TAC-PyTorch20/D2Net/exp/tmp/mc__power_compression_D2Net_whamr5/temp_best.pth.tar',
help='Path to model file created by training')
parser.add_argument('--cal_sdr', type=int, default=0,
help='Whether calculate SDR, add this option because calculation of SDR is very slow')
parser.add_argument('--use_cuda', type=int, default=1, help='Whether use GPU to separate speech')
# General config
# Task related
parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate')
# Network architecture
parser.add_argument('--mic', default=2, type=int, help='number of microphone')
class ISTFT(nn.Module):
def __init__(self, n_fft=512, hop_length=100, window_length=400):
super(ISTFT, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.window_length = window_length
self.Istft = torch.istft
self.pad = torch.nn.ZeroPad2d((0, 1, 0, 0))
def forward(self, x):
out = self.pad(x) # 1,4,641,257
out = out.permute(0, 3, 2, 1) # 1,2,641,257
out = torch.istft(out, n_fft=self.n_fft, hop_length=self.hop_length,
win_length=self.window_length,
return_complex=False).unsqueeze(1)
# out = out.unsqueeze(1)
return out
def evaluate(args):
eps = 1e-8
total_SISNRi = 0
total_SISNR = 0
total_wb_pesq = 0
total_nb_pesq_mos = 0
total_nb_pesq = 0
total_stoi = 0
total_SDRi = 0
total_sdr = 0
total_cnt = 0
# load D2Net model
model = Net()
if args.use_cuda:
model = torch.nn.DataParallel(model)
model.cuda()
# model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
model_info = torch.load(args.model_path)
try:
model.load_state_dict(model_info['model_state_dict'])
except KeyError:
state_dict = OrderedDict()
for k, v in model_info['model_state_dict'].items():
name = k.replace("module.", "") # remove 'module.'
state_dict[name] = v
model.load_state_dict(state_dict)
print(model)
model.eval()
# whamr data_loader
test_data = TestDataset(args.test_dir)
data_loader = DataLoader(test_data,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=TestCollate()
)
with torch.no_grad():
for i, (data) in enumerate(data_loader):
# Get batch data
padded_mixture, mixture_lengths, padded_source = data
if args.use_cuda:
padded_mixture = padded_mixture.cuda() # tensor ([1,2,64000])
mixture_lengths = mixture_lengths.cuda() # tensor[64000,64000]
padded_source = padded_source.cuda() # tensor ([1,2,64000])
# D2Net
estimate_source = model(padded_mixture)
# mc_power_compression_D2Net
padded_source = torch.split(padded_source, 1, dim=1)
padded_source = padded_source[0] # B,1,L
istft = ISTFT()
estimate_source = istft(estimate_source)
loss, max_snr, estimate_source, reorder_estimate_source = \
cal_loss(padded_source, estimate_source, mixture_lengths)
M, _, T = padded_mixture.shape # B,C,T
mixture_ref = torch.chunk(padded_mixture, args.mic, dim=1)[0] # [B, 1, T]
mixture_ref = mixture_ref.view(M, T) # [ 1, T]
mixture = remove_pad(mixture_ref, mixture_lengths) # ndarray(T,)
source = remove_pad(padded_source, mixture_lengths) # ndarray(C,T)
estimate_source = remove_pad(reorder_estimate_source, mixture_lengths) # ndarray(C,T)
# for each utterance mix:ndarray(T,);src_ref:(C,T);src_est:(C,T)
for mix, src_ref, src_est in zip(mixture, source, estimate_source):
print("Utt", total_cnt + 1)
# Compute SDRi
if args.cal_sdr:
avg_SDRi = cal_SDRi(src_ref, src_est, mix)
total_SDRi += avg_SDRi
# print("\tSDRi={0:.2f}".format(avg_SDRi))
src_ref = np.reshape(src_ref, [-1]) # ndarray(128000,)
src_est = np.reshape(src_est, [-1]) # ndarray(128000,)
# 2-D numpy to 1-D numpy compute pesq
# pesq
avg_wb_pesq = pesq(args.sample_rate, src_ref, src_est, on_error=PesqError.RETURN_VALUES) # pesq_batch求多输入输出的pesq
# pystoi
avg_stoi = stoi(src_ref, src_est, args.sample_rate)
total_wb_pesq += avg_wb_pesq
total_stoi += avg_stoi
total_cnt += 1
if args.cal_sdr:
print("Average SDR improvement: {0:.2f}".format(total_SDRi / (total_cnt + eps)))
print("Average wb_pesq improvement: {0:.2f}".format(total_wb_pesq / (total_cnt + eps)))
print("Average stoi improvement: {0:.2f}".format(total_stoi / (total_cnt + eps)))
def cal_SDRi(src_ref, src_est, mix):
"""Calculate Source-to-Distortion Ratio improvement (SDRi).
NOTE: bss_eval_sources is very very slow.
Args:
src_ref: numpy.ndarray, [C, T]
src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
mix: numpy.ndarray, [T]
Returns:
average_SDRi
"""
src_anchor = np.stack([mix, mix], axis=0)
sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
avg_SDRi = ((sdr[0] - sdr0[0]) + (sdr[1] - sdr0[1])) / 2.0
# print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1]))
return avg_SDRi
def cal_SISNRi(src_ref, src_est, mix):
"""Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
Args:
src_ref: numpy.ndarray, [C, T]
src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
mix: numpy.ndarray, [T]
Returns:
average_SISNRi
"""
sisnr1 = cal_SISNR(src_ref[0], src_est[0])
sisnr2 = cal_SISNR(src_ref[1], src_est[1])
sisnr1b = cal_SISNR(src_ref[0], mix)
sisnr2b = cal_SISNR(src_ref[1], mix)
# print("SISNR base1 {0:.2f} SISNR base2 {1:.2f}, avg {2:.2f}".format(
# sisnr1b, sisnr2b, (sisnr1b+sisnr2b)/2))
# print("SISNRi1: {0:.2f}, SISNRi2: {1:.2f}".format(sisnr1, sisnr2))
avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2.0
return avg_SISNRi
def cal_SISNR(ref_sig, out_sig, eps=1e-8):
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
Args:
ref_sig: numpy.ndarray, [T]
out_sig: numpy.ndarray, [T]
Returns:
SISNR
"""
assert len(ref_sig) == len(out_sig)
ref_sig = ref_sig - np.mean(ref_sig)
out_sig = out_sig - np.mean(out_sig)
ref_energy = np.sum(ref_sig ** 2) + eps
proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
noise = out_sig - proj
ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
sisnr = 10 * np.log(ratio + eps) / (np.log(10.0) + eps)
return sisnr
def cal_si_snr(source, estimate_source):
"""Calculate SI-SNR.
Arguments:
---------
source: [T, B, C],
Where B is batch size, T is the length of the sources, C is the number of sources
the ordering is made so that this loss is compatible with the class PitWrapper.
estimate_source: [T, B, C]
The estimated source.
Example:
---------
import numpy as np
x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
xhat = x[:, (1, 0)]
x = x.unsqueeze(-1).repeat(1, 1, 2)
xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
si_snr = -cal_si_snr(x, xhat)
print(si_snr)
tensor([[[ 25.2142, 144.1789],
[130.9283, 25.2142]]])
"""
EPS = 1e-8
# assert source.size() == estimate_source.size()
# device = estimate_source.device.type
# source_lengths = torch.tensor(
# [estimate_source.shape[0]] * estimate_source.shape[-2], device=device
# )
source_lengths = torch.tensor(
[estimate_source.shape[0]] * estimate_source.shape[-2]
)
source = torch.tensor(source)
estimate_source = torch.tensor(estimate_source)
mask = get_mask(source, source_lengths)
estimate_source *= mask
num_samples = (
source_lengths.contiguous().reshape(1, -1, 1).float()
) # [1, B, 1]
mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples
mean_estimate = (
torch.sum(estimate_source, dim=0, keepdim=True) / num_samples
)
zero_mean_target = source - mean_target
zero_mean_estimate = estimate_source - mean_estimate
# mask padding position along T
zero_mean_target *= mask
zero_mean_estimate *= mask
# Step 2. SI-SNR with PIT
# reshape to use broadcast
s_target = zero_mean_target # [T, B, C]
s_estimate = zero_mean_estimate # [T, B, C]
# s_target = s / ||s||^2
dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True) # [1, B, C]
s_target_energy = (
torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS
) # [1, B, C]
proj = dot * s_target / s_target_energy # [T, B, C]
# e_noise = s' - s_target
e_noise = s_estimate - proj # [T, B, C]
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (
torch.sum(e_noise ** 2, dim=0) + EPS
)
si_snr = 10 * torch.log10(si_snr_beforelog + EPS) # [B, C]
return -si_snr.unsqueeze(0)
def get_mask(source, source_lengths):
"""
Arguments
---------
source : [T, B, C]
source_lengths : [B]
Returns
-------
mask : [T, B, 1]
Example:
---------
source = torch.randn(4, 3, 2)
source_lengths = torch.Tensor([2, 1, 4]).int()
mask = get_mask(source, source_lengths)
print(mask)
tensor([[[1.],
[1.],
[1.]],
[[1.],
[0.],
[1.]],
[[0.],
[0.],
[1.]],
[[0.],
[0.],
[1.]]])
"""
mask = source.new_ones(source.size()[:-1]).unsqueeze(-1).transpose(1, -2)
B = source.size(-2)
for i in range(B):
mask[source_lengths[i]:, i] = 0
return mask.transpose(-2, 1)
def SDR(est, egs, mix):
'''
calculate SDR
est: Network generated audio
egs: Ground Truth
mix:
'''
length = est.shape[0]
sdr, _, _, _ = bss_eval_sources(egs[:length], est[:length])
mix_sdr, _, _, _ = bss_eval_sources(egs[:length], mix[:length])
return float(sdr - mix_sdr)
# ssnr
def calc_ssnr(signal, noise, frame_size):
"""
Calculate segmental signal noise ratio.
If file is not noisy then SNR is about 100dB.
:param signal: (list)
:param noise: (list)
:param frame_size: (int) ssnr frame size
:return: (value) SSNR(dB)
"""
if len(signal) != len(noise):
raise Exception("ERROR: Signal noise size mismatch")
number_of_frame_size = ceil(len(signal) / frame_size)
sum = 0
nonzero_frame_number = 0
segmental_signal_power = [0] * number_of_frame_size
segmental_noise_power = [0] * number_of_frame_size
for i in range(number_of_frame_size):
if i == number_of_frame_size - 1:
segmental_signal_power[i] = calc_power(signal[frame_size * i:])
segmental_noise_power[i] = calc_power(noise[frame_size * i:])
else:
segmental_signal_power[i] = calc_power(signal[frame_size * i:frame_size * (i + 1)])
segmental_noise_power[i] = calc_power(noise[frame_size * i:frame_size * (i + 1)])
if segmental_noise_power[i] == 0:
segmental_noise_power[i] = pow(0.1, 10)
if segmental_signal_power[i] != 0:
nonzero_frame_number += 1
sum += 10 * log10(segmental_signal_power[i] / segmental_noise_power[i])
ssnr = sum / nonzero_frame_number
return ssnr
def calc_power(input):
"""
Calculate power of input.
:param input: (list)
:return: (value)
"""
sum = 0
for n in input:
sum += pow(n, 2)
return sum / len(input)
# SSNR
def SNRseg(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75):
eps = np.finfo(np.float64).eps
winlength = round(frameLen * fs) # window length in samples
skiprate = int(np.floor((1 - overlap) * frameLen * fs)) # window skip in samples
MIN_SNR = -10 # minimum SNR in dB
MAX_SNR = 35 # maximum SNR in dB
hannWin = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
clean_speech_framed = extractOverlappedWindows(clean_speech, winlength, winlength - skiprate, hannWin)
processed_speech_framed = extractOverlappedWindows(processed_speech, winlength, winlength - skiprate, hannWin)
signal_energy = np.power(clean_speech_framed, 2).sum(-1)
noise_energy = np.power(clean_speech_framed - processed_speech_framed, 2).sum(-1)
segmental_snr = 10 * np.log10(signal_energy / (noise_energy + eps) + eps)
segmental_snr[segmental_snr < MIN_SNR] = MIN_SNR
segmental_snr[segmental_snr > MAX_SNR] = MAX_SNR
segmental_snr = segmental_snr[:-1] # remove last frame -> not valid
return np.mean(segmental_snr)
def extractOverlappedWindows(x, nperseg, noverlap, window=None):
# source: https://github.com/scipy/scipy/blob/v1.2.1/scipy/signal/spectral.py
step = nperseg - noverlap
shape = x.shape[:-1] + ((x.shape[-1] - noverlap) // step, nperseg)
strides = x.strides[:-1] + (step * x.strides[-1], x.strides[-1])
result = np.lib.stride_tricks.as_strided(x, shape=shape,
strides=strides)
if window is not None:
result = window * result
return result
if __name__ == '__main__':
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 指定使用第几个显卡
args = parser.parse_args()
print(args)
evaluate(args)