1.包含数据(即文本描述)以及分类标签的CSV
df = pd.read_csv('./output/csv_sanitized_16_.csv', dtype=str)
X = df['description_plus']
y = df['category_id']
2.此CSV包含看不见的数据(即文本描述),需要对其预测标签
df_2 = pd.read_csv('./output/csv_sanitized_2.csv', dtype=str)
X2 = df_2['description_plus']
对以上训练数据(项目1)进行操作的交叉验证功能。
def cross_val():
cv = KFold(n_splits=20)
vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5,
stop_words='english')
X_train = vectorizer.fit_transform(X)
clf = make_pipeline(preprocessing.StandardScaler(with_mean=False), svm.SVC(C=1))
scores = cross_val_score(clf, X_train, y, cv=cv)
print(scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
cross_val()
我需要知道如何将看不见的数据(项目2)传递给交叉验证功能,以及如何预测标签?
问题来源:stackoverflow
使用scores = cross_val_score(clf,X_train,y,cv = cv)
只能得到模型的交叉验证分数。cross_val_score将根据cv参数在内部将数据分为训练和测试。
因此,您获得的值是SVC的交叉验证精度。
要获得看不见的数据的分数,您可以首先拟合模型,例如
clf = make_pipeline(preprocessing.StandardScaler(with_mean=False), svm.SVC(C=1))
clf.fit(X_train, y) # the model is trained now
然后执行clf.score(X_unseen,y)
最后一个将在看不见的数据上返回模型的准确性。
* 编辑:做您想要的最好的方法是下面使用GridSearch 首先使用训练数据找到最佳模型,然后使用看不见的(测试)数据评估最佳模型: from sklearn import svm, datasets from sklearn.model_selection import GridSearchCV from sklearn.model_selection import train_test_split from sklearn.model_selection import cross_val_score
# load some data
iris = datasets.load_iris()
X, y = iris.data, iris.target
#split data to training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
# hyperparameter tunig of the SVC model
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
svc = svm.SVC()
# fit the GridSearch using the TRAINING data
grid_searcher = GridSearchCV(svc, parameters)
grid_searcher.fit(X_train, y_train)
#recover the best estimator (best parameters for the SVC, based on the GridSearch)
best_SVC_model = grid_searcher.best_estimator_
# Now, check how this best model behaves on the test set
cv_scores_on_unseen = cross_val_score(best_SVC_model, X_test, y_test, cv=5)
print(cv_scores_on_unseen.mean())
回答来源:stackoverflow
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。