基于PaddleSpeech的婴儿啼哭识别(上)

简介: 基于PaddleSpeech的婴儿啼哭识别(上)

四、模型训练


1.选取预训练模型


选取cnn14作为 backbone,用于提取音频的特征:

from paddlespeech.cls.models import cnn14
backbone = cnn14(pretrained=True, extract_embedding=True)


2.构建分类模型


SoundClassifer接收cnn14作为backbone模型,并创建下游的分类网络:

import paddle.nn as nn
class SoundClassifier(nn.Layer):
    def __init__(self, backbone, num_class, dropout=0.1):
        super().__init__()
        self.backbone = backbone
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(self.backbone.emb_size, num_class)
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.backbone(x)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits
model = SoundClassifier(backbone, num_class=len(train_ds.label_list))


3.finetune


# 定义优化器和 Loss
optimizer = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters())
criterion = paddle.nn.loss.CrossEntropyLoss()
from paddleaudio.utils import logger
epochs = 20
steps_per_epoch = len(train_loader)
log_freq = 10
eval_freq = 10
for epoch in range(1, epochs + 1):
    model.train()
    avg_loss = 0
    num_corrects = 0
    num_samples = 0
    for batch_idx, batch in enumerate(train_loader):
        waveforms, labels = batch
        feats = feature_extractor(waveforms)
        feats = paddle.transpose(feats, [0, 2, 1])  # [B, N, T] -> [B, T, N]
        logits = model(feats)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        if isinstance(optimizer._learning_rate,
                      paddle.optimizer.lr.LRScheduler):
            optimizer._learning_rate.step()
        optimizer.clear_grad()
        # Calculate loss
        avg_loss += loss.numpy()[0]
        # Calculate metrics
        preds = paddle.argmax(logits, axis=1)
        num_corrects += (preds == labels).numpy().sum()
        num_samples += feats.shape[0]
        if (batch_idx + 1) % log_freq == 0:
            lr = optimizer.get_lr()
            avg_loss /= log_freq
            avg_acc = num_corrects / num_samples
            print_msg = 'Epoch={}/{}, Step={}/{}'.format(
                epoch, epochs, batch_idx + 1, steps_per_epoch)
            print_msg += ' loss={:.4f}'.format(avg_loss)
            print_msg += ' acc={:.4f}'.format(avg_acc)
            print_msg += ' lr={:.6f}'.format(lr)
            logger.train(print_msg)
            avg_loss = 0
            num_corrects = 0
            num_samples = 0
[2022-08-24 02:20:49,381] [   TRAIN] - Epoch=17/20, Step=10/15 loss=1.3319 acc=0.4875 lr=0.000100
[2022-08-24 02:21:08,107] [   TRAIN] - Epoch=18/20, Step=10/15 loss=1.3222 acc=0.4719 lr=0.000100
[2022-08-24 02:21:08,107] [   TRAIN] - Epoch=18/20, Step=10/15 loss=1.3222 acc=0.4719 lr=0.000100
[2022-08-24 02:21:26,884] [   TRAIN] - Epoch=19/20, Step=10/15 loss=1.2539 acc=0.5125 lr=0.000100
[2022-08-24 02:21:26,884] [   TRAIN] - Epoch=19/20, Step=10/15 loss=1.2539 acc=0.5125 lr=0.000100
[2022-08-24 02:21:45,579] [   TRAIN] - Epoch=20/20, Step=10/15 loss=1.2021 acc=0.5281 lr=0.000100
[2022-08-24 02:21:45,579] [   TRAIN] - Epoch=20/20, Step=10/15 loss=1.2021 acc=0.5281 lr=0.000100


五、模型训练


top_k = 3
wav_file = 'test/test_0.wav'
n_fft = 1024
win_length = 1024
hop_length = 320
f_min=50.0
f_max=16000.0
waveform, sr = load(wav_file, sr=sr)
feature_extractor = LogMelSpectrogram(
    sr=sr, 
    n_fft=n_fft, 
    hop_length=hop_length, 
    win_length=win_length, 
    window='hann', 
    f_min=f_min, 
    f_max=f_max, 
    n_mels=64)
feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))
feats = paddle.transpose(feats, [0, 2, 1])  # [B, N, T] -> [B, T, N]
logits = model(feats)
probs = nn.functional.softmax(logits, axis=1).numpy()
sorted_indices = probs[0].argsort()
msg = f'[{wav_file}]\n'
for idx in sorted_indices[-1:-top_k-1:-1]:
    msg += f'{train_ds.label_list[idx]}: {probs[0][idx]:.5f}\n'
print(msg)    
[test/test_0.wav]
diaper: 0.50155
sleepy: 0.41397
hug: 0.05912


六、注意事项


  • 1.自定义数据集,格式可参考文档;
  • 2.统一音频尺寸(例如音频长度、采样频率)
  • 3.系统了解,可学习aistudio.baidu.com/aistudio/ed… 课程


目录
相关文章
|
6月前
|
机器学习/深度学习 算法
应用规则学习算法识别有毒的蘑菇
应用规则学习算法识别有毒的蘑菇
|
6月前
|
算法 开发工具 计算机视觉
条形码识别研究
条形码识别研究
153 0
|
计算机视觉 数据格式 Python
人脸口罩检测:使用YOLOv5检测公共场所是否佩戴口罩
人脸口罩检测:使用YOLOv5检测公共场所是否佩戴口罩
138 0
|
机器学习/深度学习 传感器 安全
【红绿灯识别】基于计算机视觉红绿灯识别附Matlab代码
【红绿灯识别】基于计算机视觉红绿灯识别附Matlab代码
|
C++ 计算机视觉 Python
C++/Yolov8人体特征识别 广场室内 人数统计
这篇博客针对<<C++/Yolov8人体特征识别 广场室内 人数统计>>编写代码,代码整洁,规则,易读。 学习与应用推荐首选。
385 0
|
数据采集 存储 搜索推荐
分析新闻评论数据并进行情绪识别
爬取新闻评论数据并进行情绪识别的目的是为了从网页中抓取用户对新闻事件或话题的评价内容,并从中识别和提取用户的情绪或态度,如积极、消极、中立等。爬取新闻评论数据并进行情绪识别有以下几个优势: 1)可以了解用户对新闻事件或话题的看法和感受,以及影响他们情绪的因素; 2)可以分析用户的情绪变化和趋势,以及与新闻事件或话题的相关性和影响力; 3)可以根据用户的情绪进行个性化的推荐或服务,如提供正能量的内容、提供帮助或建议等;
238 1
|
机器学习/深度学习 传感器 算法
【水果质量检测】基于机器视觉实现苹果疾病识别分类附matlab代码
【水果质量检测】基于机器视觉实现苹果疾病识别分类附matlab代码
|
存储 机器学习/深度学习 传感器
【发票识别】基于模板匹配实现发票识别研究附matlab代码
【发票识别】基于模板匹配实现发票识别研究附matlab代码
|
数据可视化 数据挖掘
基于PaddleClas2.3 的鲜花识别
基于PaddleClas2.3 的鲜花识别
197 0
基于PaddleClas2.3 的鲜花识别
|
XML JSON 算法
X光安检图像检测挑战赛3.0(上)
X光安检图像检测挑战赛3.0(上)
359 0
X光安检图像检测挑战赛3.0(上)