Introduction

到目前为止,在本课程中,我们已经学习了神经网络如何解决回归问题。现在,我们将把神经网络应用于另一个常见的机器学习问题:分类。到目前为止,我们学到的大部分内容仍然适用。主要区别在于我们使用的损失函数以及我们希望最终层产生的输出类型。

Binary Classification

将数据分为两类是机器学习中常见的问题。你可能想预测客户是否有可能购买,信用卡交易是否属于欺诈行为,深空信号是否显示新行星存在的证据,或者医学检测是否显示疾病存在的证据。这些都是二分类问题。

在原始数据中,类别可能用字符串表示,例如“是”和“否”,或者“狗”和“猫”。在使用这些数据之前,我们需要分配一个类别标签:一个类别为“0”,另一个类别为“1”。分配数字标签可以将数据转换为神经网络可以使用的形式。

Accuracy and Cross-Entropy

准确率是衡量分类问题成功与否的众多指标之一。准确率是正确预测数与总预测数之比:准确率 = 正确数 / 总数。一个始终预测正确的模型的准确率得分为 1.0。在其他条件相同的情况下,当数据集中的类别出现频率大致相同时,准确率是一个合理的指标。

准确率(以及大多数其他分类指标)的问题在于它不能用作损失函数。随机梯度下降 (SGD) 需要一个平滑变化的损失函数,但准确率作为计数的比率,变化幅度很大。因此,我们必须选择一个替代函数来充当损失函数。这个替代函数就是交叉熵函数。

现在,回想一下,损失函数定义了网络在训练过程中的“目标”。对于回归,我们的目标是最小化预期结果与预测结果之间的距离。我们选择 MAE 来衡量这个距离。

对于分类问题,我们想要的是概率之间的距离,而交叉熵正是提供这种距离的。交叉熵是一种度量从一个概率分布到另一个概率分布的距离的方法。

Graphs of accuracy and cross-entropy.交叉熵会惩罚错误的概率预测。

我们的想法是,我们希望网络以概率“1.0”预测正确的类别。预测概率与“1.0”的差距越大,交叉熵损失就越大。

我们使用交叉熵的技术原因有些微妙,但本节的主要内容是:使用交叉熵作为分类损失;您可能关心的其他指标(例如准确率)也会随之提高。

Making Probabilities with the Sigmoid Function

交叉熵和准确度函数都需要概率作为输入,即从 0 到 1 的数字。为了将密集层产生的实值输出转换为概率,我们附加了一种新的激活函数,即Sigmoid 激活

The sigmoid graph is an 'S' shape with horizontal asymptotes at 0 to the left and 1 to the right.

Example - Binary Classification

现在就来试试吧!

电离层 数据集包含从聚焦于地球大气电离层的雷达信号获取的特征。其任务是确定信号显示的是某个物体的存在,还是仅仅是空气。

import pandas as pd
from IPython.display import display

ion = pd.read_csv('../input/dl-course-data/ion.csv', index_col=0)
display(ion.head())

df = ion.copy()
df['Class'] = df['Class'].map({'good': 0, 'bad': 1})

df_train = df.sample(frac=0.7, random_state=0)
df_valid = df.drop(df_train.index)

max_ = df_train.max(axis=0)
min_ = df_train.min(axis=0)

df_train = (df_train - min_) / (max_ - min_)
df_valid = (df_valid - min_) / (max_ - min_)
df_train.dropna(axis=1, inplace=True) # drop the empty feature in column 2
df_valid.dropna(axis=1, inplace=True)

X_train = df_train.drop('Class', axis=1)
X_valid = df_valid.drop('Class', axis=1)
y_train = df_train['Class']
y_valid = df_valid['Class']
V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V26 V27 V28 V29 V30 V31 V32 V33 V34 Class
1 1 0 0.99539 -0.05889 0.85243 0.02306 0.83398 -0.37708 1.00000 0.03760 -0.51171 0.41078 -0.46168 0.21266 -0.34090 0.42267 -0.54487 0.18641 -0.45300 good
2 1 0 1.00000 -0.18829 0.93035 -0.36156 -0.10868 -0.93597 1.00000 -0.04549 -0.26569 -0.20468 -0.18401 -0.19040 -0.11593 -0.16626 -0.06288 -0.13738 -0.02447 bad
3 1 0 1.00000 -0.03365 1.00000 0.00485 1.00000 -0.12062 0.88965 0.01198 -0.40220 0.58984 -0.22145 0.43100 -0.17365 0.60436 -0.24180 0.56045 -0.38238 good
4 1 0 1.00000 -0.45161 1.00000 1.00000 0.71216 -1.00000 0.00000 0.00000 0.90695 0.51613 1.00000 1.00000 -0.20099 0.25682 1.00000 -0.32382 1.00000 bad
5 1 0 1.00000 -0.02401 0.94140 0.06531 0.92106 -0.23255 0.77152 -0.16399 -0.65158 0.13290 -0.53206 0.02431 -0.62197 -0.05707 -0.59573 -0.04608 -0.65697 good

5 rows × 35 columns

我们将像定义回归任务一样定义我们的模型,但有一个例外。在最后一层包含一个“sigmoid”激活函数,以便模型能够计算出类别概率。

from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([
layers.Dense(4, activation='relu', input_shape=[33]),
layers.Dense(4, activation='relu'),
layers.Dense(1, activation='sigmoid'),
])

使用“compile”方法将交叉熵损失和准确率指标添加到模型中。对于二分类问题,请务必使用“binary”版本。(如果问题包含更多类别,情况会略有不同。)Adam 优化器在分类问题中也表现出色,因此我们将坚持使用它。

model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['binary_accuracy'],
)

这个特定问题中的模型可能需要相当多的时期才能完成训练,因此为了方便起见,我们将包含一个早期停止回调。

early_stopping = keras.callbacks.EarlyStopping(
patience=10,
min_delta=0.001,
restore_best_weights=True,
)

history = model.fit(
X_train, y_train,
validation_data=(X_valid, y_valid),
batch_size=512,
epochs=1000,
callbacks=[early_stopping],
verbose=0, # hide the output because we have so many epochs
)

我们将像往常一样查看学习曲线,并检查在验证集上获得的损失和准确率的最佳值。(请记住,提前停止会将权重恢复到获得这些值的权重。)

history_df = pd.DataFrame(history.history)
# Start the plot at epoch 5
history_df.loc[5:, ['loss', 'val_loss']].plot()
history_df.loc[5:, ['binary_accuracy', 'val_binary_accuracy']].plot()

print(("Best Validation Loss: {:0.4f}" +\
"\nBest Validation Accuracy: {:0.4f}")\
.format(history_df['val_loss'].min(),
history_df['val_binary_accuracy'].max()))
Best Validation Loss: 0.3534
Best Validation Accuracy: 0.8857

img

img