Javascript类型推断(4) - 隐藏层的更新
熟悉了整个流程之后,我们可以关注更多的细节。
前面讲训练过程时,没有讲enhance_data的细节。这一部分的主要功能是更新隐藏层。它的调用点在:
def train():
train_reader = create_reader(files['train']['file'], is_training=True)
step = 0
pp = C.logging.ProgressPrinter(freq=10, tag='Training')
for epoch in range(num_epochs):
trainer = create_trainer()
epoch_end = (epoch+1) * epoch_size
print('epoch_end=%d' % epoch_end)
while step < epoch_end:
data = train_reader.next_minibatch(minibatch_size, input_map={
x: train_reader.streams.source,
y: train_reader.streams.slot_labels
})
# Enhance data
enhance_data(data, enc)
# Train model
trainer.train_minibatch(data)
pp.update_with_trainer(trainer, with_metric=True)
step += data[y].num_samples
在讲解enhance_data之前,我们先讲解一下train.txt生成的ctf的结构,以便大家了解在训练中用到的数据的真正含义。
ctf文件分析
前面我们专注于主线,现在我们重温一下txt2ctf的结果。
ctf是CNTK Text Format的缩写,是CNTK处理文本的格式。
ctf格式是这样子的:
0 |S0 15:1 |S1 0:1
0 |S0 3:1 |S1 0:1
0 |S0 25:1 |S1 1:1
0 |S0 2:1 |S1 0:1
0 |S0 0:1 |S1 0:1
0 |S0 1:1 |S1 0:1
0 |S0 3:1 |S1 0:1
0 |S0 16:1 |S1 4:1
0 |S0 2:1 |S1 0:1
0 |S0 17:1 |S1 0:1
0 |S0 1:1 |S1 0:1
0 |S0 18:1 |S1 6:1
0 |S0 6:1 |S1 0:1
0 |S0 19:1 |S1 3:1
0 |S0 4:1 |S1 0:1
0 |S0 16:1 |S1 4:1
0 |S0 5:1 |S1 0:1
0 |S0 1:1 |S1 0:1
0 |S0 20:1 |S1 0:1
我们先来看第一行:
0 |S0 15:1 |S1 0:1
行首的0是第1个文件,1就是第2个文件。我们只有2个文件,所以取值先为0,后面的是1。
S0是source_wl,15是第16行。
我们看下source_wl的内容:
0
;
=
let
(
)
.
{
}
Test
value
v
this
new
,
<s>
s
"s"
console
log
</s>
class
print
TestNumber
mul
a
public
:
number
constructor
extends
return
*
false
[
_UNKNOWN_
查表可知,对应的是。
而“|S1 0:1”,就是target_wl的第1行,这是个O。
我们再以第2行为例:
0 |S0 3:1 |S1 0:1
左边S0 3:1查source_wl对应的是let。
右边的S1 0:1,在target_wl中对应的是O。
let是个关键字,没有类型信息。
数据的读取
reader
有了上面的格式的基本知识,我们再看reader的代码就清晰了:
def create_reader(path, is_training):
return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(
source = C.io.StreamDef(field='S0', shape=vocab_size, is_sparse=True),
slot_labels = C.io.StreamDef(field='S1', shape=num_labels, is_sparse=True)
)), randomize=is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)
source读的是S0,就是token信息。slot_labels读的是S1,就是类型信息。
shape形状分别是源字典和类型字典的大小:
vocab_size = len(source_dict)
num_labels = len(target_dict)
source_dict和target_dict读取文件的过程在:
files = {
'train': { 'file': 'data/train.ctf', 'location': 0 },
'valid': { 'file': 'data/valid.ctf', 'location': 0 },
'test': { 'file': 'data/test.ctf', 'location': 0 },
'source': { 'file': 'data/source_wl', 'location': 1 },
'target': { 'file': 'data/target_wl', 'location': 1 }
}
source_wl = [line.rstrip('\n') for line in open(files['source']['file'])]
target_wl = [line.rstrip('\n') for line in open(files['target']['file'])]
source_dict = {source_wl[i]:i for i in range(len(source_wl))}
target_dict = {target_wl[i]:i for i in range(len(target_wl))}
inputs
到了inputs的时候,就已经没有S0, S1和0,1这样的标签了,而是变成数组:
inputs = C.ops.argmax(x).eval({x: data[x]})
下面这两重循环,i循环对应到第一个ts工程,j循环是对应到每一个token:
for i in range(len(inputs)):
ts = []
table = {}
counts = {}
for j in range(len(inputs[i])):
还以两个ts工程为例,inputs第一级是两个元素的数组,对应了ctf中第一列的0和1:
inputs[0]= [15. 3. 25. 2. 0. 1. 3. 16. 2. 17. 1. 3. 0. 2. 25. 35. 0. 1.
3. 0. 2. 0. 35. 0. 1. 3. 0. 2. 0. 32. 0. 1. 18. 6. 19. 4.
16. 5. 1. 20.]
inputs[1]= [15. 21. 9. 7. 26. 10. 27. 28. 1. 29. 4. 11. 5. 7. 12. 6. 10. 2.
11. 1. 8. 22. 4. 5. 7. 18. 6. 19. 4. 12. 6. 10. 5. 1. 8. 35.
4. 0. 5. 7. 31. 12. 6. 10. 35. 0. 1. 8. 8. 3. 35. 2. 13. 9.
4. 0. 5. 1. 35. 6. 22. 4. 5. 1. 21. 35. 30. 9. 7. 35. 4. 0.
5. 7. 31. 12. 6. 10. 35. 0. 1. 8. 8. 3. 0. 2. 13. 35. 4. 0.
5. 1. 0. 6. 22. 4. 5. 1. 35. 35. 4. 0. 14. 0. 5. 7. 31. 0.
35. 0. 1. 8. 35. 35. 7. 35. 27. 35. 1. 35. 27. 28. 1. 8. 35. 35.
4. 35. 5. 7. 18. 6. 19. 4. 35. 6. 35. 5. 1. 18. 6. 19. 4. 35.
6. 35. 5. 1. 8. 3. 0. 2. 33. 1. 3. 0. 2. 0. 1. 3. 0. 2.
17. 1. 3. 0. 2. 34. 0. 14. 0. 14. 0. 14. 0. 35. 1. 20.]
比如inputs0是15,inputs0是3, inputs0是25,对应ctf中的值如下:
0 |S0 15:1 |S1 0:1
0 |S0 3:1 |S1 0:1
0 |S0 25:1 |S1 1:1
同样inputs[1]的前几位15. 21. 9. 7. 26.对应train.ctf中的
1 |S0 15:1 |S1 0:1
1 |S0 21:1 |S1 0:1
1 |S0 9:1 |S1 2:1
1 |S0 7:1 |S1 0:1
1 |S0 26:1 |S1 0:1
enhance_data - 更新隐藏层
下面我们就把上面的信息串联一下,看下enhance_data的实现:
def enhance_data(data, enc):
guesses = enc.eval({x: data[x]})
inputs = C.ops.argmax(x).eval({x: data[x]})
tables = []
for i in range(len(inputs)):
ts = []
table = {}
counts = {}
# 第一次遍历token
for j in range(len(inputs[i])):
inp = int(inputs[i][j])
if inp not in table:
table[inp] = guesses[i][j]
counts[inp] = 1
else:
table[inp] += guesses[i][j]
counts[inp] += 1
for inp in table:
table[inp] /= counts[inp]
为每一个token保存之前计算的guesses值。
然后将这些值汇总一下:
# 第二次遍历token,生成每个ts文件的guess表
for j in range(len(inputs[i])):
inp = int(inputs[i][j])
ts.append(table[inp])
# 汇总成每个工程的总guess表
tables.append(np.array(np.float32(ts)))
最后,我们更新下隐藏层的数据:
s = C.io.MinibatchSourceFromData(dict(t=(tables, C.layers.typing.Sequence[C.layers.typing.tensor])))
mems = s.next_minibatch(minibatch_size)
print('mems=',mems)
data[t] = mems[s.streams['t']]
其中,data[t]中的t,定义于:
t = C.sequence.input_variable(hidden_dim, name="t")
一些经验和技巧
修改epoch_size
我们运行lexer.py之后,会输出统计参数,例:
Train projects: 580
Validation projects: 72
Test projects: 73
Train files: 51431
Validation files: 4704
Test files: 3916
Producing vocabularies
Size of source vocab: 42631
Size of target vocab: 12386
Writing train/valid/test files
Overall tokens: 19286557 train, 1663181 valid and 1806291 test
我们可以据此修改infer.py中的参数:
epoch_size = 19286557
如果遇到node OOM问题怎么办?
适当增加一下max-old-space-size的大小,单位为M。
例:
node --max-old-space-size=16384 GetTypes.js