train_test_split.py代码解释

简介: 这段代码用于将MovieLens 1M数据集的评分数据划分为训练集和测试集。• 首先,使用Path库获取当前文件的父级目录,也就是项目根目录。• 接着,定义输出训练集和测试集文件的路径。
import os
from pathlib import Path
p = Path(__file__).parents[1]
OUTPUT_DIR_TRAIN=os.path.abspath(os.path.join(p, '..', 'data/raw/ml-1m/train.dat'))
OUTPUT_DIR_TEST=os.path.abspath(os.path.join(p, '..', 'data/raw/ml-1m/test.dat'))
ROOT_DIR=os.path.abspath(os.path.join(p, '..', 'data/raw/ml-1m/ratings.dat'))
NUM_USERS=6040
NUM_TEST_RATINGS=10
def count_rating_per_user():
    rating_per_user={}
    for line in open(ROOT_DIR):
        line=line.split('::')
        user_nr=int(line[0])
        if user_nr in rating_per_user:
            rating_per_user[user_nr]+=1                       
        else:
            rating_per_user[user_nr]=1
    return rating_per_user
def train_test_split():
    user_rating=count_rating_per_user()
    test_counter=0
    next_user=1
    train_writer=open(OUTPUT_DIR_TRAIN, 'w')
    test_writer=open(OUTPUT_DIR_TEST, 'w')
    for line in open(ROOT_DIR):
        splitted_line=line.split('::')
        user_nr=int(splitted_line[0])
        if user_rating[user_nr]<=NUM_TEST_RATINGS*2:
            next_user+=1
            continue
        try:
            if user_nr==next_user:
                write_test_samples=True
                next_user+=1
            if write_test_samples==True:
                test_writer.write(line)
                test_counter+=1
                if test_counter>=NUM_TEST_RATINGS:
                    test_counter=0
                    write_test_samples=False        
            else:
                train_writer.write(line)
        except KeyError:   
            print('Key not found')
            continue
if __name__ == "__main__":
    train_test_split()

这段代码用于将MovieLens 1M数据集的评分数据划分为训练集和测试集。

  • 首先,使用Path库获取当前文件的父级目录,也就是项目根目录。
  • 接着,定义输出训练集和测试集文件的路径。
  • 然后,定义数据集中用户的数量和每个用户的测试评分数量。
  • 下一步是定义一个函数count_rating_per_user(),该函数将读取ratings.dat文件的每一行,拆分为user_id、movie_id、rating和timestamp,并计算每个用户的评分数。该函数返回一个字典,其中每个键表示用户ID,每个值表示该用户的评分数。
  • train_test_split()函数使用count_rating_per_user()函数获取每个用户的评分数量。该函数读取文件中的每一行,拆分为user_id、movie_id、rating和timestamp,并通过将其user_id与每个用户的评分数量进行比较来确定该行是否写入测试集。如果该行对应的user_id在当前的用户评分数量计数器中具有的评分数小于或等于 NUM_TEST_RATINGS * 2,则该行被跳过。否则,该行将被写入测试集中,并且当前用户的评分数量计数器将递增。如果计数器的值达到 NUM_TEST_RATINGS,则将write_test_samples标记设置为False,以停止将数据写入测试文件,并开始将数据写入训练文件。


这段代码是为了将ml-1m数据集的评分数据ratings.dat划分为训练集和测试集,并将它们保存为两个文件:train.dattest.dat。具体来说,它完成以下几个任务:

  1. 计算每个用户给了多少次评分,以便在划分数据集时可以控制每个用户的评分数。
  2. 按照每个用户的评分数来划分训练集和测试集,每个用户的前 NUM_TEST_RATINGS 个评分划分为测试集,其余评分划分为训练集。
  3. 将划分后的训练集和测试集保存为 train.dattest.dat 两个文件。

在这段代码中,首先使用了pathlibos两个库,用于处理文件路径。接下来,使用count_rating_per_user()函数计算每个用户的评分数。这个函数使用了ratings.dat文件,并将其按行读入。每行的第一项是用户编号,所以可以统计每个用户给出了多少次评分。

接下来,train_test_split()函数读取 ratings.dat 文件,并且按照之前的规则将数据划分为训练集和测试集。对于每个用户,如果其评分数小于等于 NUM_TEST_RATINGS 的两倍,则将所有评分都划分到训练集。否则,将其前 NUM_TEST_RATINGS 个评分划分到测试集,其余评分划分到训练集。通过一个 test_counter 变量来跟踪每个用户的测试集评分数是否达到了 NUM_TEST_RATINGS 的限制,从而保证测试集的大小。在写入数据时,使用两个文件句柄 train_writertest_writer,分别写入训练集和测试集。

最后,在 main 函数中调用 train_test_split() 函数,生成两个文件 train.dattest.dat,它们位于 ml-1m 数据集的 raw 文件夹中。

相关文章
|
机器学习/深度学习 测试技术 TensorFlow
dataset.py代码解释
这段代码主要定义了三个函数来创建 TensorFlow 数据集对象,这些数据集对象将被用于训练、评估和推断神经网络模型。
128 0
|
6月前
|
机器学习/深度学习 索引
yolov5--loss.py --v5.0版本-最新代码详细解释-2021-7-1更新
yolov5--loss.py --v5.0版本-最新代码详细解释-2021-7-1更新
298 0
|
6月前
yolov5--train.py --v5.0版本-2021-7-6更新
yolov5--train.py --v5.0版本-2021-7-6更新
52 0
|
6月前
yolov5--datasets.py --v5.0版本-数据集加载 最新代码详细解释2021-7-5更新
yolov5--datasets.py --v5.0版本-数据集加载 最新代码详细解释2021-7-5更新
277 0
|
机器学习/深度学习 JSON 数据格式
YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py
YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py
1722 1
YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py
|
机器学习/深度学习 搜索推荐 TensorFlow
inference.py的代码解释
这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。
466 0
python中line.split()的用法及实际使用示例
python中line.split()的用法及实际使用示例
|
机器学习/深度学习 Python
python中print参数sep和end 输出中的奥秘!
python中print参数sep和end 输出中的奥秘!
140 0
|
Python
PASCAL VOC数据集训练集、验证集、测试集的划分和提取,得到test.txt、train.txt、trainval.txt、val.txt文件代码
PASCAL VOC数据集训练集、验证集、测试集的划分和提取,得到test.txt、train.txt、trainval.txt、val.txt文件代码
445 0
|
存储 搜索推荐 Java
preprocess_data.py代码解释
循环遍历每个用户,对于每个用户,提取其对电影的评分。 创建一个与所有电影数量相同的评分数组,将相应的评分放置在数组的正确位置。 如果该用户没有评分电影,则跳过该用户。 返回所有用户的评分数组列表。
234 0