前篇得出初始矩阵、转移矩阵、发射矩阵
通过归一化得出每个状态的概率。
然后通过 pickle 将三个数组序列化到文件中,用的时候反序列化
# 训练数据 [ '今天 天气 真 不错 。', '麻辣肥牛 好吃 !', '我 喜欢 吃 好吃 的 !' ] # 标注 [ 'BE BE S BE S', 'BMME BE S', 'S BE S BE S S ' ] # 初始矩阵 [2, 0, 1, 0] # 转移矩阵 [ [0, 1, 0, 6], [0, 1, 0, 1], [3, 0, 1, 0], [2, 0, 5, 0] ] # 发射矩阵 { 'B': {'total': 7, '今': 1, '天': 1, '不': 1, '麻': 1, '好': 2, '喜': 1}, 'M': {'total': 2, '辣': 1, '肥': 1}, 'S': {'total': 7, '真': 1, '。': 1, '!': 2, '我': 1, '吃': 1, '的': 1}, 'E': {'total': 7, '天': 1, '气': 1, '错': 1, '牛': 1, '吃': 2, '欢': 1} }
import pickle from tqdm import tqdm import numpy as np import os # 定义 HMM类, 其实最关键的就是三大矩阵 class HMM: def __init__(self, file_text, file_state): # 初始矩阵 : 1 * 4 , 对应的是 BMSE, self.init_matrix = [2, 0, 1, 0, ] # 转移状态矩阵: 4 * 4 , self.transfer_matrix = [[0, 1, 0, 6], [0, 1, 0, 1], [3, 0, 1, 0], [2, 0, 5, 0]] # 发射矩阵 self.emit_matrix = { 'B': {'total': 7, '今': 1, '天': 1, '不': 1, '麻': 1, '好': 2, '喜': 1}, 'M': {'total': 2, '辣': 1, '肥': 1}, 'S': {'total': 7, '真': 1, '。': 1, '!': 2, '我': 1, '吃': 1, '的': 1}, 'E': {'total': 7, '天': 1, '气': 1, '错': 1, '牛': 1, '吃': 2, '欢': 1} } # 将矩阵归一化,得出概率 def normalize(self): self.init_matrix = self.init_matrix / np.sum(self.init_matrix) self.transfer_matrix = self.transfer_matrix / np.sum(self.transfer_matrix, axis=1, keepdims=True) self.emit_matrix = {state: {word: t / word_times["total"] * 1000 for word, t in word_times.items() if word != "total"} for state, word_times in self.emit_matrix.items()} # 训练开始, 其实就是3个矩阵的求解过程 def train(self): self.normalize() # 矩阵求完之后进行归一化 pickle.dump([self.init_matrix, self.transfer_matrix, self.emit_matrix], open("data/three_matrix.pkl", "wb")) # 保存参数 if __name__ == "__main__": train_file = "data/train_data.txt" state_file = "data/train_state.txt" hmm = HMM(train_file, state_file) hmm.train()
# 初始矩阵 -- 概率 [0.66666667, 0,0.33333333,0] # 转移矩阵 -- 概率 [[0. , 0.14285714, 0. , 0.85714286] [0. , 0.5 , 0. , 0.5 ] [0.75 , 0. , 0.25 , 0. ] [0.28571429, 0. , 0.71428571, 0. ]] # 转移矩阵 -- 概率 { 'B': {'今': 142.85714285714286, '天': 142.85714285714286, '不': 142.85714285714286, '麻': 142.85714285714286, '好': 285.7142857142857, '喜': 142.85714285714286}, 'M': {'辣': 500.0, '肥': 500.0}, 'S': {'真': 142.85714285714286, '。': 142.85714285714286, '!': 285.7142857142857, '我': 142.85714285714286, '吃': 142.85714285714286, '的': 142.85714285714286}, 'E': {'天': 142.85714285714286, '气': 142.85714285714286, '错': 142.85714285714286, '牛': 142.85714285714286, '吃': 285.7142857142857, '欢': 142.85714285714286} }