首頁人工智能常見問題正文

ResNet解決了什么問題?結(jié)構(gòu)有何特點(diǎn)?

更新時間:2023-07-21 來源:黑馬程序員 瀏覽量:

IT培訓(xùn)班

  ResNet(Residual Network)是由Kaiming He等人提出的深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),它在2015年的ImageNet圖像識別競賽中取得了非常顯著的成績,引起了廣泛的關(guān)注。ResNet的主要貢獻(xiàn)是解決了深度神經(jīng)網(wǎng)絡(luò)的梯度消失問題,使得可以訓(xùn)練更深的網(wǎng)絡(luò),從而獲得更好的性能。

  問題:在傳統(tǒng)的深度神經(jīng)網(wǎng)絡(luò)中,隨著網(wǎng)絡(luò)層數(shù)的增加,梯度在反向傳播過程中逐漸變小,導(dǎo)致淺層網(wǎng)絡(luò)的權(quán)重更新幾乎沒有效果,難以訓(xùn)練。這被稱為梯度消失問題。

  ResNet的解決方法:ResNet引入了“殘差塊”(residual block),每個殘差塊包含了一條“跳躍連接”(shortcut connection),它允許梯度能夠直接穿過塊,從而避免了梯度消失問題。因此,深度網(wǎng)絡(luò)可以通過恒等映射(identity mapping)來學(xué)習(xí)殘差,使得網(wǎng)絡(luò)在增加深度時反而變得更容易訓(xùn)練。

  ResNet結(jié)構(gòu)特點(diǎn):

  1.殘差塊:每個殘差塊由兩個或三個卷積層組成,它們的輸出通過跳躍連接與塊的輸入相加,形成殘差(residual)。

  2.跳躍連接:跳躍連接允許梯度直接流過塊,有助于避免梯度消失問題。

  3.批量歸一化:ResNet中廣泛使用批量歸一化層來加速訓(xùn)練并穩(wěn)定網(wǎng)絡(luò)。

  4.殘差塊堆疊:ResNet通過堆疊多個殘差塊來構(gòu)建深層網(wǎng)絡(luò)。深度可以根據(jù)任務(wù)的復(fù)雜性而自由選擇。

  接下來我們看一個簡化的ResNet代碼演示(使用TensorFlow):

import tensorflow as tf
from tensorflow.keras import layers, models

# 定義一個基本的殘差塊
def residual_block(x, filters, downsample=False):
    # 如果downsample為True,使用步長為2的卷積層實(shí)現(xiàn)降采樣
    stride = 2 if downsample else 1
    
    # 記錄輸入,以便在跳躍連接時使用
    identity = x
    
    # 第一個卷積層
    x = layers.Conv2D(filters, kernel_size=3, strides=stride, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    # 第二個卷積層
    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    # 如果進(jìn)行了降采樣,需要對identity進(jìn)行相應(yīng)處理,保證維度一致
    if downsample:
        identity = layers.Conv2D(filters, kernel_size=1, strides=stride, padding='same')(identity)
        identity = layers.BatchNormalization()(identity)
    
    # 跳躍連接:將卷積層的輸出與輸入相加
    x = layers.add([x, identity])
    x = layers.Activation('relu')(x)
    
    return x

# 構(gòu)建ResNet網(wǎng)絡(luò)
def ResNet(input_shape, num_classes):
    input_img = layers.Input(shape=input_shape)
    
    # 第一個卷積層
    x = layers.Conv2D(64, kernel_size=7, strides=2, padding='same')(input_img)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
    
    # 堆疊殘差塊組成網(wǎng)絡(luò)
    x = residual_block(x, filters=64)
    x = residual_block(x, filters=64)
    x = residual_block(x, filters=64)
    
    x = residual_block(x, filters=128, downsample=True)
    x = residual_block(x, filters=128)
    x = residual_block(x, filters=128)
    
    x = residual_block(x, filters=256, downsample=True)
    x = residual_block(x, filters=256)
    x = residual_block(x, filters=256)
    
    x = residual_block(x, filters=512, downsample=True)
    x = residual_block(x, filters=512)
    x = residual_block(x, filters=512)
    
    # 全局平均池化
    x = layers.GlobalAveragePooling2D()(x)
    # 全連接層輸出
    x = layers.Dense(num_classes, activation='softmax')(x)
    
    # 創(chuàng)建模型
    model = models.Model(inputs=input_img, outputs=x)
    return model

# 在這里定義輸入圖像的形狀和類別數(shù)
input_shape = (224, 224, 3)
num_classes = 1000

# 構(gòu)建ResNet模型
model = ResNet(input_shape, num_classes)
model.summary()

  請注意,上述代碼是一個簡化版本的ResNet網(wǎng)絡(luò),實(shí)際上,ResNet有不同的變體,可以根據(jù)任務(wù)的復(fù)雜性和資源的可用性選擇適合的ResNet結(jié)構(gòu)。

分享到:
在線咨詢 我要報(bào)名
和我們在線交談!