—— 25.2.23
ReLU广泛应用于卷积神经网络(CNN)和全连接网络,尤其在图像分类(如ImageNet)、语音识别等领域表现优异。其高效性和非线性特性使其成为深度学习默认激活函数的首选
一、定义与数学表达式
ReLU(Rectified Linear Unit,修正线性单元)是一种分段线性激活函数,
其数学表达式为:ReLU(x)=max(0,x)
即当输入 x 大于 0 时,输出为 x;当 x≤0 时,输出为 0。
二、核心特点
非线性特性:通过引入分段线性特性,ReLU为神经网络引入非线性,使其能拟合复杂函数。
计算高效:仅通过阈值判断(x>0)即可完成计算,避免了指数运算(如Sigmoid、Tanh),显著提升速度。
缓解梯度消失:在 x>0 时梯度恒为 1,反向传播时梯度不会饱和,加速收敛。
稀疏激活性:负输入时输出为 0,导致部分神经元“休眠”,减少参数依赖和过拟合风险。
三、优点
简单高效:实现和计算成本低,适合深度网络。
收敛速度快:相比Sigmoid/Tanh,ReLU在训练中梯度更稳定,收敛更快。
非零中心性:输出范围为 [0,+∞),虽非严格零中心,但简化了优化过程
四、局限性
Dead ReLU问题:若神经元输入长期为负,梯度恒为 0,导致权重无法更新,神经元“死亡”。
非零中心性:输出偏向非负值,可能影响梯度下降效率。
对初始化敏感:若学习率过高,负输入区域可能使神经元永久失效。
五、变体
Leaky ReLU:允许负输入时输出 αx(α为小常数,如0.01)。
PReLU(Parametric ReLU):将 α 设为可学习参数,动态调整负区斜率。
ELU(Exponential Linear Unit):负输入时输出 α(ex−1),使输出均值接近零。
Swish:自门控激活函数,结合ReLU和Sigmoid特性,平滑且无上界。
六、代码示例
1.通过 nn.ReLU()
作为网络层
nn.ReLU() :PyTorch 中的修正线性单元(ReLU)激活函数模块,用于神经网络中引入非线性。其功能是将输入张量中所有负值置为 0,保留正值不变
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
inplace | bool | 否 | 是否原地操作(直接修改输入张量)。 |
默认值为 False ,此时会返回新张量。若设为 True ,则直接在原张量上操作。 |
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的网络,包含两个线性层和 ReLU 激活
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 256) # 输入层:784 → 256
self.relu = nn.ReLU() # ReLU 激活层
self.fc2 = nn.Linear(256, 10) # 输出层:256 → 10(如分类任务)
def forward(self, x):
x = self.relu(self.fc1(x)) # 在第一层后应用 ReLU
x = self.fc2(x)
return x
# 初始化网络、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 输入数据示例(如 MNIST 图像,形状为 [batch_size, 784])
input_data = torch.randn(32, 784)
# 前向传播
output = model(input_data)
print(output.shape) # 输出形状: (32, 10)
2. 直接使用 torch.relu()
函数
torch.relu(): PyTorch 中实现修正线性单元(ReLU)激活函数的函数
其数学表达式为:ReLU(x)=max(0,x)
参数名称 | 类型 | 是否必填 | 说明 |
---|---|---|---|
inplace | bool | 否 | 是否原地修改输入张量。若为 True ,则直接修改输入张量以节省内存;若为 False (默认),则返回新张量。 |
import torch
# 示例输入
x = torch.tensor([-1.0, 0.0, 1.0])
# 应用 ReLU 函数(非原地)
y = torch.relu(x)
print(y) # 输出: tensor([0., 0., 1.])
# 应用 ReLU 函数(原地)
torch.relu_(x)
print(x) # 输出: tensor([0., 0., 1.]),原始张量被修改[1,7](@ref)。