【阿里天池-医学影像报告异常检测】3 机器学习模型训练及集成学习Baseline开源

简介: 本文介绍了一个基于XGBoost、LightGBM和逻辑回归的集成学习模型,用于医学影像报告异常检测任务,并公开了达到0.83+准确率的基线代码。

引言

采用机器学习分类算法XGBClassifier、LGBMClassifier、LogisticRegression集成学习线上得到0.83+的准确率
开源源码:https://github.com/823316627bandeng/TIANCHI-2021-AI-Compition

模型实现

(1)导入包

import os
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, ClassifierMixin, MultiOutputClassifier
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from mlxtend.classifier import StackingClassifier
from utils import *
import os
import pickle
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

(2)准备数据

#加载数据
label= np.array(pd.read_csv('./data/label.csv'))
#train_sampel = pd.read_csv('./data/train_sample_500.csv')
train_sampel = pd.read_csv('./data/train_sample.csv')
test_sampel = pd.read_csv('./data/test_sample.csv')
#test_sampel = pd.read_csv('./data/test_sample_500.csv')
#数据归一化
stdScalar = StandardScaler()
train_df = stdScalar.fit_transform(np.float_(train_sampel))
test_df = stdScalar.fit_transform(np.float_(test_sampel))

losslist = []
nfold = 5
kf = MultilabelStratifiedKFold(n_splits=nfold, shuffle=True, random_state=2020)
lr_oof = np.zeros(label.shape)
# 存储测试集的概率
probility = np.zeros((len(test_df), label.shape[1]))
i = 0
model_type = 'ensemble'
# model_type ='single'
# K折交叉划分训练
for train_index, valid_index in kf.split(train_df, label):
    print("\nFold {}".format(i + 1))
    X_train, label_train = train_df[train_index], label[train_index]
    X_valid, label_valid = train_df[valid_index], label[valid_index]
    # 三个模型
    clf1 = OneVsRestClassifier(XGBClassifier(eval_metric= 'mlogloss',use_label_encoder=False,n_estimators=150))
    clf2 = LGBMClassifier()
    clf3 = LogisticRegression(max_iter =500, n_jobs=20)
    # 集成学习方法1
    if model_type == 'ensemble':
        model = OneVsRestClassifier(EnsembleVoteClassifier(clfs=[clf1, clf2, clf3],weights=[2, 1, 1], voting='soft', verbose=2))
    # 集成学习方法2
    elif model_type == 'stacking':
        lr = LogisticRegression()
        base = StackingClassifier(classifiers=[clf1, clf2, clf3],use_probas=True,average_probas=False, meta_classifier=lr,verbose=2)
        model = OneVsRestClassifier(base)
    else:
    # 单模型训练
        model = OneVsRestClassifier(XGBClassifier(eval_metric= 'mlogloss',use_label_encoder=False,n_estimators=150))

    model.fit(X_train, label_train)
    # 预测结果
    lr_oof[valid_index] = model.predict_proba(X_valid,)
    # 计算mlogloss
    loss = Mutilogloss(label_valid[:,:-1,], lr_oof[valid_index][:,:-1,])
    losslist.append(loss)
    # 多个flod预测结果叠加
    probility += model.predict_proba(test_df) / nfold
    i += 1
    print(losslist)

print(np.mean(losslist))
print()

# 保存存提交数据
submit_dir='submits/'
if not os.path.exists(submit_dir): os.makedirs(submit_dir)
str_w=''
with open(submit_dir+'machine_model_submit.csv','w') as f:
    for i in range(len(probility)):
        list_to_str = [str(x) for x in list(probility[i])][0:-1]
        liststr = " ".join(list_to_str)
        str_w+=str(i)+'|'+','+'|'+liststr+'\n'
    str_w=str_w.strip('\n')
    f.write(str_w)
print()
目录
相关文章
|
2天前
|
消息中间件 Java 数据库
新版 Seata 集成 RocketMQ事务消息,越来越 牛X 了!阿里的 Seata , yyds !
这里 借助 Seata 集成 RocketMQ 事务消息的 新功能,介绍一下一个新遇到的面试题:如果如何实现 **强弱一致性 结合**的分布式事务?
新版 Seata 集成 RocketMQ事务消息,越来越 牛X 了!阿里的 Seata , yyds !
|
7天前
|
数据采集 移动开发 数据可视化
模型预测笔记(一):数据清洗分析及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)
这篇文章介绍了数据清洗、分析、可视化、模型搭建、训练和预测的全过程,包括缺失值处理、异常值处理、特征选择、数据归一化等关键步骤,并展示了模型融合技术。
24 1
模型预测笔记(一):数据清洗分析及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)
|
15天前
|
测试技术
软件质量保护与测试(第2版)学习总结第十三章 集成测试
本文是《软件质量保护与测试》(第2版)第十三章的学习总结,介绍了集成测试的概念、主要任务、测试层次与原则,以及集成测试的不同策略,包括非渐增式集成和渐增式集成(自顶向下和自底向上),并通过图示详细解释了集成测试的过程。
36 1
软件质量保护与测试(第2版)学习总结第十三章 集成测试
|
2天前
|
存储 JSON Ubuntu
时序数据库 TDengine 支持集成开源的物联网平台 ThingsBoard
本文介绍了如何结合 Thingsboard 和 TDengine 实现设备管理和数据存储。Thingsboard 中的“设备配置”与 TDengine 中的超级表相对应,每个设备对应一个子表。通过创建设备配置和设备,实现数据的自动存储和管理。具体操作包括创建设备配置、添加设备、写入数据,并展示了车辆实时定位追踪和车队维护预警两个应用场景。
15 3
|
7天前
|
前端开发 Java 程序员
springboot 学习十五:Spring Boot 优雅的集成Swagger2、Knife4j
这篇文章是关于如何在Spring Boot项目中集成Swagger2和Knife4j来生成和美化API接口文档的详细教程。
21 1
|
15天前
|
人工智能 自然语言处理 关系型数据库
阿里云云原生数据仓库 AnalyticDB PostgreSQL 版已完成和开源LLMOps平台Dify官方集成
近日,阿里云云原生数据仓库 AnalyticDB PostgreSQL 版已完成和开源LLMOps平台Dify官方集成。
|
1月前
|
机器学习/深度学习 算法 数据挖掘
Python数据分析革命:Scikit-learn库,让机器学习模型训练与评估变得简单高效!
在数据驱动时代,Python 以强大的生态系统成为数据科学的首选语言,而 Scikit-learn 则因简洁的 API 和广泛的支持脱颖而出。本文将指导你使用 Scikit-learn 进行机器学习模型的训练与评估。首先通过 `pip install scikit-learn` 安装库,然后利用内置数据集进行数据准备,选择合适的模型(如逻辑回归),并通过交叉验证评估其性能。最终,使用模型对新数据进行预测,简化整个流程。无论你是新手还是专家,Scikit-learn 都能助你一臂之力。
96 8
|
7天前
|
Java Spring
springboot 学习十一:Spring Boot 优雅的集成 Lombok
这篇文章是关于如何在Spring Boot项目中集成Lombok,以简化JavaBean的编写,避免冗余代码,并提供了相关的配置步骤和常用注解的介绍。
38 0
|
8天前
|
机器学习/深度学习 算法 前端开发
集成学习任务七和八、投票法与bagging学习
集成学习任务七和八、投票法与bagging学习
8 0
|
9天前
|
JSON 测试技术 API
阿里云PAI-Stable Diffusion开源代码浅析之(二)我的png info怎么有乱码
阿里云PAI-Stable Diffusion开源代码浅析之(二)我的png info怎么有乱码