张芷铭的个人博客

K 折交叉验证的步骤

#机器学习 #训练

K 折交叉验证(K-Fold Cross Validation) 是一种常用的模型评估方法,特别适用于数据量较小的情况。它可以充分利用数据,提高模型的泛化能力,并减少因数据划分不同导致的模型性能波动。


K 折交叉验证的步骤

  1. 将数据集随机拆分成 K 份(folds)
  2. 进行 K 次训练和验证
    1. 每次选择 1 份作为验证集(Eval Set),其余 K-1 份作为训练集(Train Set)
    2. 训练模型后,在验证集上评估性能。
  3. 计算 K 次的评估指标均值 作为最终性能指标。

示例(K=5 时的划分方式)

数据集:  [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]
折数K = 5:
第1次:训练集 [ 3, 4, 5, 6, 7, 8, 9, 10 ],验证集 [ 1, 2 ]
第2次:训练集 [ 1, 2, 5, 6, 7, 8, 9, 10 ],验证集 [ 3, 4 ]
第3次:训练集 [ 1, 2, 3, 4, 7, 8, 9, 10 ],验证集 [ 5, 6 ]
第4次:训练集 [ 1, 2, 3, 4, 5, 6, 9, 10 ],验证集 [ 7, 8 ]
第5次:训练集 [ 1, 2, 3, 4, 5, 6, 7, 8 ],验证集 [ 9, 10 ]

最终结果:计算 5 次训练的平均性能(如准确率、F1-score)。


2. 为什么使用 K 折交叉验证?

方法训练集验证集评估稳定性
简单划分(Hold-Out)80%20%可能受数据划分影响
K 折交叉验证K-1 份1 份(K 次)更稳定,减少数据划分影响

适用于小数据集:能充分利用数据,提高评估稳定性。

避免过拟合/欠拟合:防止模型依赖特定的数据划分。


代码实现

使用 sklearn.model_selection.KFold

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from sklearn.model_selection import KFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# 加载数据
X, y = load_iris(return_X_y=True)

# 定义K折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# 选择模型
model = RandomForestClassifier()

# 计算交叉验证分数
scores = cross_val_score(model, X, y, cv=kf, scoring='accuracy')

# 输出结果
print("每折的准确率:", scores)
print("平均准确率:", scores.mean())

参数解释

• n_splits=5:5 折交叉验证。

• shuffle=True:打乱数据,提高随机性。

• cross_val_score():自动进行 K 次训练和评估。


K 值的选择

K 值适用情况优缺点
K=5 或 K=10常用,适用于大部分情况训练开销适中,评估稳定
K=N([[留一交叉验证(LOOCV)]],LOOCV)超小数据集(如医学)计算量大,评估最稳定
K 过大(如 K=20, 30)数据量较大但不均匀计算成本高
K 过小(如 K=2, 3)数据少,训练开销小评估不稳定

推荐

• 一般选择 K=5 或 K=10,平衡计算成本和评估稳定性。

• 若数据量特别小,可以考虑 K=N(LOOCV)


不同类型的交叉验证

分层 K 折交叉验证(Stratified K-Fold)

• 适用于 类别不均衡的分类任务(例如少数类别占比很低)。

确保每一折的类别分布与整体数据集相似

示例

1
2
3
4
5
6
7
from sklearn.model_selection import StratifiedKFold

# 分层K折交叉验证
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X, y, cv=skf, scoring='accuracy')

print("平均准确率:", scores.mean())

时间序列 K 折交叉验证(Time Series Split)

• 适用于 时间序列数据(如股票预测、天气预报)。

• 确保训练集的时间点 始终早于 验证集。

示例

1
2
3
4
5
6
7
from sklearn.model_selection import TimeSeriesSplit

# 时间序列 K 折交叉验证
tscv = TimeSeriesSplit(n_splits=5)
scores = cross_val_score(model, X, y, cv=tscv, scoring='accuracy')

print("时间序列交叉验证得分:", scores)

K 折交叉验证 vs 超参数调优

K 折交叉验证可以结合 超参数搜索,如 网格搜索(Grid Search)贝叶斯优化

示例:K 折 + 网格搜索

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from sklearn.model_selection import GridSearchCV

# 定义超参数搜索范围
param_grid = {'n_estimators': [10, 50, 100], 'max_depth': [3, 5, 10]}

# 交叉验证+网格搜索
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='accuracy', n_jobs=-1)
grid_search.fit(X, y)

# 输出最优超参数
print("最佳参数:", grid_search.best_params_)
print("最佳分数:", grid_search.best_score_)

总结

方法适用场景计算开销适用数据类型
简单划分(Hold-Out)大数据集,简单任务分类、回归
K 折交叉验证(K-Fold)适中数据量,标准任务中等分类、回归
分层 K 折交叉验证(Stratified K-Fold)类别不均衡的分类问题中等分类
时间序列 K 折(Time Series Split)时间序列预测时间序列

选择建议

• 若数据量 充足,用 K=5 或 K=10 进行 K 折交叉验证。

• 若数据类别 不均衡,用 Stratified K-Fold

• 若数据是 时间序列,用 Time Series Split

• 若数据量 极小,用 K=N(LOOCV),但计算量大。


总结一句话:K 折交叉验证是一种强大的模型评估方法,适用于数据量较小的情况,可以提高模型泛化能力,减少数据划分的影响!

💬 评论