以cifar100作为闭集(closed-set)数据集,使用resnet18模型进行训练,然后在常见的开集(out-of-distribution)数据集上进行OOD检测。使用MSP(Maximum Softmax Probability)作为OOD检测的依据。
开集噪声数据集使用gaussian, rademacher, blob, svhn四种类型。其中gaussian、rademacher、blob是生成的随机噪声,svhn是额外引入的噪声数据集。
输出结果
Error Rate 46.3000
AUROC: 81.9790, AUPR: 85.7377, FPR95: 73.3909
ood type: gaussian
AUROC: 68.1596, AUPR: 92.9277, FPR95: 99.4000
ood type: rademacher
AUROC: 69.9099, AUPR: 93.1788, FPR95: 96.1500
ood type: blob
AUROC: 68.0615, AUPR: 92.7477, FPR95: 97.5500
ood type: svhn
AUROC: 66.9684, AUPR: 91.6508, FPR95: 89.0500
可以看到,在使用简单的交叉熵损失且不经过其他处理的resnet18,在开集检测上的表示并不好。
闭集数据集上训练一个resnet18
# train.py
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets.cifar import CIFAR100
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
def get_transform(train=True):
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
if train:
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
else:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
return transform
def get_loader(train=True):
transform = get_transform(train)
dataset = CIFAR100(root='~/data', train=train, transform=transform)
loader = DataLoader(dataset, batch_size=128, shuffle=train, num_workers=8, pin_memory=True)
return loader
def train_model():
loader = get_loader(train=True)
test_loader = get_loader(train=False)
model = resnet18(num_classes=100)
model = model.cuda()
epochs = 100
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
model.eval()
all_preds = []
all_labels = []
for epoch in range(epochs):
model.train()
print('Training')
for i, (inputs, labels) in enumerate(loader):
inputs, labels = inputs.cuda(), labels.cuda()
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f'Epoch[{epoch}] Iter: {i}/{len(loader)} Loss: {loss.item()}')
scheduler.step()
print('Testing')
for inputs, labels in test_loader:
inputs = inputs.cuda()
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
accuracy = accuracy_score(all_labels, all_preds)
print(f'Epoch[{epoch}] acc@1 {accuracy:.4f}')
torch.save(model.state_dict(), 'cifar100_resnet18.pth')
if __name__ == '__main__':
train_model()
构建常用的开集数据集
# ood_data.py
import torch
import numpy as np
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.dataloader import DataLoader
from skimage.filters import gaussian
from torchvision.datasets import SVHN
from train import get_transform
def build_ood_loader(noise_type, ood_num_examples, batch_size, worker):
dummy_targets = torch.ones(ood_num_examples)
if noise_type in ['gaussian', 'rademacher', 'blob']:
if noise_type == 'gaussian':
ood_data = torch.from_numpy(np.float32(np.clip(
np.random.normal(size=(ood_num_examples, 3, 32, 32), scale=0.5), -1, 1)))
elif noise_type == 'rademacher':
ood_data = torch.from_numpy(np.random.binomial(
n=1, p=0.5, size=(ood_num_examples, 3, 32, 32)).astype(np.float32)) * 2 - 1
else:
ood_data = np.float32(np.random.binomial(n=1, p=0.7, size=(ood_num_examples, 32, 32, 3)))
for i in range(ood_num_examples):
ood_data[i] = gaussian(ood_data[i], sigma=1.5)
ood_data[i][ood_data[i]
使用常见的OOD检测评估指标
# ood_utils.py
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
@torch.no_grad()
def get_ood_scores(model, dataloader, closed_set=False):
model.eval()
scores = []
right_scores = []
wrong_scores = []
for i, (data, targets) in enumerate(dataloader):
data = data.cuda()
output = model(data)
smax = F.softmax(output, dim=1).cpu().numpy()
scores.append(np.max(smax, axis=1))
if closed_set:
pred = np.argmax(smax, axis=1)
targets = targets.numpy().squeeze()
right_indices = pred == targets
wrong_indices = np.invert(right_indices)
right_scores.append(np.max(smax[right_indices], axis=1))
wrong_scores.append(np.max(smax[wrong_indices], axis=1))
if closed_set:
return (np.concatenate(scores),
np.concatenate(right_scores),
np.concatenate(wrong_scores))
else:
return np.concatenate(scores)
def get_performance(pos, neg):
pos = np.array(pos).reshape(-1)
neg = np.array(neg).reshape(-1)
scores = np.concatenate([pos, neg])
labels = [1] * len(pos) + [0] * len(neg)
auroc = roc_auc_score(labels, scores)
aupr = average_precision_score(labels, scores)
fpr, tpr, _ = roc_curve(labels, scores)
fpr95 = fpr[np.argmax(tpr >= 0.95)]
return auroc, aupr, fpr95
def show_performance(pos, neg):
auroc, aupr, fpr95 = get_performance(pos, neg)
print(f"AUROC: {auroc * 100:.4f}, AUPR: {aupr * 100:.4f}, FPR95: {fpr95 * 100:.4f}")
测试模型的OOD检测性能
# test.py
import torch
from torchvision.models import resnet18
from train import get_loader
from ood_utils import get_ood_scores, show_performance
from ood_data import build_ood_loader
def evaluate():
model = resnet18(num_classes=100)
model.load_state_dict(torch.load('cifar100_resnet18.pth'))
model = model.cuda()
# closed-set test
test_loader = get_loader(train=False)
in_score, right_score, wrong_score = get_ood_scores(model, test_loader, True)
num_right, num_wrong = len(right_score), len(wrong_score)
print(f'Error Rate {100 * num_wrong / (num_right + num_wrong):.4f}')
show_performance(right_score, wrong_score)
# open-set test
ood_num_examples = len(test_loader.dataset) // 5
ood_types = ['gaussian', 'rademacher', 'blob', 'svhn']
for i in ood_types:
print(f'ood type: {i}')
ood_loader = build_ood_loader(i, ood_num_examples, batch_size=128, worker=8)
out_score = get_ood_scores(model, ood_loader)
show_performance(in_score, out_score)
if __name__ == '__main__':
evaluate()
依赖
scikit-learn 1.5.2
scipy 1.14.1
torch 2.4.1
1.本站内容仅供参考,不作为任何法律依据。用户在使用本站内容时,应自行判断其真实性、准确性和完整性,并承担相应风险。
2.本站部分内容来源于互联网,仅用于交流学习研究知识,若侵犯了您的合法权益,请及时邮件或站内私信与本站联系,我们将尽快予以处理。
3.本文采用知识共享 署名4.0国际许可协议 [BY-NC-SA] 进行授权
4.根据《计算机软件保护条例》第十七条规定“为了学习和研究软件内含的设计思想和原理,通过安装、显示、传输或者存储软件等方式使用软件的,可以不经软件著作权人许可,不向其支付报酬。”您需知晓本站所有内容资源均来源于网络,仅供用户交流学习与研究使用,版权归属原版权方所有,版权争议与本站无关,用户本人下载后不能用作商业或非法用途,需在24个小时之内从您的电脑中彻底删除上述内容,否则后果均由用户承担责任;如果您访问和下载此文件,表示您同意只将此文件用于参考、学习而非其他用途,否则一切后果请您自行承担,如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。
5.本站是非经营性个人站点,所有软件信息均来自网络,所有资源仅供学习参考研究目的,并不贩卖软件,不存在任何商业目的及用途
暂无评论内容