模型提取攻击摸索

本文记录了我针对“模型提取攻击(Model Extraction Attacks, MEAs)”进行的摸索。模型窃取本质上是利用黑盒接口的查询-响应机制,通过训练影子模型(Substitute Model)来复制靶机模型(Victim Model)的决策逻辑

   整个研究从最初的简单的基准模型的训练开始,到软标签(Soft-label)窃取,再到硬标签(Hard-label)环境下的决策边界探测以及有防御策略下的窃取.相关的攻击和防御策略如下

攻击

攻击方法

  • 替代模型训练 (Substitute Model Training): 利用查询反馈训练功能等价的影子模型。
  • 方程求解与参数还原: 针对简单模型或特定激活函数(如 ReLU)进行权重窃取。
  • 解释引导与梯度估计: 利用热力图或样本梯度加速窃取(如 SPSG 策略)。

攻击数据选择

  • 问题域攻击 (Problem Domain): 使用与靶机同分布的数据(如本项目对 CIFAR-10 的划分)。
  • 非问题域攻击 (Non-problem Domain): 利用无关数据通过领域自适应完成提取(如 Marich 策略)。
  • 无数据攻击 (Data-free): 完全依赖 GAN 生成查询样本。

防御

  • 攻击预防 (Prevention): 包含 对抗训练(增强鲁棒性)、输出/数据扰动(如加入 $\epsilon$-DP 噪声)、访问控制(基于速率限制或 PoW)。
  • 监测与验证 (Detection & Verification):
    • 查询模式监控:PRADA 监测查询分布的统计偏差 。
    • 所有权水印 (Watermarking): 在模型中嵌入特定 trigger 以便后事追溯所有权。
    • 指纹识别 (Fingerprinting): 利用边界特征生成唯一模型标识。

环境配置

Python

torch==2.5.1
torchvision==0.20.1
numpy==2.2.6
scipy==1.15.3
scikit-learn==1.7.2
pillow==12.1.0
scikit-image==0.25.2
tqdm==4.67.1
matplotlib==3.10.8

Stage 0:基准模型的训练

靶机采用 ResNet-18 架构,在 CIFAR-10 数据集上进行训练

选择一些常见的训练策略:比如通过TrivialAugmentWide进行数据增强;引入标签平滑;还有经典的针对cifar的7x7卷积层改为3x3卷积层.使用单周期学习策略加速训练和Nesterov加速梯度减少震荡

50epoch后的基准模型的准确率为94% 因为我采用的是问题域攻击,所以我采取了cifai-10的前50%训练集进行训练,后50%训练集进行攻击

Stage 1: 软标签窃取

软标签(Soft Label)指的是模型输出的概率分布(比如:“90% 概率是猫,8% 概率是狗,2% 概率是老虎”,此时不仅告诉你了label,同时告诉你了KL散度

使用vgg-11作为窃取模型来对基准模型进行攻击

软标签窃取使用了四种不同的策略,分别是:

1.基准软标签提取(基准策略):使用KL 散度作为损失函数,让替代模型的输出尽可能接近目标模型的输出分布。

  # 核心逻辑:利用 KL 散度让替代模型学习目标模型的输出分布
T = 3.0  # 温度系数,用于平滑概率分布

with torch.no_grad():
    # 1. 获取受害者的软标签(包含类间关系信息)
    target_logits = target_model(inputs)
    target_soft = F.softmax(target_logits / T, dim=1)

# 2. 替代模型进行预测并计算损失
sub_logits = substitute_model(inputs)
sub_log_soft = F.log_softmax(sub_logits / T, dim=1)

# 3. KLDivLoss 核心:最小化两个分布之间的差异
loss = criterion_kl(sub_log_soft, target_soft) * (T * T)

2.基于随机擦除的提取:增加了训练中对原图随机擦掉一块区域方法进行的窃取

# 核心逻辑:定义擦除变换
erasure_transform = transforms.RandomErasing(p=0.5, scale=(0.02, 0.33))

# 攻击循环:
with torch.no_grad():
    target_soft = F.softmax(target_model(inputs) / T, dim=1)

# 将图片擦除部分后再喂给替代模型进行训练
erased_inputs = torch.stack([erasure_transform(img) for img in inputs])
sub_logits = substitute_model(erased_inputs) 

3.MARICH (Multi-stage Active Model Extraction):通过主动学习(熵采样;K-Means)筛选出最“值得”查询的样本,提高查询效率。

# 核心逻辑:基于熵的样本筛选
def marich_sampling(substitute_model, pool_loader, budget):
    # 计算替代模型对样本的困惑度(熵)
    logits = substitute_model(imgs)
    probs = F.softmax(logits, dim=1)
    # 熵公式:H(p) = -Σ p(x) log p(x)
    entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
    
    # 挑选熵最高(即最不确定)的前 B 个点进行查询
    top_b_indices = np.argsort(all_entropies)[-budget:]
    return top_b_indices

4.SPSG (SuperPixel Soft-label Gradient) 梯度纯化提取:通过超像素分割 (SLIC),通过梯度提取和梯度纯化进行窃取

# 核心逻辑:超像素梯度提取
from skimage.segmentation import slic

# 1. 将图像切分为超像素块(类似拼图)
segments = slic(img_np, n_segments=16, compactness=10)

# 2. 超像素层级的前向差分:探测目标模型对特定区域的反应
for sp_idx in range(num_sp):
    mask = torch.from_numpy(segments == sp_idx).float()
    perturbed_img = image + epsilon * mask # 微调特定块
    perturbed_output = target_model(perturbed_img.unsqueeze(0)).softmax(dim=1)
    # 计算该块对预测结果的影响力(梯度模拟)
    sp_grads = (perturbed_output - base_output).abs().mean() / epsilon

实验结果: cnn软标签基准方法在40轮次下达到了83.74%的准确率以及84.85%的一致率 基于擦除的软标签在40轮次下达到了84.16%的准确率以及85.29%的一致率 MARICH方法在进行4785次查询,即0.47轮次时达到了71.15%的准确率,和72.24%的一致率.远高于基准模型的01轮次的50%,大幅降低了攻击的开销 SPSG方法在进行3203次查询,即0.32轮次时达到了68.93的准确率,和69.78%的一致率.

Stage 2: 硬标签窃取

硬标签指基准模型只输出label

使用vgg-11为窃取模型来对基准模型进行攻击

采用了下述五种方法:

1.基础硬标签提取:朴素监督学习。直接将目标模型返回的分类索引作为标签,用交叉熵损失训练替代模型

# 核心逻辑:直接利用 Hard API 返回的索引进行监督学习
for inputs, _ in tqdm(query_loader):
    inputs = inputs.to(device)
    # 1. 询问受害者模型:这张图是什么?
    target_labels = api.query(inputs)
    
    # 2. 替代模型强制“背诵”答案
    optimizer.zero_grad()
    loss = nn.CrossEntropyLoss()(substitute_model(inputs), target_labels)
    loss.backward()
    optimizer.step()

2.数据增强硬标签提取:在查询前对图片进行旋转、裁剪等强增强处理,增加样本覆盖面

# 核心逻辑:利用强数据增强制造“样本多样性”
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomRotation(15), # 增加旋转,让替代模型学习旋转不变性
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 每次迭代,api.query 处理的都是经过随机扰动的图像

3.边界探测提取:二分查找边界。找到两个预测类别不同的样本,通过二分查找合成一张处于“决策边界”上的新图片进行查询

# 核心逻辑:二分搜索决策边界
def binary_search_boundary(api, img1, img2, label1, steps=10):
    low, high = 0.0, 1.0
    for _ in range(steps):
        mid = (low + high) / 2
        # 线性插值生成中间图像
        mid_img = (1 - mid) * img1 + mid * img2
        if api.query(mid_img.unsqueeze(0)).item() == label1.item():
            low = mid  # 仍在 label1 区域
        else:
            high = mid # 已跨越边界
    return mid_img # 返回最接近边界的样本

4.主动学习硬标签提取:MARICH。利用信息熵和 K-Means 聚类,在海量数据池中只选最“令人困惑”且“不重复”的样本进行查询

# 核心逻辑:熵筛选 + 多样性聚类
# 1. 挑选替代模型最“困惑”的样本
p = torch.softmax(model(x), dim=1)
entropies = -torch.sum(p * torch.log(p + 1e-10), dim=1)

# 2. 使用 K-Means 对梯度进行聚类,确保选出的样本分布均匀
kmeans = KMeans(n_clusters=budget).fit(grads)
# 选取每个簇中心最近的样本,避免冗余查询

5.线性插值提取:将两张图片按比例混合,同时按比例混合它们的标签。

# 核心逻辑:将硬标签转化为“伪软标签”
def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(x.size(0))
    # 图像混合
    mixed_x = lam * x + (1 - lam) * x[index]
    return mixed_x, y, y[index], lam

# 训练时,损失函数同时兼顾两个原始标签
loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)

实验结果:

基准硬标签提取:准确率69.39%,一致率70.27%,1,250,000次查询 数据增强提取:准确率77.04%,一致率78.23%,1,250,000查询 边界探测提取:准确率72.12%,一致率72.97%,1,250,000查询 主动学习提取:准确率47.50%,一致率47.62%,3,000查询 线性插值提取:准确率86.30%,一致率87.43%,1,250,000查询 表明线性插值查询优于其他方法,主动学习提取能够极大减少攻击开销,但是,主动学习提取硬标签真的很慢,3000次查询也耗费了3~4小时的时间,实在是无法长时间训练

Stage 3: 防御干扰下的窃取

防御措施采用Honey-labels:API 在返回预测结果前,会检查 Top-1 和 Top-2 类别之间的置信度差异,如果两者的差距小于预设阈值(margin_threshold=1.5),意味着模型处于“犹豫不决”的决策边界附近,此时 API 会故意返回**错误(Top-2)**的标签来误导攻击者。

# 核心逻辑:在决策边界附近返回错误标签
def query(self, x):
    with torch.no_grad():
        logits = self.model(x.to(self.device))
        # 获取前两名的置信度及其索引
        values, indices = torch.topk(logits, 2, dim=1)
        predicted = indices[:, 0].clone()
        
        # Honey-label 逻辑:如果第一名和第二名差距(Margin)太小,说明在边界附近
        # 此时 API 故意返回第二名(错误标签)来干扰攻击者
        mask = (values[:, 0] - values[:, 1]) < self.margin_threshold
        predicted[mask] = indices[mask, 1]
    return predicted

此外还添加了水印,它通过在图像右下角添加 $3 \times 3$ 的白色像素块来观察替代模型是否学会了某种“后门”行为。

# 核心逻辑:给图像添加特定像素块作为“后门”水印
def apply_watermark(x):
    x_wm = x.clone()
    # 在右下角添加 3x3 的白色像素块
    x_wm[:, 29:32, 29:32] = 1.0
    return x_wm

# 检测影子模型是否对带水印的图片给出了特定的(错误)响应
wm_hits = sub_model(test_img).max(1)[1].eq(0).sum().item()

采用下述四种方式进行攻击,实现逻辑不再赘述:

1.朴素硬标签攻击

2.线性插值增强

3.边界探测

4.传统数据增强

结果:

Model Strategy (Stage 3) | Test Acc | Fidelity | WSR

Basic (Honey-Def) | 10.00% | 10.18% | 0.00%

Augmented (Honey-Def) | 79.85% | 80.61% | 10.53%

Boundary (Honey-Def) | 71.79% | 72.61% | 13.54%

Mixup (Honey-Def) | 85.23% | 86.20% | 11.68%

Mixup 通过线性约束迫使模型学习全局连续性,去噪能力能力强,达到了较好的水平,而边界探测由于Honey-Def,即使能够捕捉几何特征,准确率仍然下滑.传统数据增强一张图会被多次变体后查询,虽然单次查询可能拿到错误标签,但多样的视角也提升了模型的泛化能力.

P.S.为什么水印这么不尽人意,因为查询集中不包含水印样本,难绷,攻击者没见过触发器,自然无法学习到这个后门逻辑.