Javascript类型推断(4) - 隐藏层的更新

简介: # Javascript类型推断(4) - 隐藏层的更新 熟悉了整个流程之后,我们可以关注更多的细节。 前面讲训练过程时,没有讲enhance_data的细节。这一部分的主要功能是更新隐藏层。它的调用点在: ```python def train(): train_reader = create_reader(files['train']['file'], is_trainin

Javascript类型推断(4) - 隐藏层的更新

熟悉了整个流程之后,我们可以关注更多的细节。

前面讲训练过程时,没有讲enhance_data的细节。这一部分的主要功能是更新隐藏层。它的调用点在:

def train():
    train_reader = create_reader(files['train']['file'], is_training=True)
    step = 0
    pp = C.logging.ProgressPrinter(freq=10, tag='Training')
    for epoch in range(num_epochs):
        trainer = create_trainer()
        epoch_end = (epoch+1) * epoch_size
        print('epoch_end=%d' % epoch_end)
        while step < epoch_end:
            data = train_reader.next_minibatch(minibatch_size, input_map={
                x: train_reader.streams.source,
                y: train_reader.streams.slot_labels
            })

            # Enhance data
            enhance_data(data, enc)
            # Train model
            trainer.train_minibatch(data)
            pp.update_with_trainer(trainer, with_metric=True)
            step += data[y].num_samples

在讲解enhance_data之前,我们先讲解一下train.txt生成的ctf的结构,以便大家了解在训练中用到的数据的真正含义。

ctf文件分析

前面我们专注于主线,现在我们重温一下txt2ctf的结果。

ctf是CNTK Text Format的缩写,是CNTK处理文本的格式。
ctf格式是这样子的:

0    |S0 15:1    |S1 0:1
0    |S0 3:1    |S1 0:1
0    |S0 25:1    |S1 1:1
0    |S0 2:1    |S1 0:1
0    |S0 0:1    |S1 0:1
0    |S0 1:1    |S1 0:1
0    |S0 3:1    |S1 0:1
0    |S0 16:1    |S1 4:1
0    |S0 2:1    |S1 0:1
0    |S0 17:1    |S1 0:1
0    |S0 1:1    |S1 0:1
0    |S0 18:1    |S1 6:1
0    |S0 6:1    |S1 0:1
0    |S0 19:1    |S1 3:1
0    |S0 4:1    |S1 0:1
0    |S0 16:1    |S1 4:1
0    |S0 5:1    |S1 0:1
0    |S0 1:1    |S1 0:1
0    |S0 20:1    |S1 0:1

我们先来看第一行:

0    |S0 15:1    |S1 0:1

行首的0是第1个文件,1就是第2个文件。我们只有2个文件,所以取值先为0,后面的是1。
S0是source_wl,15是第16行。
我们看下source_wl的内容:

0
;
=
let
(
)
.
{
}
Test
value
v
this
new
,
<s>
s
"s"
console
log
</s>
class
print
TestNumber
mul
a
public
:
number
constructor
extends
return
*
false
[
_UNKNOWN_

查表可知,对应的是。

而“|S1 0:1”,就是target_wl的第1行,这是个O。

我们再以第2行为例:

0    |S0 3:1    |S1 0:1

左边S0 3:1查source_wl对应的是let。
右边的S1 0:1,在target_wl中对应的是O。
let是个关键字,没有类型信息。

数据的读取

reader

有了上面的格式的基本知识,我们再看reader的代码就清晰了:

def create_reader(path, is_training):
    return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(
            source        = C.io.StreamDef(field='S0', shape=vocab_size, is_sparse=True), 
            slot_labels    = C.io.StreamDef(field='S1', shape=num_labels, is_sparse=True)
    )), randomize=is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)

source读的是S0,就是token信息。slot_labels读的是S1,就是类型信息。
shape形状分别是源字典和类型字典的大小:

vocab_size = len(source_dict)
num_labels = len(target_dict)

source_dict和target_dict读取文件的过程在:

files = {
    'train': { 'file': 'data/train.ctf', 'location': 0 },
    'valid': { 'file': 'data/valid.ctf', 'location': 0 },
    'test': { 'file': 'data/test.ctf', 'location': 0 },
    'source': { 'file': 'data/source_wl', 'location': 1 },
    'target': { 'file': 'data/target_wl', 'location': 1 }
}

source_wl = [line.rstrip('\n') for line in open(files['source']['file'])]
target_wl = [line.rstrip('\n') for line in open(files['target']['file'])]
source_dict = {source_wl[i]:i for i in range(len(source_wl))}
target_dict = {target_wl[i]:i for i in range(len(target_wl))}

inputs

到了inputs的时候,就已经没有S0, S1和0,1这样的标签了,而是变成数组:

    inputs = C.ops.argmax(x).eval({x: data[x]})

下面这两重循环,i循环对应到第一个ts工程,j循环是对应到每一个token:

    for i in range(len(inputs)):
        ts = []
        table = {}
        counts = {}
        for j in range(len(inputs[i])):

还以两个ts工程为例,inputs第一级是两个元素的数组,对应了ctf中第一列的0和1:

inputs[0]= [15.  3. 25.  2.  0.  1.  3. 16.  2. 17.  1.  3.  0.  2. 25. 35.  0.  1.
  3.  0.  2.  0. 35.  0.  1.  3.  0.  2.  0. 32.  0.  1. 18.  6. 19.  4.
 16.  5.  1. 20.]
inputs[1]= [15. 21.  9.  7. 26. 10. 27. 28.  1. 29.  4. 11.  5.  7. 12.  6. 10.  2.
 11.  1.  8. 22.  4.  5.  7. 18.  6. 19.  4. 12.  6. 10.  5.  1.  8. 35.
  4.  0.  5.  7. 31. 12.  6. 10. 35.  0.  1.  8.  8.  3. 35.  2. 13.  9.
  4.  0.  5.  1. 35.  6. 22.  4.  5.  1. 21. 35. 30.  9.  7. 35.  4.  0.
  5.  7. 31. 12.  6. 10. 35.  0.  1.  8.  8.  3.  0.  2. 13. 35.  4.  0.
  5.  1.  0.  6. 22.  4.  5.  1. 35. 35.  4.  0. 14.  0.  5.  7. 31.  0.
 35.  0.  1.  8. 35. 35.  7. 35. 27. 35.  1. 35. 27. 28.  1.  8. 35. 35.
  4. 35.  5.  7. 18.  6. 19.  4. 35.  6. 35.  5.  1. 18.  6. 19.  4. 35.
  6. 35.  5.  1.  8.  3.  0.  2. 33.  1.  3.  0.  2.  0.  1.  3.  0.  2.
 17.  1.  3.  0.  2. 34.  0. 14.  0. 14.  0. 14.  0. 35.  1. 20.]

比如inputs0是15,inputs0是3, inputs0是25,对应ctf中的值如下:

0    |S0 15:1    |S1 0:1
0    |S0 3:1    |S1 0:1
0    |S0 25:1    |S1 1:1

同样inputs[1]的前几位15. 21. 9. 7. 26.对应train.ctf中的

1    |S0 15:1    |S1 0:1
1    |S0 21:1    |S1 0:1
1    |S0 9:1    |S1 2:1
1    |S0 7:1    |S1 0:1
1    |S0 26:1    |S1 0:1

enhance_data - 更新隐藏层

下面我们就把上面的信息串联一下,看下enhance_data的实现:

def enhance_data(data, enc):
    guesses = enc.eval({x: data[x]})
    inputs = C.ops.argmax(x).eval({x: data[x]})
    tables = []
    for i in range(len(inputs)):
        ts = []
        table = {}
        counts = {}
        # 第一次遍历token
        for j in range(len(inputs[i])):
            inp = int(inputs[i][j])
            if inp not in table:
                table[inp] = guesses[i][j]
                counts[inp] = 1
            else:
                table[inp] += guesses[i][j]
                counts[inp] += 1
        for inp in table:
            table[inp] /= counts[inp]

为每一个token保存之前计算的guesses值。
然后将这些值汇总一下:

        # 第二次遍历token,生成每个ts文件的guess表
        for j in range(len(inputs[i])):
            inp = int(inputs[i][j])
            ts.append(table[inp])
        # 汇总成每个工程的总guess表
        tables.append(np.array(np.float32(ts)))

最后,我们更新下隐藏层的数据:

    s = C.io.MinibatchSourceFromData(dict(t=(tables, C.layers.typing.Sequence[C.layers.typing.tensor])))
    mems = s.next_minibatch(minibatch_size)
    print('mems=',mems)
    data[t] = mems[s.streams['t']]

其中,data[t]中的t,定义于:

t = C.sequence.input_variable(hidden_dim, name="t")

一些经验和技巧

修改epoch_size

我们运行lexer.py之后,会输出统计参数,例:

Train projects: 580
Validation projects: 72
Test projects: 73
Train files: 51431
Validation files: 4704
Test files: 3916
Producing vocabularies
Size of source vocab: 42631
Size of target vocab: 12386
Writing train/valid/test files
Overall tokens: 19286557 train, 1663181 valid and 1806291 test

我们可以据此修改infer.py中的参数:

epoch_size = 19286557

如果遇到node OOM问题怎么办?

适当增加一下max-old-space-size的大小,单位为M。
例:

node --max-old-space-size=16384 GetTypes.js
目录
相关文章
|
算法 JavaScript 前端开发
Javascript类型推断(3) - 算法模型解析
# Javascript类型推断(3) - 算法模型解析 ## 构建训练模型 上一节我们介绍了生成训练集,测试集,验证集的方法,以及生成词表的方法。 这5个文件构成了训练的基本素材: ```python files = { 'train': { 'file': 'data/train.ctf', 'location': 0 }, 'valid': { 'file':
1042 0
|
JavaScript 开发工具 git
Javascript类型推断(1) - 获取token和类型
Javascript类型推断(1) - 获取token和类型 ## js类型推断的三种思路 第一种思路是用传统的编译类的方法,推断是没啥好办法,但是可以用来验证。 第二种思路是利用对象的属性或方法的调用来推断,JSNice就是这样做的。 第三种思路比较先进,充分利用到越来越流行的Typescript,通过学习Typescript生成的javascript进行监督学习。这种思路是Vi
837 0
|
并行计算 异构计算 Python
Javascript类型推断(2) - 开始训练吧
# Javascript类型推断(2) - 开始训练吧 ## 准备训练数据 下面我们将上一节获取的类型数据信息进行预处理,转化为可以训练的数据。 代码在GetTypes.js中,会创建三个相关目录: ```ts let root = "data/Repos-cleaned"; let outputDirGold = "data/outputs-gold/"; let
657 0
|
1月前
|
JavaScript 前端开发
JavaScript中的原型 保姆级文章一文搞懂
本文详细解析了JavaScript中的原型概念,从构造函数、原型对象、`__proto__`属性、`constructor`属性到原型链,层层递进地解释了JavaScript如何通过原型实现继承机制。适合初学者深入理解JS面向对象编程的核心原理。
27 1
JavaScript中的原型 保姆级文章一文搞懂
|
5月前
|
JavaScript Java 测试技术
基于springboot+vue.js+uniapp的客户关系管理系统附带文章源码部署视频讲解等
基于springboot+vue.js+uniapp的客户关系管理系统附带文章源码部署视频讲解等
107 2
|
1月前
JS+CSS3文章内容背景黑白切换源码
JS+CSS3文章内容背景黑白切换源码是一款基于JS+CSS3制作的简单网页文章文字内容背景颜色黑白切换效果。
20 0
|
5月前
|
JavaScript Java 测试技术
基于springboot+vue.js+uniapp的小区物流配送系统附带文章源码部署视频讲解等
基于springboot+vue.js+uniapp的小区物流配送系统附带文章源码部署视频讲解等
155 4
|
5月前
|
JavaScript Java 测试技术
基于springboot+vue.js+uniapp的宠物援助平台附带文章源码部署视频讲解等
基于springboot+vue.js+uniapp的宠物援助平台附带文章源码部署视频讲解等
90 4
|
5月前
|
JavaScript Java 测试技术
基于springboot+vue.js+uniapp的宠物交易平台附带文章源码部署视频讲解等
基于springboot+vue.js+uniapp的宠物交易平台附带文章源码部署视频讲解等
82 4
|
5月前
|
JavaScript Java 测试技术
基于springboot+vue.js+uniapp的大学生入伍人员管理系统附带文章源码部署视频讲解等
基于springboot+vue.js+uniapp的大学生入伍人员管理系统附带文章源码部署视频讲解等
100 4