// Blog Entry
模型提取攻击摸索
本文记录了我针对“模型提取攻击(Model Extraction Attacks, MEAs)”进行的摸索。模型窃取本质上是利用黑盒接口的查询-响应机制,通过训练影子模型(Substitute Model)来复制靶机模型(Victim Model)的决策逻辑
整个研究从最初的简单的基准模型的训练开始,到软标签(Soft-label)窃取,再到硬标签(Hard-label)环境下的决策边界探测以及有防御策略下的窃取.相关的攻击和防御策略如下
攻击
攻击方法
- 替代模型训练 (Substitute Model Training): 利用查询反馈训练功能等价的影子模型。
- 方程求解与参数还原: 针对简单模型或特定激活函数(如 ReLU)进行权重窃取。
- 解释引导与梯度估计: 利用热力图或样本梯度加速窃取(如 SPSG 策略)。
攻击数据选择
- 问题域攻击 (Problem Domain): 使用与靶机同分布的数据(如本项目对 CIFAR-10 的划分)。
- 非问题域攻击 (Non-problem Domain): 利用无关数据通过领域自适应完成提取(如 Marich 策略)。
- 无数据攻击 (Data-free): 完全依赖 GAN 生成查询样本。
防御
- 攻击预防 (Prevention): 包含 对抗训练(增强鲁棒性)、输出/数据扰动(如加入 $\epsilon$-DP 噪声)、访问控制(基于速率限制或 PoW)。
- 监测与验证 (Detection & Verification):
- 查询模式监控: 如 PRADA 监测查询分布的统计偏差 。
- 所有权水印 (Watermarking): 在模型中嵌入特定 trigger 以便后事追溯所有权。
- 指纹识别 (Fingerprinting): 利用边界特征生成唯一模型标识。
环境配置
Python
torch==2.5.1
torchvision==0.20.1
numpy==2.2.6
scipy==1.15.3
scikit-learn==1.7.2
pillow==12.1.0
scikit-image==0.25.2
tqdm==4.67.1
matplotlib==3.10.8
Stage 0:基准模型的训练
靶机采用 ResNet-18 架构,在 CIFAR-10 数据集上进行训练
选择一些常见的训练策略:比如通过TrivialAugmentWide进行数据增强;引入标签平滑;还有经典的针对cifar的7x7卷积层改为3x3卷积层.使用单周期学习策略加速训练和Nesterov加速梯度减少震荡
50epoch后的基准模型的准确率为94%
因为我采用的是问题域攻击,所以我采取了cifai-10的前50%训练集进行训练,后50%训练集进行攻击
Stage 1: 软标签窃取
软标签(Soft Label)指的是模型输出的概率分布(比如:“90% 概率是猫,8% 概率是狗,2% 概率是老虎”,此时不仅告诉你了label,同时告诉你了KL散度
使用vgg-11作为窃取模型来对基准模型进行攻击
软标签窃取使用了四种不同的策略,分别是:
1.基准软标签提取(基准策略):使用KL 散度作为损失函数,让替代模型的输出尽可能接近目标模型的输出分布。
# 核心逻辑:利用 KL 散度让替代模型学习目标模型的输出分布
T = 3.0 # 温度系数,用于平滑概率分布
with torch.no_grad():
# 1. 获取受害者的软标签(包含类间关系信息)
target_logits = target_model(inputs)
target_soft = F.softmax(target_logits / T, dim=1)
# 2. 替代模型进行预测并计算损失
sub_logits = substitute_model(inputs)
sub_log_soft = F.log_softmax(sub_logits / T, dim=1)
# 3. KLDivLoss 核心:最小化两个分布之间的差异
loss = criterion_kl(sub_log_soft, target_soft) * (T * T)
2.基于随机擦除的提取:增加了训练中对原图随机擦掉一块区域方法进行的窃取
# 核心逻辑:定义擦除变换
erasure_transform = transforms.RandomErasing(p=0.5, scale=(0.02, 0.33))
# 攻击循环:
with torch.no_grad():
target_soft = F.softmax(target_model(inputs) / T, dim=1)
# 将图片擦除部分后再喂给替代模型进行训练
erased_inputs = torch.stack([erasure_transform(img) for img in inputs])
sub_logits = substitute_model(erased_inputs)
3.MARICH (Multi-stage Active Model Extraction):通过主动学习(熵采样;K-Means)筛选出最“值得”查询的样本,提高查询效率。
# 核心逻辑:基于熵的样本筛选
def marich_sampling(substitute_model, pool_loader, budget):
# 计算替代模型对样本的困惑度(熵)
logits = substitute_model(imgs)
probs = F.softmax(logits, dim=1)
# 熵公式:H(p) = -Σ p(x) log p(x)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
# 挑选熵最高(即最不确定)的前 B 个点进行查询
top_b_indices = np.argsort(all_entropies)[-budget:]
return top_b_indices
4.SPSG (SuperPixel Soft-label Gradient) 梯度纯化提取:通过超像素分割 (SLIC),通过梯度提取和梯度纯化进行窃取
# 核心逻辑:超像素梯度提取
from skimage.segmentation import slic
# 1. 将图像切分为超像素块(类似拼图)
segments = slic(img_np, n_segments=16, compactness=10)
# 2. 超像素层级的前向差分:探测目标模型对特定区域的反应
for sp_idx in range(num_sp):
mask = torch.from_numpy(segments == sp_idx).float()
perturbed_img = image + epsilon * mask # 微调特定块
perturbed_output = target_model(perturbed_img.unsqueeze(0)).softmax(dim=1)
# 计算该块对预测结果的影响力(梯度模拟)
sp_grads = (perturbed_output - base_output).abs().mean() / epsilon
实验结果: cnn软标签基准方法在40轮次下达到了83.74%的准确率以及84.85%的一致率 基于擦除的软标签在40轮次下达到了84.16%的准确率以及85.29%的一致率 MARICH方法在进行4785次查询,即0.47轮次时达到了71.15%的准确率,和72.24%的一致率.远高于基准模型的01轮次的50%,大幅降低了攻击的开销 SPSG方法在进行3203次查询,即0.32轮次时达到了68.93的准确率,和69.78%的一致率.
Stage 2: 硬标签窃取
硬标签指基准模型只输出label
使用vgg-11为窃取模型来对基准模型进行攻击
采用了下述五种方法:
1.基础硬标签提取:朴素监督学习。直接将目标模型返回的分类索引作为标签,用交叉熵损失训练替代模型
# 核心逻辑:直接利用 Hard API 返回的索引进行监督学习
for inputs, _ in tqdm(query_loader):
inputs = inputs.to(device)
# 1. 询问受害者模型:这张图是什么?
target_labels = api.query(inputs)
# 2. 替代模型强制“背诵”答案
optimizer.zero_grad()
loss = nn.CrossEntropyLoss()(substitute_model(inputs), target_labels)
loss.backward()
optimizer.step()
2.数据增强硬标签提取:在查询前对图片进行旋转、裁剪等强增强处理,增加样本覆盖面
# 核心逻辑:利用强数据增强制造“样本多样性”
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.RandomRotation(15), # 增加旋转,让替代模型学习旋转不变性
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 每次迭代,api.query 处理的都是经过随机扰动的图像
3.边界探测提取:二分查找边界。找到两个预测类别不同的样本,通过二分查找合成一张处于“决策边界”上的新图片进行查询
# 核心逻辑:二分搜索决策边界
def binary_search_boundary(api, img1, img2, label1, steps=10):
low, high = 0.0, 1.0
for _ in range(steps):
mid = (low + high) / 2
# 线性插值生成中间图像
mid_img = (1 - mid) * img1 + mid * img2
if api.query(mid_img.unsqueeze(0)).item() == label1.item():
low = mid # 仍在 label1 区域
else:
high = mid # 已跨越边界
return mid_img # 返回最接近边界的样本
4.主动学习硬标签提取:MARICH。利用信息熵和 K-Means 聚类,在海量数据池中只选最“令人困惑”且“不重复”的样本进行查询
# 核心逻辑:熵筛选 + 多样性聚类
# 1. 挑选替代模型最“困惑”的样本
p = torch.softmax(model(x), dim=1)
entropies = -torch.sum(p * torch.log(p + 1e-10), dim=1)
# 2. 使用 K-Means 对梯度进行聚类,确保选出的样本分布均匀
kmeans = KMeans(n_clusters=budget).fit(grads)
# 选取每个簇中心最近的样本,避免冗余查询
5.线性插值提取:将两张图片按比例混合,同时按比例混合它们的标签。
# 核心逻辑:将硬标签转化为“伪软标签”
def mixup_data(x, y, alpha=1.0):
lam = np.random.beta(alpha, alpha)
index = torch.randperm(x.size(0))
# 图像混合
mixed_x = lam * x + (1 - lam) * x[index]
return mixed_x, y, y[index], lam
# 训练时,损失函数同时兼顾两个原始标签
loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)
实验结果:
基准硬标签提取:准确率69.39%,一致率70.27%,1,250,000次查询 数据增强提取:准确率77.04%,一致率78.23%,1,250,000查询 边界探测提取:准确率72.12%,一致率72.97%,1,250,000查询 主动学习提取:准确率47.50%,一致率47.62%,3,000查询 线性插值提取:准确率86.30%,一致率87.43%,1,250,000查询 表明线性插值查询优于其他方法,主动学习提取能够极大减少攻击开销,但是,主动学习提取硬标签真的很慢,3000次查询也耗费了3~4小时的时间,实在是无法长时间训练
Stage 3: 防御干扰下的窃取
防御措施采用Honey-labels:API 在返回预测结果前,会检查 Top-1 和 Top-2 类别之间的置信度差异,如果两者的差距小于预设阈值(margin_threshold=1.5),意味着模型处于“犹豫不决”的决策边界附近,此时 API 会故意返回**错误(Top-2)**的标签来误导攻击者。
# 核心逻辑:在决策边界附近返回错误标签
def query(self, x):
with torch.no_grad():
logits = self.model(x.to(self.device))
# 获取前两名的置信度及其索引
values, indices = torch.topk(logits, 2, dim=1)
predicted = indices[:, 0].clone()
# Honey-label 逻辑:如果第一名和第二名差距(Margin)太小,说明在边界附近
# 此时 API 故意返回第二名(错误标签)来干扰攻击者
mask = (values[:, 0] - values[:, 1]) < self.margin_threshold
predicted[mask] = indices[mask, 1]
return predicted
此外还添加了水印,它通过在图像右下角添加 $3 \times 3$ 的白色像素块来观察替代模型是否学会了某种“后门”行为。
# 核心逻辑:给图像添加特定像素块作为“后门”水印
def apply_watermark(x):
x_wm = x.clone()
# 在右下角添加 3x3 的白色像素块
x_wm[:, 29:32, 29:32] = 1.0
return x_wm
# 检测影子模型是否对带水印的图片给出了特定的(错误)响应
wm_hits = sub_model(test_img).max(1)[1].eq(0).sum().item()
采用下述四种方式进行攻击,实现逻辑不再赘述:
1.朴素硬标签攻击
2.线性插值增强
3.边界探测
4.传统数据增强
结果:
Model Strategy (Stage 3) | Test Acc | Fidelity | WSR
Basic (Honey-Def) | 10.00% | 10.18% | 0.00%
Augmented (Honey-Def) | 79.85% | 80.61% | 10.53%
Boundary (Honey-Def) | 71.79% | 72.61% | 13.54%
Mixup (Honey-Def) | 85.23% | 86.20% | 11.68%
Mixup 通过线性约束迫使模型学习全局连续性,去噪能力能力强,达到了较好的水平,而边界探测由于Honey-Def,即使能够捕捉几何特征,准确率仍然下滑.传统数据增强一张图会被多次变体后查询,虽然单次查询可能拿到错误标签,但多样的视角也提升了模型的泛化能力.
P.S.为什么水印这么不尽人意,因为查询集中不包含水印样本,难绷,攻击者没见过触发器,自然无法学习到这个后门逻辑.