kaldi 源码分析(十) - gmm-init-mono.c分析

简介: 一直没有搞明白 hmm-gmm 之间是通过什么联系起来的,花了些时间查代码,看到最直观联系的就是 gmm-init-mono 工具。gmm-init-mono 基础类通过上述看到,主要的配置都是 在 topo 文件中, 这里需要将一些常...

一直没有搞明白 hmm-gmm 之间是通过什么联系起来的,花了些时间查代码,看到最直观联系的就是 gmm-init-mono 工具。


img_28424dbf260083d0390bb38a8f84c847.jpe
gmm-init-mono 基础类

通过上述看到,主要的配置都是 在 topo 文件中, 这里需要将一些常见的名称理解下来,这里直接贴出英文内容:

名称 解释
phone a phone index (1, 2, 3 ...)
HMM-state a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) (HmmTopology 中 HmmState 位置)
pdf-id a number output by the Compute function of ContextDependency (it indexes pdf's, either forward or self-loop). Zero-based. (HmmState 中 pdf 中的 index)
transition-state the states for which we estimate transition probabilities for transitions out of them. In some topologies, will map one-to-one with pdf-ids. One-based, since it appears on FSTs. (状态转换描述)
transition-index identifier of a transition (or final-prob) in the HMM. Indexes the "transitions" vector in HmmTopology::HmmState. (状态转换 index) [if it is out of range, equal to transitions.size(), it refers to the final-prob.] Zero-based.
transition-id identifier of a unique parameter of the TransitionModel. Associated with a (transition-state, transition-index) pair.One-based, since it appears on FSTs. (状态转换 id)

从 train_mono.sh 中获取 gmm-init-mono 命令详细内容

  $cmd JOB=1 $dir/log/init.log \
    gmm-init-mono $shared_phones_opt "--train-feats=$feats subset-feats --n=10 ark:- ark:-|" $lang/topo $feat_dim \
    $dir/0.mdl $dir/tree || exit 1;
# 实际执行的内容如下:
$ gmm-init-mono --shared-phones=$lang/phones/sets.int "--train-feats=ark,s,cs:apply-cmvn --norm-vars=true --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- | subset-feats --n=10 ark:- ark:-|" $lang/topo $feat_dim \
    $dir/0.mdl $dir/tree

从上述命令来看 --train-feats 指定了 gmm-init-mono 初始化使用的特征向量数据,其中通过 apply-cmvn 将 feats 进行归一化,然后通过 subset-feats 来取出 10 个特征向量作为参数
下面是具体代码的简单分析:

    // 读入一定量的特征,进行统计获取 gmm 模型的 means 及 variances 数据
    if (train_feats != "") {
      double count = 0.0;
      Vector<double> var_stats(dim);
      Vector<double> mean_stats(dim);
      SequentialDoubleMatrixReader feat_reader(train_feats);
      for (; !feat_reader.Done(); feat_reader.Next()) {
        const Matrix<double> &mat = feat_reader.Value();
        for (int32 i = 0; i < mat.NumRows(); i++) {
          count += 1.0;
          var_stats.AddVec2(1.0, mat.Row(i));
          mean_stats.AddVec(1.0, mat.Row(i));
        }
      }
      if (count == 0) { KALDI_ERR << "no features were seen."; }
      var_stats.Scale(1.0/count);
      // 计算均值
      mean_stats.Scale(1.0/count);
      var_stats.AddVec2(-1.0, mean_stats);
      if (var_stats.Min() <= 0.0)
        KALDI_ERR << "bad variance";
      var_stats.InvertElements();
      glob_inv_var.CopyFromVec(var_stats);
      glob_mean.CopyFromVec(mean_stats);
    }

    HmmTopology topo;
    bool binary_in;
    Input ki(topo_filename, &binary_in);
    topo.Read(ki.Stream(), binary_in);

    const std::vector<int32> &phones = topo.GetPhones();

    // 根据 topo 中的配置来获取每个 phone 音素 pdf 类数量
    std::vector<int32> phone2num_pdf_classes (1+phones.back());
    for (size_t i = 0; i < phones.size(); i++)
      phone2num_pdf_classes[phones[i]] = topo.NumPdfClasses(phones[i]);

    // 根据每个 phone 音素对应 pdf 数量来创建 ContextDependency (决策树)对象
    // Now the tree [not really a tree at this point]:
    ContextDependency *ctx_dep = NULL;
    if (shared_phones_rxfilename == "") {  // No sharing of phones: standard approach.
      ctx_dep = MonophoneContextDependency(phones, phone2num_pdf_classes);
    } else {
      std::vector<std::vector<int32> > shared_phones;
      ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones);
      // ReadSharedPhonesList crashes on error.
      ctx_dep = MonophoneContextDependencyShared(shared_phones, phone2num_pdf_classes);
    }

    // 获取所有 pdfs 数量 = phones * 每个 phone 含有的 pdfclass 数量
    int32 num_pdfs = ctx_dep->NumPdfs();

    // 根据特征统计出的结果,创建 DiagGmm 初始化模型
    AmDiagGmm am_gmm;
    DiagGmm gmm;
    gmm.Resize(1, dim);
    {  // Initialize the gmm.
      Matrix<BaseFloat> inv_var(1, dim);
      inv_var.Row(0).CopyFromVec(glob_inv_var);
      Matrix<BaseFloat> mu(1, dim);
      mu.Row(0).CopyFromVec(glob_mean);
      Vector<BaseFloat> weights(1);
      weights.Set(1.0);
      gmm.SetInvVarsAndMeans(inv_var, mu);
      gmm.SetWeights(weights);
      gmm.ComputeGconsts();
    }

    // 将每个 pdf 都初始化为上述创建的 gmm ,并与pdf对应起来
    for (int i = 0; i < num_pdfs; i++)
      am_gmm.AddPdf(gmm);

    // 添加 perturb_factor 因子
    if (perturb_factor != 0.0) {
      for (int i = 0; i < num_pdfs; i++)
        am_gmm.GetPdf(i).Perturb(perturb_factor);
    }

    // 将 ContextDependency 与 topo 合并为一个模型文件保存下来
    // Now the transition model:
    TransitionModel trans_model(*ctx_dep, topo);

    {
      Output ko(model_filename, binary);
      trans_model.Write(ko.Stream(), binary);
      am_gmm.Write(ko.Stream(), binary);
    }

    // 将ContextDependency存为决策树文件
    // Now write the tree.
    ctx_dep->Write(Output(tree_filename, binary).Stream(),
                   binary);
目录
相关文章
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
|
3月前
|
固态存储 Java Serverless
CV目标检测 Task03: 化劲儿-损失函数设计 打卡笔记
CV目标检测 Task03: 化劲儿-损失函数设计 打卡笔记
32 0
|
8月前
|
机器学习/深度学习 编解码 决策智能
计算机视觉实战(十一)Scale Invariant Feature Transform(SIFT)(附完整代码)
计算机视觉实战(十一)Scale Invariant Feature Transform(SIFT)(附完整代码)
|
9月前
|
存储 网络协议 数据安全/隐私保护
mjpg-streamer框架分析
mjpg-streamer框架分析
46 0
|
11月前
|
机器学习/深度学习 并行计算 算法
手把手教你使用LabVIEW OpenCV dnn实现物体识别(Object Detection)含源码
今天和大家一起分享如何使用LabVIEW调用pb模型实现物体识别
77 0
|
11月前
|
数据挖掘 计算机视觉
即插即用 | SA模块携Shuffle Attention带你CV全任务涨点(文末获取论文与源码)(二)
即插即用 | SA模块携Shuffle Attention带你CV全任务涨点(文末获取论文与源码)(二)
81 0
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
即插即用 | SA模块携Shuffle Attention带你CV全任务涨点(文末获取论文与源码)(一)
即插即用 | SA模块携Shuffle Attention带你CV全任务涨点(文末获取论文与源码)(一)
491 0
|
12月前
|
并行计算 编译器 Linux
TVM 从入门到精通 | 安装 TVM (Part 1)
TVM 从入门到精通 | 安装 TVM (Part 1)
314 0
|
12月前
|
机器学习/深度学习 存储 并行计算
Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型
Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型
168 0
|
PyTorch 算法框架/工具
pytorch中meter.ClassErrorMeter()使用方法
pytorch中meter.ClassErrorMeter()使用方法
140 0