训练误差与泛化误差的说明

简介: 训练误差与泛化误差的说明

1. 训练误差与泛化误差的定义

训练误差和泛化误差是评估机器学习模型性能的两个重要概念。

1.1 训练误差(Training Error)

训练误差是指模型在训练数据集上表现的误差,也就是模型在学习过程中对已知样本的预测错误程度。它是通过计算模型在训练集上的损失函数(如均方误差、交叉熵等)来度量的。训练误差越低,通常表示模型在学习过程中对训练数据的拟合越好。

1.2 泛化误差(Generalization Error):

泛化误差则是指模型在未见过的新数据样本上的预期误差,也就是模型对未知数据的预测能力。它反映了模型在实际应用中对新输入的处理能力。因为我们无法直接计算泛化误差,因为未来的数据是未知的,所以通常通过在独立的测试集(与训练集不同的数据集)上评估模型的性能来估计泛化误差。

训练误差和泛化误差的关系:

  • 在训练过程中,我们通常希望尽可能降低训练误差,但是过于关注降低训练误差可能会导致模型过度拟合(Overfitting),即模型过于复杂,对训练数据的学习过于精细,以至于对训练数据以外的新数据表现不佳,即泛化误差较高。
  • 相反,如果模型过于简单,可能无法充分捕捉训练数据的特性,导致训练误差和泛化误差都较高,这种情况称为欠拟合(Underfitting)。
  • 为了获得良好的泛化性能,我们需要找到一个平衡点,使得模型既能很好地拟合训练数据,又能对未见过的数据有较好的预测能力。这通常通过使用正则化技术(如L2正则化、Dropout等)、早期停止(Early Stopping)、增加数据集大小或使用更复杂的模型结构等方式来实现。

讲定义太无聊了,下面直接通过实例说明。

2. 通过实例说明训练误差与泛化误差

假设我们要构建一个网络模型来拟合下面的训练数据:

- 输入:trianset_x = [1, 2, 3, 4, 5]

- 输出:trainset_y = [1, 4, 9, 16, 25]

然后我们构建了一个网络模型A进行训练,经过训练后测试,输入trainset_x输出为:

- A_output = [1.1, 4.3, 8.9, 15.6, 26.1]

这里网络模型A训练误差即为loss(trainset_y, A_output), 其中loss为损失函数。

这里,我们可以看到网络模型A的效果并不太好,它的训练误差还是比较大的。于是我们又构建一个网络模型B

网络模型B训练后输入trainset_x输出为:

- B_output = [1.00001, 4.00001, 9.00001, 16.00001, 25.00001]

可见,网络模型B的训练误差已经远远小于网络模型A。

那么网络模型B就一定比网络模型A好吗?

肯定不是!!

要知道,我们构建网络模型的目的是实现输入到输出的预测,是要“泛化”到所有的输入数据都能有准确的输出,而不是仅仅关注在训练数据上。

回到上面的例子,如果再输入一个测试数据testset_x = [10, 20],真实的期望输出testset_y = [100, 400],而此时

网络模型A的输出为:A_test_output = [93.7, 412.8]

网络模型B的输出为:B_test_output = [-12.89, 1023,432]

此时的误差loss(A_test_output, testset_y)loss(B_test_output, testset_y)即为泛化误差。可见网络模型A的泛化误差明显小于网络模型B的泛化误差。

可以直观看出网络模型A是更好的模型,所以机器学习模型应该更加关注泛化误差的降低

实际应用中,类似网络模型B这种离谱输出的情况很有可能是因为过拟合导致的。

3. 通过PyTorch实例说明训练误差与泛化误差

以下是一个简单的PyTorch实例,使用线性回归模型来说明训练误差与泛化误差:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
 
# 生成一个简单的线性回归数据集
X, y = make_regression(n_samples=1000, n_features=1, noise=20, random_state=42)
 
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
# 将数据转换为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)
 
# 定义线性回归模型
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)
 
    def forward(self, x):
        return self.linear(x)
 
model = LinearRegression()
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    # 前向传播
    y_pred = model(X_train)
 
    # 计算训练误差(均方误差)
    train_loss = criterion(y_pred, y_train)
 
    # 反向传播和优化
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
 
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss.item():.4f}")
 
# 测试模型并计算泛化误差
with torch.no_grad():
    y_pred_test = model(X_test)
    test_loss = criterion(y_pred_test, y_test)
 
print(f"Test Loss (Generalization Error): {test_loss.item():.4f}")

在这个例子中:

  • 我们使用sklearn.datasets.make_regression生成了一个简单的线性回归数据集,并将其划分为训练集和测试集。
  • 定义了一个简单的线性回归模型,该模型只有一个输入特征和一个输出特征。
  • 使用均方误差(MSE)作为损失函数,并使用SGD优化器进行训练。
  • 在每个训练epoch结束后,我们计算了训练误差(即均方误差)。
  • 训练完成后,我们在测试集上评估模型性能,计算了测试误差(也称为泛化误差)。

在这个例子中,训练误差是在训练集上计算的模型预测值与真实值之间的差异,而泛化误差则是模型在未见过的测试集上的表现,反映了模型对新数据的预测能力。我们的目标是通过训练找到一个能够在训练集上表现良好并且在测试集上具有较低泛化误差的模型。


相关文章
|
机器学习/深度学习 数据采集 监控
构建高效机器学习模型的五大关键步骤
在数据科学领域,搭建一个高效的机器学习模型是实现数据驱动决策的核心。本文详细阐述了从数据预处理到模型评估五个关键步骤,旨在为读者提供一个清晰的建模流程。文中不仅介绍了各个步骤的理论依据,还结合了实用的技术细节,以期帮助读者在实际工作中构建出既健壮又精确的机器学习系统。
418 5
|
机器学习/深度学习 算法
大模型开发:什么是过拟合和欠拟合?你如何防止它们?
机器学习中,过拟合和欠拟合影响模型泛化能力。过拟合是模型对训练数据过度学习,测试集表现差,可通过正则化、降低模型复杂度或增加训练数据来缓解。欠拟合则是模型未能捕捉数据趋势,解决方案包括增加模型复杂度、添加特征或调整参数。平衡两者需通过实验、交叉验证和超参数调优。
1833 0
|
存储 机器学习/深度学习 前端开发
通义灵码的技术架构
通义灵码的技术架构
|
机器学习/深度学习 算法
机器学习算法之欠拟合和过拟合
机器学习算法之欠拟合和过拟合
|
消息中间件 安全 API
《阿里云产品四月刊》—Apache RocketMQ ACL 2.0 全新升级(1)
阿里云瑶池数据库云原生化和一体化产品能力升级,多款产品更新迭代
528 1
《阿里云产品四月刊》—Apache RocketMQ ACL 2.0 全新升级(1)
|
机器学习/深度学习 算法 PyTorch
SGD、Adam
【9月更文挑战第23天】
224 6
|
存储 安全 区块链
区块链与游戏:颠覆传统的数字娱乐新纪元
**区块链技术颠覆游戏行业,赋予玩家真实所有权,增强资产安全与经济系统创新。去中心化、不可篡改的特性确保公平性,智能合约驱动新盈利模式。虽有技术复杂性与扩展性挑战,但未来区块链游戏有望带来更丰富、安全、公平的体验,推动行业持续革新。**
区块链与游戏:颠覆传统的数字娱乐新纪元
|
存储 数据可视化 PyTorch
PyTorch中 Datasets & DataLoader 的介绍
PyTorch中 Datasets & DataLoader 的介绍
326 0
|
机器学习/深度学习 负载均衡 算法
训练Backbone你还用EMA?ViT训练的大杀器EWA升级来袭
训练Backbone你还用EMA?ViT训练的大杀器EWA升级来袭
410 1
|
机器学习/深度学习 数据可视化 PyTorch
Pytorch 最全入门介绍,Pytorch入门看这一篇就够了(二)
Pytorch 最全入门介绍,Pytorch入门看这一篇就够了
376 2