tflearn titanic

简介:






import numpy as np
import tflearn


from tflearn.datasets import titanic
# titanic.download_dataset('titanic_dataset.csv')
from tflearn.data_utils import load_csv


data,labels = load_csv('titanic_dataset.csv',target_column=0,categorical_labels=True,n_classes=2)

def preprocess(passengers,columns_to_delete):
for column_to_delete in sorted(columns_to_delete,reverse=True):
[passenger.pop(column_to_delete) for passenger in passengers]
for i in range(len(passengers)):
passengers[i][1] = 1. if passengers[i][1] == 'female' else 0.

return np.array(passengers,dtype=np.float32)

to_ignore = [1,6]

data = preprocess(data,to_ignore)


net = tflearn.input_data(shape=[None,6])
net = tflearn.fully_connected(net,32)
net = tflearn.fully_connected(net,32)
net = tflearn.fully_connected(net,2,activation='softmax')
net = tflearn.regression(net)



model = tflearn.DNN(net)
# model.fit(data,labels,n_epoch=10,batch_size=16,show_metric=True)
model_file = './model_titanic.tfl'
# model.save(model_file=model_file)

model.load(model_file=model_file)

dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000]
winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000]
# Preprocess data
dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore)
print([dicaprio, winslet])


print(dicaprio)
print(winslet)
# Predict surviving chances (class 1 results)
pred = model.predict([dicaprio, winslet])
print(pred)
print(pred.shape)
print("DiCaprio Surviving Rate:", pred[0][1])
print("Winslet Surviving Rate:", pred[1][1])
目录
相关文章
|
3月前
|
机器学习/深度学习 数据可视化 TensorFlow
TFLearn介绍
【7月更文挑战第27天】TFLearn介绍。
29 4
|
5月前
波士顿房价数据集 Boston house prices dataset
波士顿房价数据集 Boston house prices dataset
131 2
|
5月前
|
存储 数据可视化 PyTorch
PyTorch中 Datasets & DataLoader 的介绍
PyTorch中 Datasets & DataLoader 的介绍
122 0
|
机器学习/深度学习 数据采集 PyTorch
pytorch笔记:Dataset 和 DataLoader
pytorch笔记:Dataset 和 DataLoader
289 0
|
Python
解决ImportError: umap.plot requires pandas matplotlib datashader bokeh holoviews scikit-image and colo
解决ImportError: umap.plot requires pandas matplotlib datashader bokeh holoviews scikit-image and colo
256 0
解决ImportError: umap.plot requires pandas matplotlib datashader bokeh holoviews scikit-image and colo
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch教程[02]DataLoader与Dataset
Pytorch教程[02]DataLoader与Dataset
Pytorch教程[02]DataLoader与Dataset
|
机器学习/深度学习
COVID-19 Cases Prediction (Regression)(一)
COVID-19 Cases Prediction (Regression)
507 0
COVID-19 Cases Prediction (Regression)(一)
|
机器学习/深度学习 异构计算
COVID-19 Cases Prediction (Regression)(二)
COVID-19 Cases Prediction (Regression)
435 0
COVID-19 Cases Prediction (Regression)(二)