最大间隔分类与核技巧

前言

支持向量机(Support Vector Machine, SVM)是一种强大的分类算法,核心思想是找到最大间隔的决策超平面。通过核技巧,SVM可以处理非线性问题。


线性可分SVM

最大间隔

给定线性可分数据,存在无穷多个分割超平面。SVM选择间隔最大的那个。

超平面方程:$\mathbf{w}^T\mathbf{x} + b = 0$

点到超平面的距离:$\frac{ \mathbf{w}^T\mathbf{x} + b }{|\mathbf{w}|}$
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification, make_circles, make_moons
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC, LinearSVC
from sklearn.preprocessing import StandardScaler

np.random.seed(42)

# 生成线性可分数据
X_linear, y_linear = make_classification(n_samples=100, n_features=2, n_informative=2,
                                          n_redundant=0, n_clusters_per_class=1,
                                          class_sep=2, random_state=42)

# 可视化
plt.figure(figsize=(10, 8))
plt.scatter(X_linear[y_linear==0, 0], X_linear[y_linear==0, 1], c='blue', label='Class 0')
plt.scatter(X_linear[y_linear==1, 0], X_linear[y_linear==1, 1], c='red', label='Class 1')

# 绘制几条可能的分割线
for i, (w, b, color) in enumerate([
    ([1, 0.5], -0.5, 'gray'),
    ([0.5, 1], 0, 'gray'),
    ([0.8, 0.6], -0.3, 'gray')
]):
    w = np.array(w)
    x_line = np.linspace(-3, 3, 100)
    y_line = -(w[0] * x_line + b) / w[1]
    plt.plot(x_line, y_line, color=color, linestyle='--', alpha=0.5)

plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('线性可分数据:哪条分割线最好?')
plt.legend()
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.show()

优化目标

\[\max_{\mathbf{w}, b} \frac{2}{\|\mathbf{w}\|}\]

等价于:

\[\min_{\mathbf{w}, b} \frac{1}{2}\|\mathbf{w}\|^2\]

约束条件:$y_i(\mathbf{w}^T\mathbf{x}_i + b) \geq 1$

# 训练SVM
svm = SVC(kernel='linear', C=1000)  # C大表示硬间隔
svm.fit(X_linear, y_linear)

# 可视化决策边界和间隔
def plot_svm_decision_boundary(clf, X, y, ax, title):
    xx, yy = np.meshgrid(np.linspace(X[:, 0].min()-1, X[:, 0].max()+1, 200),
                         np.linspace(X[:, 1].min()-1, X[:, 1].max()+1, 200))
    
    Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, levels=np.linspace(Z.min(), Z.max(), 50), cmap='RdYlBu', alpha=0.5)
    ax.contour(xx, yy, Z, levels=[-1, 0, 1], colors=['blue', 'black', 'red'], 
               linestyles=['--', '-', '--'], linewidths=[1, 2, 1])
    
    ax.scatter(X[y==0, 0], X[y==0, 1], c='blue', edgecolors='k', label='Class 0')
    ax.scatter(X[y==1, 0], X[y==1, 1], c='red', edgecolors='k', label='Class 1')
    
    # 标记支持向量
    ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],
               s=100, facecolors='none', edgecolors='black', linewidths=2,
               label='支持向量')
    
    ax.set_title(title)
    ax.legend()

fig, ax = plt.subplots(1, 1, figsize=(10, 8))
plot_svm_decision_boundary(svm, X_linear, y_linear, ax, 'SVM决策边界与间隔')
plt.show()

print(f"支持向量数量: {len(svm.support_vectors_)}")
print(f"权重向量: {svm.coef_}")
print(f"偏置: {svm.intercept_}")

软间隔SVM

松弛变量

现实数据往往线性不可分,引入松弛变量 $\xi_i$ 允许一些点违反约束。

\[\min_{\mathbf{w}, b, \xi} \frac{1}{2}\|\mathbf{w}\|^2 + C\sum_{i=1}^{N}\xi_i\]

约束:$y_i(\mathbf{w}^T\mathbf{x}_i + b) \geq 1 - \xi_i$,$\xi_i \geq 0$

C参数的影响

# 带噪声的数据
X_noisy, y_noisy = make_classification(n_samples=100, n_features=2, n_informative=2,
                                        n_redundant=0, n_clusters_per_class=1,
                                        class_sep=1, flip_y=0.1, random_state=42)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

C_values = [0.1, 1, 100]

for ax, C in zip(axes, C_values):
    svm_temp = SVC(kernel='linear', C=C)
    svm_temp.fit(X_noisy, y_noisy)
    plot_svm_decision_boundary(svm_temp, X_noisy, y_noisy, ax, f'C={C}\n支持向量数={len(svm_temp.support_vectors_)}')

plt.tight_layout()
plt.show()
C值 特点
宽间隔,允许更多违反
窄间隔,更严格分类

核技巧

非线性问题

# 生成非线性可分数据
X_circles, y_circles = make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=42)

# 线性SVM
svm_linear = SVC(kernel='linear')
svm_linear.fit(X_circles, y_circles)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 100), np.linspace(-1.5, 1.5, 100))
Z = svm_linear.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.3, cmap='RdYlBu')
plt.scatter(X_circles[y_circles==0, 0], X_circles[y_circles==0, 1], c='blue')
plt.scatter(X_circles[y_circles==1, 0], X_circles[y_circles==1, 1], c='red')
plt.title(f'线性SVM准确率: {svm_linear.score(X_circles, y_circles):.2f}')

# RBF核SVM
svm_rbf = SVC(kernel='rbf', gamma=1)
svm_rbf.fit(X_circles, y_circles)

plt.subplot(1, 2, 2)
Z = svm_rbf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.3, cmap='RdYlBu')
plt.scatter(X_circles[y_circles==0, 0], X_circles[y_circles==0, 1], c='blue')
plt.scatter(X_circles[y_circles==1, 0], X_circles[y_circles==1, 1], c='red')
plt.title(f'RBF核SVM准确率: {svm_rbf.score(X_circles, y_circles):.2f}')

plt.tight_layout()
plt.show()

核函数

核函数 $K(\mathbf{x}_i, \mathbf{x}_j) = \phi(\mathbf{x}_i)^T\phi(\mathbf{x}_j)$

隐式地在高维空间计算内积,无需显式映射。

核函数 公式 适用场景
线性核 $\mathbf{x}_i^T\mathbf{x}_j$ 线性可分
多项式核 $(γ\mathbf{x}_i^T\mathbf{x}_j + r)^d$ 多项式关系
RBF/高斯核 $\exp(-γ|\mathbf{x}_i-\mathbf{x}_j|^2)$ 通用
Sigmoid核 $\tanh(γ\mathbf{x}_i^T\mathbf{x}_j + r)$ 类神经网络
# 不同核函数对比
X_moons, y_moons = make_moons(n_samples=200, noise=0.15, random_state=42)

fig, axes = plt.subplots(2, 2, figsize=(12, 12))

kernels = ['linear', 'poly', 'rbf', 'sigmoid']
params = [
    {},
    {'degree': 3, 'gamma': 'auto'},
    {'gamma': 'auto'},
    {'gamma': 'auto'}
]

xx, yy = np.meshgrid(np.linspace(-1.5, 2.5, 100), np.linspace(-1, 1.5, 100))

for ax, kernel, param in zip(axes.ravel(), kernels, params):
    svm_temp = SVC(kernel=kernel, **param)
    svm_temp.fit(X_moons, y_moons)
    
    Z = svm_temp.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, alpha=0.3, cmap='RdYlBu')
    ax.scatter(X_moons[y_moons==0, 0], X_moons[y_moons==0, 1], c='blue')
    ax.scatter(X_moons[y_moons==1, 0], X_moons[y_moons==1, 1], c='red')
    ax.scatter(svm_temp.support_vectors_[:, 0], svm_temp.support_vectors_[:, 1],
               s=50, facecolors='none', edgecolors='black')
    ax.set_title(f'{kernel}核, 准确率={svm_temp.score(X_moons, y_moons):.2f}')

plt.tight_layout()
plt.show()

RBF核参数

gamma参数

gamma控制单个样本的影响范围:

  • gamma大:影响范围小,决策边界复杂
  • gamma小:影响范围大,决策边界平滑
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

gammas = [0.01, 0.1, 1, 10, 100, 1000]

for ax, gamma in zip(axes.ravel(), gammas):
    svm_temp = SVC(kernel='rbf', gamma=gamma, C=1)
    svm_temp.fit(X_moons, y_moons)
    
    Z = svm_temp.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, alpha=0.3, cmap='RdYlBu')
    ax.scatter(X_moons[y_moons==0, 0], X_moons[y_moons==0, 1], c='blue')
    ax.scatter(X_moons[y_moons==1, 0], X_moons[y_moons==1, 1], c='red')
    ax.set_title(f'gamma={gamma}\n支持向量数={len(svm_temp.support_vectors_)}')

plt.suptitle('gamma参数的影响', fontsize=14)
plt.tight_layout()
plt.show()

C和gamma的联合影响

from sklearn.model_selection import cross_val_score

# 网格搜索
C_range = [0.1, 1, 10, 100]
gamma_range = [0.01, 0.1, 1, 10]

scores = np.zeros((len(C_range), len(gamma_range)))

for i, C in enumerate(C_range):
    for j, gamma in enumerate(gamma_range):
        svm_temp = SVC(kernel='rbf', C=C, gamma=gamma)
        cv_scores = cross_val_score(svm_temp, X_moons, y_moons, cv=5)
        scores[i, j] = cv_scores.mean()

plt.figure(figsize=(8, 6))
plt.imshow(scores, interpolation='nearest', cmap='viridis')
plt.colorbar(label='CV准确率')
plt.xticks(range(len(gamma_range)), gamma_range)
plt.yticks(range(len(C_range)), C_range)
plt.xlabel('gamma')
plt.ylabel('C')
plt.title('C和gamma的网格搜索')

# 标注最佳位置
best_idx = np.unravel_index(scores.argmax(), scores.shape)
plt.scatter(best_idx[1], best_idx[0], marker='*', s=300, c='red')

for i in range(len(C_range)):
    for j in range(len(gamma_range)):
        plt.text(j, i, f'{scores[i,j]:.2f}', ha='center', va='center', color='white')

plt.show()

SVM回归(SVR)

from sklearn.svm import SVR

# 生成回归数据
np.random.seed(42)
X_reg = np.sort(np.random.rand(100) * 10).reshape(-1, 1)
y_reg = np.sin(X_reg).ravel() + np.random.randn(100) * 0.2

X_plot = np.linspace(0, 10, 100).reshape(-1, 1)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

kernels = ['linear', 'poly', 'rbf']

for ax, kernel in zip(axes, kernels):
    svr = SVR(kernel=kernel, C=100, epsilon=0.1)
    svr.fit(X_reg, y_reg)
    
    y_pred = svr.predict(X_plot)
    
    ax.scatter(X_reg, y_reg, c='blue', alpha=0.5, label='数据')
    ax.plot(X_plot, y_pred, 'r-', linewidth=2, label='SVR预测')
    ax.plot(X_plot, np.sin(X_plot), 'g--', alpha=0.5, label='真实函数')
    ax.set_title(f'SVR ({kernel}核)\nR²={svr.score(X_reg, y_reg):.3f}')
    ax.legend()

plt.tight_layout()
plt.show()

多分类SVM

SVM本身是二分类器,多分类策略:

  • One-vs-Rest (OvR):K个分类器
  • One-vs-One (OvO):K(K-1)/2个分类器
from sklearn.datasets import load_iris
from sklearn.metrics import classification_report

# 多分类示例
iris = load_iris()
X_iris, y_iris = iris.data[:, [2, 3]], iris.target

X_train_i, X_test_i, y_train_i, y_test_i = train_test_split(
    X_iris, y_iris, test_size=0.2, random_state=42)

# 训练
svm_multi = SVC(kernel='rbf', gamma='auto', decision_function_shape='ovr')
svm_multi.fit(X_train_i, y_train_i)

print(f"测试准确率: {svm_multi.score(X_test_i, y_test_i):.4f}")

# 可视化
xx, yy = np.meshgrid(np.linspace(X_iris[:, 0].min()-0.5, X_iris[:, 0].max()+0.5, 100),
                     np.linspace(X_iris[:, 1].min()-0.5, X_iris[:, 1].max()+0.5, 100))

Z = svm_multi.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)

plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
for i, color in enumerate(['blue', 'orange', 'green']):
    plt.scatter(X_iris[y_iris==i, 0], X_iris[y_iris==i, 1], 
                c=color, label=iris.target_names[i])
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.title('SVM多分类')
plt.legend()
plt.show()

实战:手写数字识别

from sklearn.datasets import load_digits
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 加载数据
digits = load_digits()
X_digits, y_digits = digits.data, digits.target

X_train_d, X_test_d, y_train_d, y_test_d = train_test_split(
    X_digits, y_digits, test_size=0.2, random_state=42)

# 标准化
scaler = StandardScaler()
X_train_d = scaler.fit_transform(X_train_d)
X_test_d = scaler.transform(X_test_d)

# 训练SVM
svm_digits = SVC(kernel='rbf', gamma='scale', C=10)
svm_digits.fit(X_train_d, y_train_d)

y_pred_d = svm_digits.predict(X_test_d)

print(f"测试准确率: {svm_digits.score(X_test_d, y_test_d):.4f}")

# 混淆矩阵
plt.figure(figsize=(10, 8))
cm = confusion_matrix(y_test_d, y_pred_d)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('预测')
plt.ylabel('实际')
plt.title('SVM手写数字识别混淆矩阵')
plt.show()

SVM vs 其他算法

from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
import time

classifiers = {
    'SVM (linear)': SVC(kernel='linear'),
    'SVM (rbf)': SVC(kernel='rbf'),
    'Logistic Regression': LogisticRegression(max_iter=1000),
    'KNN': KNeighborsClassifier(),
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42)
}

results = []
for name, clf in classifiers.items():
    start = time.time()
    clf.fit(X_train_d, y_train_d)
    train_time = time.time() - start
    
    accuracy = clf.score(X_test_d, y_test_d)
    
    results.append({
        'Model': name,
        'Accuracy': accuracy,
        'Train Time': train_time
    })
    print(f"{name}: Acc={accuracy:.4f}, Time={train_time:.3f}s")

import pandas as pd
pd.DataFrame(results)

常见问题

Q1: SVM的优缺点?

优点 缺点
高维数据表现好 大规模数据训练慢
核技巧处理非线性 对参数敏感
泛化能力强 难以解释
内存效率高(只存支持向量) 多分类需要扩展

Q2: 如何选择核函数?

  1. 先尝试线性核(快速基准)
  2. 数据量小、特征多:线性核
  3. 通用情况:RBF核
  4. 可用交叉验证选择

Q3: SVM需要标准化吗?

必须标准化! SVM基于距离,不同尺度的特征会影响结果。

Q4: 如何处理大规模数据?

  • 使用LinearSVC(基于liblinear,更快)
  • 使用SGDClassifier(随机梯度下降)
  • 采样子集训练

总结

概念 说明
核心思想 最大化分类间隔
硬间隔 线性可分,严格分类
软间隔 允许违反,参数C控制
核技巧 隐式映射到高维空间
关键参数 C(惩罚)、gamma(RBF影响范围)

参考资料

  • Cortes, C., & Vapnik, V. (1995). “Support-vector networks”
  • 《统计学习方法》李航 第7章
  • scikit-learn 文档:SVM

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:《 机器学习基础系列——支持向量机 》

本文链接:http://localhost:3015/ai/%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA.html

本文最后一次更新为 天前,文章中的某些内容可能已过时!