kaldi 源码分析(十) - gmm-init-mono.c分析

简介: 一直没有搞明白 hmm-gmm 之间是通过什么联系起来的,花了些时间查代码,看到最直观联系的就是 gmm-init-mono 工具。gmm-init-mono 基础类通过上述看到,主要的配置都是 在 topo 文件中, 这里需要将一些常...

一直没有搞明白 hmm-gmm 之间是通过什么联系起来的,花了些时间查代码,看到最直观联系的就是 gmm-init-mono 工具。


img_28424dbf260083d0390bb38a8f84c847.jpe
gmm-init-mono 基础类

通过上述看到,主要的配置都是 在 topo 文件中, 这里需要将一些常见的名称理解下来,这里直接贴出英文内容:

名称 解释
phone a phone index (1, 2, 3 ...)
HMM-state a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) (HmmTopology 中 HmmState 位置)
pdf-id a number output by the Compute function of ContextDependency (it indexes pdf's, either forward or self-loop). Zero-based. (HmmState 中 pdf 中的 index)
transition-state the states for which we estimate transition probabilities for transitions out of them. In some topologies, will map one-to-one with pdf-ids. One-based, since it appears on FSTs. (状态转换描述)
transition-index identifier of a transition (or final-prob) in the HMM. Indexes the "transitions" vector in HmmTopology::HmmState. (状态转换 index) [if it is out of range, equal to transitions.size(), it refers to the final-prob.] Zero-based.
transition-id identifier of a unique parameter of the TransitionModel. Associated with a (transition-state, transition-index) pair.One-based, since it appears on FSTs. (状态转换 id)

从 train_mono.sh 中获取 gmm-init-mono 命令详细内容

  $cmd JOB=1 $dir/log/init.log \
    gmm-init-mono $shared_phones_opt "--train-feats=$feats subset-feats --n=10 ark:- ark:-|" $lang/topo $feat_dim \
    $dir/0.mdl $dir/tree || exit 1;
# 实际执行的内容如下:
$ gmm-init-mono --shared-phones=$lang/phones/sets.int "--train-feats=ark,s,cs:apply-cmvn --norm-vars=true --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- | subset-feats --n=10 ark:- ark:-|" $lang/topo $feat_dim \
    $dir/0.mdl $dir/tree

从上述命令来看 --train-feats 指定了 gmm-init-mono 初始化使用的特征向量数据,其中通过 apply-cmvn 将 feats 进行归一化,然后通过 subset-feats 来取出 10 个特征向量作为参数
下面是具体代码的简单分析:

    // 读入一定量的特征,进行统计获取 gmm 模型的 means 及 variances 数据
    if (train_feats != "") {
      double count = 0.0;
      Vector<double> var_stats(dim);
      Vector<double> mean_stats(dim);
      SequentialDoubleMatrixReader feat_reader(train_feats);
      for (; !feat_reader.Done(); feat_reader.Next()) {
        const Matrix<double> &mat = feat_reader.Value();
        for (int32 i = 0; i < mat.NumRows(); i++) {
          count += 1.0;
          var_stats.AddVec2(1.0, mat.Row(i));
          mean_stats.AddVec(1.0, mat.Row(i));
        }
      }
      if (count == 0) { KALDI_ERR << "no features were seen."; }
      var_stats.Scale(1.0/count);
      // 计算均值
      mean_stats.Scale(1.0/count);
      var_stats.AddVec2(-1.0, mean_stats);
      if (var_stats.Min() <= 0.0)
        KALDI_ERR << "bad variance";
      var_stats.InvertElements();
      glob_inv_var.CopyFromVec(var_stats);
      glob_mean.CopyFromVec(mean_stats);
    }

    HmmTopology topo;
    bool binary_in;
    Input ki(topo_filename, &binary_in);
    topo.Read(ki.Stream(), binary_in);

    const std::vector<int32> &phones = topo.GetPhones();

    // 根据 topo 中的配置来获取每个 phone 音素 pdf 类数量
    std::vector<int32> phone2num_pdf_classes (1+phones.back());
    for (size_t i = 0; i < phones.size(); i++)
      phone2num_pdf_classes[phones[i]] = topo.NumPdfClasses(phones[i]);

    // 根据每个 phone 音素对应 pdf 数量来创建 ContextDependency (决策树)对象
    // Now the tree [not really a tree at this point]:
    ContextDependency *ctx_dep = NULL;
    if (shared_phones_rxfilename == "") {  // No sharing of phones: standard approach.
      ctx_dep = MonophoneContextDependency(phones, phone2num_pdf_classes);
    } else {
      std::vector<std::vector<int32> > shared_phones;
      ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones);
      // ReadSharedPhonesList crashes on error.
      ctx_dep = MonophoneContextDependencyShared(shared_phones, phone2num_pdf_classes);
    }

    // 获取所有 pdfs 数量 = phones * 每个 phone 含有的 pdfclass 数量
    int32 num_pdfs = ctx_dep->NumPdfs();

    // 根据特征统计出的结果,创建 DiagGmm 初始化模型
    AmDiagGmm am_gmm;
    DiagGmm gmm;
    gmm.Resize(1, dim);
    {  // Initialize the gmm.
      Matrix<BaseFloat> inv_var(1, dim);
      inv_var.Row(0).CopyFromVec(glob_inv_var);
      Matrix<BaseFloat> mu(1, dim);
      mu.Row(0).CopyFromVec(glob_mean);
      Vector<BaseFloat> weights(1);
      weights.Set(1.0);
      gmm.SetInvVarsAndMeans(inv_var, mu);
      gmm.SetWeights(weights);
      gmm.ComputeGconsts();
    }

    // 将每个 pdf 都初始化为上述创建的 gmm ,并与pdf对应起来
    for (int i = 0; i < num_pdfs; i++)
      am_gmm.AddPdf(gmm);

    // 添加 perturb_factor 因子
    if (perturb_factor != 0.0) {
      for (int i = 0; i < num_pdfs; i++)
        am_gmm.GetPdf(i).Perturb(perturb_factor);
    }

    // 将 ContextDependency 与 topo 合并为一个模型文件保存下来
    // Now the transition model:
    TransitionModel trans_model(*ctx_dep, topo);

    {
      Output ko(model_filename, binary);
      trans_model.Write(ko.Stream(), binary);
      am_gmm.Write(ko.Stream(), binary);
    }

    // 将ContextDependency存为决策树文件
    // Now write the tree.
    ctx_dep->Write(Output(tree_filename, binary).Stream(),
                   binary);
目录
相关文章
|
1月前
|
机器学习/深度学习 XML 计算机视觉
OpenCV(Open Source Computer Vision Library)是一个开源的计算机视觉和机器学习库,它提供了大量的函数和工具,用于处理图像和视频数据。
OpenCV(Open Source Computer Vision Library)是一个开源的计算机视觉和机器学习库,它提供了大量的函数和工具,用于处理图像和视频数据。
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
【2月更文挑战第22天】本文介绍基于Python的tensorflow库,将tensorflow与keras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++等其他语言中将其打开的方法~
119 1
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
|
机器学习/深度学习 算法 Serverless
【李宏毅机器学习CP4】(task2)回归+Python Basics with Numpy
第一部分:回归栗子 ps:CP3的部分在上一篇笔记中【李宏毅机器学习】CP1-3笔记了。 1.问题描述 现在假设有10个x_data和y
167 0
【李宏毅机器学习CP4】(task2)回归+Python Basics with Numpy
|
并行计算 编译器 Linux
TVM 从入门到精通 | 安装 TVM (Part 1)
TVM 从入门到精通 | 安装 TVM (Part 1)
410 0
|
机器学习/深度学习 存储 并行计算
Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型
Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型
224 0
|
机器学习/深度学习 存储 PyTorch
一个快速构造GAN的教程:如何用pytorch构造DCGAN(上)
一个快速构造GAN的教程:如何用pytorch构造DCGAN
145 0
一个快速构造GAN的教程:如何用pytorch构造DCGAN(上)
|
机器学习/深度学习 存储 并行计算
一个快速构造GAN的教程:如何用pytorch构造DCGAN(下)
一个快速构造GAN的教程:如何用pytorch构造DCGAN
138 0
一个快速构造GAN的教程:如何用pytorch构造DCGAN(下)
|
Python
YOLOv5的Tricks | 【Trick13】YOLOv5的detect.py脚本的解析与简化
YOLOv5的Tricks | 【Trick13】YOLOv5的detect.py脚本的解析与简化
1387 0
YOLOv5的Tricks | 【Trick13】YOLOv5的detect.py脚本的解析与简化
|
存储 数据可视化 计算机视觉
目标检测的Tricks | 【Trick10】工具类文件调用(coco评价指标包、日志工具、Tensorboard工具...)
目标检测的Tricks | 【Trick10】工具类文件调用(coco评价指标包、日志工具、Tensorboard工具...)
669 0
目标检测的Tricks | 【Trick10】工具类文件调用(coco评价指标包、日志工具、Tensorboard工具...)
|
PyTorch 算法框架/工具
pytorch中meter.ClassErrorMeter()使用方法
pytorch中meter.ClassErrorMeter()使用方法
157 0