作用
分层的K折交叉验证器。
提供训练/测试索引以将数据拆分为训练/测试集。
此交叉验证对象是KFold的变体,它返回分层的折痕。折叠是通过保留每个类别的样品百分比来进行的。
参数解析
n_splits int,默认= 5
折数。必须至少为2。在0.22版中更改:n_splits默认值从3更改为5。shuffle bool,默认= False
在拆分成批次之前是否对每个班级的样本进行混洗。请注意,每个拆分内的样本都不会被混洗。random_state int,RandomState实例或无,默认=无
当shuffle为True时,random_state会影响索引的顺序,从而控制每个类别的每个折叠的随机性。否则,保留random_state为None。为多个函数调用传递可重复输出的int值
举例使用
import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 1, 1])
skf = StratifiedKFold(n_splits=2)
skf.get_n_splits(X, y)
print(skf)
StratifiedKFold(n_splits=2, random_state=None, shuffle=False)
for train_index, test_index in skf.split(X, y):
print("TRAIN:", train_index, "TEST:", test_index)