Javascript类型推断(3) - 算法模型解析

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: # Javascript类型推断(3) - 算法模型解析 ## 构建训练模型 上一节我们介绍了生成训练集,测试集,验证集的方法,以及生成词表的方法。 这5个文件构成了训练的基本素材: ```python files = { 'train': { 'file': 'data/train.ctf', 'location': 0 }, 'valid': { 'file':

Javascript类型推断(3) - 算法模型解析

构建训练模型

上一节我们介绍了生成训练集,测试集,验证集的方法,以及生成词表的方法。
这5个文件构成了训练的基本素材:

files = {
    'train': { 'file': 'data/train.ctf', 'location': 0 },
    'valid': { 'file': 'data/valid.ctf', 'location': 0 },
    'test': { 'file': 'data/test.ctf', 'location': 0 },
    'source': { 'file': 'data/source_wl', 'location': 1 },
    'target': { 'file': 'data/target_wl', 'location': 1 }
}
AI 代码解读

词表我们需要转换一下格式,放到哈希表里:

# load dictionaries
source_wl = [line.rstrip('\n') for line in open(files['source']['file'])]
target_wl = [line.rstrip('\n') for line in open(files['target']['file'])]
source_dict = {source_wl[i]:i for i in range(len(source_wl))}
target_dict = {target_wl[i]:i for i in range(len(target_wl))}
AI 代码解读

下面是一些全局参数:

# number of words in vocab, slot labels, and intent labels
vocab_size = len(source_dict)
num_labels = len(target_dict)
epoch_size = 17.955*1000*1000
minibatch_size = 5000
emb_dim = 300
hidden_dim = 650
num_epochs = 10
AI 代码解读

下面我们定义x,y,t三个值,分别与输入词表、输出标签数和隐藏层有关

# Create the containers for input feature (x) and the label (y)
x = C.sequence.input_variable(vocab_size, name="x")
y = C.sequence.input_variable(num_labels, name="y")
t = C.sequence.input_variable(hidden_dim, name="t")
AI 代码解读

好,我们开始看下训练的流程:

model = create_model()
enc, dec = model(x, t)
trainer = create_trainer()
train()
AI 代码解读

训练模型

首先是一个词嵌入层:

def create_model():
    embed = C.layers.Embedding(emb_dim, name='embed')
AI 代码解读

然后是两个双向的循环神经网络(使用GRU),一个全连接网络,和一个dropout:

    encoder = BiRecurrence(C.layers.GRU(hidden_dim//2), C.layers.GRU(hidden_dim//2))
    recoder = BiRecurrence(C.layers.GRU(hidden_dim//2), C.layers.GRU(hidden_()dim//2))
    project = C.layers.Dense(num_labels, name='classify')
    do = C.layers.Dropout(0.5)
AI 代码解读

然后把上面的四项组合起来:

    def recode(x, t):
        inp = embed(x)
        inp = C.layers.LayerNormalization()(inp)
        
        enc = encoder(inp)
        rec = recoder(enc + t)
        proj = project(do(rec))
        
        dec = C.ops.softmax(proj)
        return enc, dec
    return recode
AI 代码解读

其中双向循环神经网络定义如下:

def BiRecurrence(fwd, bwd):
    F = C.layers.Recurrence(fwd)
    G = C.layers.Recurrence(bwd, go_backwards=True)
    x = C.placeholder()
    apply_x = C.splice(F(x), G(x))
    return apply_x
AI 代码解读

构建训练过程

首先定义下损失函数,由两部分组成,一部分是loss,另一部分是分类错误:

def criterion(model, labels):
    ce     = -C.reduce_sum(labels*C.ops.log(model))
    errs = C.classification_error(model, labels)
    return ce, errs
AI 代码解读

有了损失函数之后,我们使用带动量的Adam算法进行梯度下降训练:

def create_trainer():
    masked_dec = dec*C.ops.clip(C.ops.argmax(y), 0, 1)
    loss, label_error = criterion(masked_dec, y)
    loss *= C.ops.clip(C.ops.argmax(y), 0, 1)

    lr_schedule = C.learning_parameter_schedule_per_sample([1e-3]*2 + [5e-4]*2 + [1e-4], epoch_size=int(epoch_size))
    momentum_as_time_constant = C.momentum_as_time_constant_schedule(1000)
    learner = C.adam(parameters=dec.parameters,
                         lr=lr_schedule,
                         momentum=momentum_as_time_constant,
                         gradient_clipping_threshold_per_sample=15, 
                         gradient_clipping_with_truncation=True)

    progress_printer = C.logging.ProgressPrinter(tag='Training', num_epochs=num_epochs)
    trainer = C.Trainer(dec, (loss, label_error), learner, progress_printer)
    C.logging.log_number_of_parameters(dec)
    return trainer
AI 代码解读

训练

定义好模型之后,我们就可以训练了。
首先我们可以利用CNTK.io包的功能定义一个数据的读取器:

def create_reader(path, is_training):
    return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(
            source        = C.io.StreamDef(field='S0', shape=vocab_size, is_sparse=True), 
            slot_labels    = C.io.StreamDef(field='S1', shape=num_labels, is_sparse=True)
    )), randomize=is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)
AI 代码解读

然后我们就可以利用这个读取器读取数据开始训练了:

def train():
    train_reader = create_reader(files['train']['file'], is_training=True)
    step = 0
    pp = C.logging.ProgressPrinter(freq=10, tag='Training')
    for epoch in range(num_epochs):
        epoch_end = (epoch+1) * epoch_size
        while step < epoch_end:
            data = train_reader.next_minibatch(minibatch_size, input_map={
                x: train_reader.streams.source,
                y: train_reader.streams.slot_labels
            })
            # Enhance data
            enhance_data(data, enc)
            # Train model
            trainer.train_minibatch(data)
            pp.update_with_trainer(trainer, with_metric=True)
            step += data[y].num_samples
        pp.epoch_summary(with_metric=True)
        trainer.save_checkpoint("models/model-" + str(epoch + 1) + ".cntk")
        validate()
        evaluate()
AI 代码解读

上面的代码中,enhance_data需要解释一下。
我们的数据并非是完全线性的数据,还需要进行一个数据增强的处理过程:

def enhance_data(data, enc):
    guesses = enc.eval({x: data[x]})
    inputs = C.ops.argmax(x).eval({x: data[x]})
    tables = []
    for i in range(len(inputs)):
        ts = []
        table = {}
        counts = {}
        for j in range(len(inputs[i])):
            inp = int(inputs[i][j])
            if inp not in table:
                table[inp] = guesses[i][j]
                counts[inp] = 1
            else:
                table[inp] += guesses[i][j]
                counts[inp] += 1
        for inp in table:
            table[inp] /= counts[inp]
        for j in range(len(inputs[i])):
            inp = int(inputs[i][j])
            ts.append(table[inp])
        tables.append(np.array(np.float32(ts)))
    s = C.io.MinibatchSourceFromData(dict(t=(tables, C.layers.typing.Sequence[C.layers.typing.tensor])))
    mems = s.next_minibatch(minibatch_size)
    data[t] = mems[s.streams['t']]
AI 代码解读

测试和验证

测试和验证的过程中,也需要我们上面介绍的数据增强的过程:

def validate():
    valid_reader = create_reader(files['valid']['file'], is_training=False)
    while True:
        data = valid_reader.next_minibatch(minibatch_size, input_map={
                x: valid_reader.streams.source,
                y: valid_reader.streams.slot_labels
        })
        if not data:
            break
        enhance_data(data, enc)
        trainer.test_minibatch(data)
    trainer.summarize_test_progress()
AI 代码解读

evaluate与validate逻辑完全一样,只是读取的文件不同:

def evaluate():
    test_reader = create_reader(files['test']['file'], is_training=False)
    while True:
        data = test_reader.next_minibatch(minibatch_size, input_map={
            x: test_reader.streams.source,
            y: test_reader.streams.slot_labels
        })
        if not data:
            break
        # Enhance data
        enhance_data(data, enc)
        # Test model
        trainer.test_minibatch(data)
    trainer.summarize_test_progress()
AI 代码解读
目录
打赏
0
0
0
0
577
分享
相关文章
18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现
本文系统讲解从基本强化学习方法到高级技术(如PPO、A3C、PlaNet等)的实现原理与编码过程,旨在通过理论结合代码的方式,构建对强化学习算法的全面理解。
70 10
18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现
AI训练师入行指南(三):机器学习算法和模型架构选择
从淘金到雕琢,将原始数据炼成智能珠宝!本文带您走进数字珠宝工坊,用算法工具打磨数据金砂。从基础的经典算法到精密的深度学习模型,结合电商、医疗、金融等场景实战,手把手教您选择合适工具,打造价值连城的智能应用。掌握AutoML改装套件与模型蒸馏术,让复杂问题迎刃而解。握紧算法刻刀,为数字世界雕刻文明!
86 6
金融数据分析:解析JavaScript渲染的隐藏表格
本文详解了如何使用Python与Selenium结合代理IP技术,从金融网站(如东方财富网)抓取由JavaScript渲染的隐藏表格数据。内容涵盖环境搭建、代理配置、模拟用户行为、数据解析与分析等关键步骤。通过设置Cookie和User-Agent,突破反爬机制;借助Selenium等待页面渲染,精准定位动态数据。同时,提供了常见错误解决方案及延伸练习,帮助读者掌握金融数据采集的核心技能,为投资决策提供支持。注意规避动态加载、代理验证及元素定位等潜在陷阱,确保数据抓取高效稳定。
77 17
企业用网络监控软件中的 Node.js 深度优先搜索算法剖析
在数字化办公盛行的当下,企业对网络监控的需求呈显著增长态势。企业级网络监控软件作为维护网络安全、提高办公效率的关键工具,其重要性不言而喻。此类软件需要高效处理复杂的网络拓扑结构与海量网络数据,而算法与数据结构则构成了其核心支撑。本文将深入剖析深度优先搜索(DFS)算法在企业级网络监控软件中的应用,并通过 Node.js 代码示例进行详细阐释。
43 2
基于 Node.js 深度优先搜索算法的上网监管软件研究
在数字化时代,网络环境呈现出高度的复杂性与动态性,上网监管软件在维护网络秩序与安全方面的重要性与日俱增。此类软件依托各类数据结构与算法,实现对网络活动的精准监测与高效管理。本文将深度聚焦于深度优先搜索(DFS)算法,并结合 Node.js 编程语言,深入剖析其在上网监管软件中的应用机制与效能。
42 6
Javascript常见算法详解
本文介绍了几种常见的JavaScript算法,包括排序、搜索、递归和图算法。每种算法都提供了详细的代码示例和解释。通过理解这些算法,你可以在实际项目中有效地解决各种数据处理和分析问题。
79 21
基于 C# 的内网行为管理软件入侵检测算法解析
当下数字化办公环境中,内网行为管理软件已成为企业维护网络安全、提高办公效率的关键工具。它宛如一位恪尽职守的网络守护者,持续监控内网中的各类活动,以确保数据安全及网络稳定。在其诸多功能实现的背后,先进的数据结构与算法发挥着至关重要的作用。本文将深入探究一种应用于内网行为管理软件的 C# 算法 —— 基于二叉搜索树的入侵检测算法,并借助具体代码例程予以解析。
50 4
JavaScript 中通过Array.sort() 实现多字段排序、排序稳定性、随机排序洗牌算法、优化排序性能,JS中排序算法的使用详解(附实际应用代码)
Array.sort() 是一个功能强大的方法,通过自定义的比较函数,可以处理各种复杂的排序逻辑。无论是简单的数字排序,还是多字段、嵌套对象、分组排序等高级应用,Array.sort() 都能胜任。同时,通过性能优化技巧(如映射排序)和结合其他数组方法(如 reduce),Array.sort() 可以用来实现高效的数据处理逻辑。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
JS数组操作方法全景图,全网最全构建完整知识网络!js数组操作方法全集(实现筛选转换、随机排序洗牌算法、复杂数据处理统计等情景详解,附大量源码和易错点解析)
这些方法提供了对数组的全面操作,包括搜索、遍历、转换和聚合等。通过分为原地操作方法、非原地操作方法和其他方法便于您理解和记忆,并熟悉他们各自的使用方法与使用范围。详细的案例与进阶使用,方便您理解数组操作的底层原理。链式调用的几个案例,让您玩转数组操作。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
全网最全情景,深入浅出解析JavaScript数组去重:数值与引用类型的全面攻略
如果是基础类型数组,优先选择 Set。 对于引用类型数组,根据需求选择 Map 或 JSON.stringify()。 其余情况根据实际需求进行混合调用,就能更好的实现数组去重。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~

推荐镜像

更多
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等