根据译文片段预测翻译作者

简介: 本教程的目的是带领大家学会,根据译文片段预测翻译作者本次用到的数据集是三个 txt 文本,分别是 cowper.txt、derby.txt、butler.txt ,该文本已经经过一些预处理,去除了表头,页眉等

本教程的目的是带领大家学会,根据译文片段预测翻译作者

本次用到的数据集是三个 txt 文本,分别是 cowper.txt、derby.txt、butler.txt ,该文本已经经过一些预处理,去除了表头,页眉等

接下来我们加载数据,这里我们使用 tf.data.TextLineDataset API,而不是之前使用的 text_dataset_from_directory,两者的区别是,前者加载 txt 文件里的每一行作为一个样本,后者是加载整个 txt 文件作为一个样本

DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']

for name in FILE_NAMES:
  text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)

parent_dir = pathlib.Path(text_dir).parent
list(parent_dir.iterdir())

def labeler(example, index):
  return example, tf.cast(index, tf.int64)

labeled_data_sets = []

for i, file_name in enumerate(FILE_NAMES):
  lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))
  labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
  labeled_data_sets.append(labeled_dataset)

如上图所示,我们可以看到,txt 文件里的每一行确实是一个样本,其实上面的数据已经经过进一步处理了,变成 (example, label) pair 了

接下来我们需要对文本进行 standardize and tokenize,然后再使用 StaticVocabularyTable,建立 tokens 到 integers 的映射

这里我们使用 UnicodeScriptTokenizer 来 tokenize 数据集,代码如下所示

tokenizer = tf_text.UnicodeScriptTokenizer()

def tokenize(text, unused_label):
  lower_case = tf_text.case_fold_utf8(text)
  return tokenizer.tokenize(lower_case)

tokenized_ds = all_labeled_data.map(tokenize)

上图是 tokenize 的结果展示

下一步,我们需要建立 vocabulary,根据 tokens 的频率做一个排序,并取排名靠前的 VOCAB_SIZE 个元素

tokenized_ds = configure_dataset(tokenized_ds)

vocab_dict = collections.defaultdict(lambda: 0)
for toks in tokenized_ds.as_numpy_iterator():
  for tok in toks:
    vocab_dict[tok] += 1

vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)
vocab = [token for token, count in vocab]
vocab = vocab[:VOCAB_SIZE]
vocab_size = len(vocab)
print("Vocab size: ", vocab_size)
print("First five vocab entries:", vocab[:5])

接下来,我们需要用 vocab 创建 StaticVocabularyTable,因为 0 被保留用于表明 padding,1 被保留用于表明 OOV token,所以我们的实际 map tokens 的integer 是 [2, vocab_size+2],代码如下所示

keys = vocab
values = range(2, len(vocab) + 2)  # reserve 0 for padding, 1 for OOV

init = tf.lookup.KeyValueTensorInitializer(
    keys, values, key_dtype=tf.string, value_dtype=tf.int64)

num_oov_buckets = 1
vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)

最后我们要封装一个函数用于 standardize, tokenize and vectorize 数据集,通过 tokenizer and lookup table

def preprocess_text(text, label):
  standardized = tf_text.case_fold_utf8(text)
  tokenized = tokenizer.tokenize(standardized)
  vectorized = vocab_table.lookup(tokenized)
  return vectorized, label

上图是关于把 raw text 转化成 tokens 的展示结果

接下来,我们需要对数据集进行划分,然后再创建模型,最后就可以开始训练了,代码如下所示

all_encoded_data = all_labeled_data.map(preprocess_text)

train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
validation_data = all_encoded_data.take(VALIDATION_SIZE)

train_data = train_data.padded_batch(BATCH_SIZE)
validation_data = validation_data.padded_batch(BATCH_SIZE)

vocab_size += 2

train_data = configure_dataset(train_data)
validation_data = configure_dataset(validation_data)

model = create_model(vocab_size=vocab_size, num_labels=3)
model.compile(
    optimizer='adam',
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
history = model.fit(train_data, validation_data=validation_data, epochs=3)

上图是训练的结果展示,在验证集上的准确率达到了 84.18%

inputs = [
    "Join'd to th' Ionians with their flowing robes,",  # Label: 1
    "the allies, and his armour flashed about him so that he seemed to all",  # Label: 2
    "And with loud clangor of his arms he fell.",  # Label: 0
]
predicted_scores = export_model.predict(inputs)
predicted_labels = tf.argmax(predicted_scores, axis=1)
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())

最后我们用训练后的模型进行预测,结果如下图所示

预测结果和实际标签都对应上了

代码地址: https://codechina.csdn.net/csdn_codechina/enterprise_technology/-/blob/master/predict_translations_author.ipynb

目录
相关文章
|
Cloud Native 网络协议 数据中心
Overlay网络与Underlay网络:深入探索与全面对比
在当今云原生的世界中🌍☁️,网络是构建和维护任何分布式系统的基石💎。了解Overlay网络和Underlay网络及其之间的区别🔍,对于设计高效、可扩展的云原生应用至关重要🚀。本文旨在全面解析Overlay和Underlay网络,揭示它们的工作原理、优缺点,并说明何种情况下应该使用哪一种网络📚。
Overlay网络与Underlay网络:深入探索与全面对比
|
12月前
|
数据采集 Web App开发 JavaScript
Selenium爬虫技术:如何模拟鼠标悬停抓取动态内容
本文介绍了如何使用Selenium爬虫技术抓取抖音评论,通过模拟鼠标悬停操作和结合代理IP、Cookie及User-Agent设置,有效应对动态内容加载和反爬机制。代码示例展示了具体实现步骤,帮助读者掌握这一实用技能。
522 0
Selenium爬虫技术:如何模拟鼠标悬停抓取动态内容
|
12月前
|
消息中间件 负载均衡 算法
聊聊 RocketMQ中 Topic,Queue,Consumer,Consumer Group的关系
本文详细解析了RocketMQ中Topic、Queue、Consumer及Consumer Group之间的关系。文中通过图表展示了Topic可包含多个Queue,Queue分布在不同Broker上;Consumer组内多个消费者共享消息;并深入探讨了集群消费与广播消费模式下Queue与Consumer的关系,以及Rebalancing机制在实例增减时如何确保负载均衡。理解这些关系有助于更好地掌握RocketMQ的工作原理,提升系统运维效率。
2540 2
|
7天前
|
存储 关系型数据库 分布式数据库
PostgreSQL 18 发布,快来 PolarDB 尝鲜!
PostgreSQL 18 发布,PolarDB for PostgreSQL 全面兼容。新版本支持异步I/O、UUIDv7、虚拟生成列、逻辑复制增强及OAuth认证,显著提升性能与安全。PolarDB-PG 18 支持存算分离架构,融合海量弹性存储与极致计算性能,搭配丰富插件生态,为企业提供高效、稳定、灵活的云数据库解决方案,助力企业数字化转型如虎添翼!
|
6天前
|
存储 人工智能 Java
AI 超级智能体全栈项目阶段二:Prompt 优化技巧与学术分析 AI 应用开发实现上下文联系多轮对话
本文讲解 Prompt 基本概念与 10 个优化技巧,结合学术分析 AI 应用的需求分析、设计方案,介绍 Spring AI 中 ChatClient 及 Advisors 的使用。
316 130
AI 超级智能体全栈项目阶段二:Prompt 优化技巧与学术分析 AI 应用开发实现上下文联系多轮对话
|
18天前
|
弹性计算 关系型数据库 微服务
基于 Docker 与 Kubernetes(K3s)的微服务:阿里云生产环境扩容实践
在微服务架构中,如何实现“稳定扩容”与“成本可控”是企业面临的核心挑战。本文结合 Python FastAPI 微服务实战,详解如何基于阿里云基础设施,利用 Docker 封装服务、K3s 实现容器编排,构建生产级微服务架构。内容涵盖容器构建、集群部署、自动扩缩容、可观测性等关键环节,适配阿里云资源特性与服务生态,助力企业打造低成本、高可靠、易扩展的微服务解决方案。
1330 8
|
5天前
|
监控 JavaScript Java
基于大模型技术的反欺诈知识问答系统
随着互联网与金融科技发展,网络欺诈频发,构建高效反欺诈平台成为迫切需求。本文基于Java、Vue.js、Spring Boot与MySQL技术,设计实现集欺诈识别、宣传教育、用户互动于一体的反欺诈系统,提升公众防范意识,助力企业合规与用户权益保护。