使用cifar100上训练的resnet18进行ood测试

以cifar100作为闭集(closed-set)数据集,使用resnet18模型进行训练,然后在常见的开集(out-of-distribution)数据集上进行OOD检测。使用MSP(Maximum Softmax Probability)作为OOD检测的依据。

开集噪声数据集使用gaussian, rademacher, blob, svhn四种类型。其中gaussian、rademacher、blob是生成的随机噪声,svhn是额外引入的噪声数据集。

输出结果

Error Rate 46.3000
AUROC: 81.9790, AUPR: 85.7377, FPR95: 73.3909
ood type: gaussian
AUROC: 68.1596, AUPR: 92.9277, FPR95: 99.4000
ood type: rademacher
AUROC: 69.9099, AUPR: 93.1788, FPR95: 96.1500
ood type: blob
AUROC: 68.0615, AUPR: 92.7477, FPR95: 97.5500
ood type: svhn
AUROC: 66.9684, AUPR: 91.6508, FPR95: 89.0500

可以看到,在使用简单的交叉熵损失且不经过其他处理的resnet18,在开集检测上的表示并不好。

闭集数据集上训练一个resnet18

# train.py
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets.cifar import CIFAR100
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score
import torch.nn.functional as F


def get_transform(train=True):
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    if train:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    return transform


def get_loader(train=True):
    transform = get_transform(train)
    dataset = CIFAR100(root='~/data', train=train, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=train, num_workers=8, pin_memory=True)
    return loader


def train_model():

    loader = get_loader(train=True)
    test_loader = get_loader(train=False)

    model = resnet18(num_classes=100)
    model = model.cuda()

    epochs = 100
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    model.eval()
    all_preds = []
    all_labels = []

    for epoch in range(epochs):
        model.train()
        print('Training')
        for i, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print(f'Epoch[{epoch}] Iter: {i}/{len(loader)} Loss: {loss.item()}')
        scheduler.step()
        print('Testing')
        for inputs, labels in test_loader:
            inputs = inputs.cuda()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
        accuracy = accuracy_score(all_labels, all_preds)
        print(f'Epoch[{epoch}] acc@1 {accuracy:.4f}')

    torch.save(model.state_dict(), 'cifar100_resnet18.pth')


if __name__ == '__main__':
    train_model()

构建常用的开集数据集

# ood_data.py
import torch
import numpy as np
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.dataloader import DataLoader
from skimage.filters import gaussian
from torchvision.datasets import SVHN

from train import get_transform


def build_ood_loader(noise_type, ood_num_examples, batch_size, worker):
    dummy_targets = torch.ones(ood_num_examples)
    if noise_type in ['gaussian', 'rademacher', 'blob']:
        if noise_type == 'gaussian':
            ood_data = torch.from_numpy(np.float32(np.clip(
                np.random.normal(size=(ood_num_examples, 3, 32, 32), scale=0.5), -1, 1)))
        elif noise_type == 'rademacher':
            ood_data = torch.from_numpy(np.random.binomial(
                n=1, p=0.5, size=(ood_num_examples, 3, 32, 32)).astype(np.float32)) * 2 - 1
        else:
            ood_data = np.float32(np.random.binomial(n=1, p=0.7, size=(ood_num_examples, 32, 32, 3)))
            for i in range(ood_num_examples):
                ood_data[i] = gaussian(ood_data[i], sigma=1.5)
                ood_data[i][ood_data[i] 

使用常见的OOD检测评估指标

# ood_utils.py
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve


@torch.no_grad()
def get_ood_scores(model, dataloader, closed_set=False):
    model.eval()
    scores = []
    right_scores = []
    wrong_scores = []
    for i, (data, targets) in enumerate(dataloader):
        data = data.cuda()
        output = model(data)
        smax = F.softmax(output, dim=1).cpu().numpy()

        scores.append(np.max(smax, axis=1))

        if closed_set:
            pred = np.argmax(smax, axis=1)
            targets = targets.numpy().squeeze()
            right_indices = pred == targets
            wrong_indices = np.invert(right_indices)
            right_scores.append(np.max(smax[right_indices], axis=1))
            wrong_scores.append(np.max(smax[wrong_indices], axis=1))

    if closed_set:
        return (np.concatenate(scores),
                np.concatenate(right_scores),
                np.concatenate(wrong_scores))
    else:
        return np.concatenate(scores)


def get_performance(pos, neg):
    pos = np.array(pos).reshape(-1)
    neg = np.array(neg).reshape(-1)
    scores = np.concatenate([pos, neg])
    labels = [1] * len(pos) + [0] * len(neg)
    auroc = roc_auc_score(labels, scores)
    aupr = average_precision_score(labels, scores)

    fpr, tpr, _ = roc_curve(labels, scores)
    fpr95 = fpr[np.argmax(tpr >= 0.95)]
    return auroc, aupr, fpr95


def show_performance(pos, neg):
    auroc, aupr, fpr95 = get_performance(pos, neg)
    print(f"AUROC: {auroc * 100:.4f}, AUPR: {aupr * 100:.4f}, FPR95: {fpr95 * 100:.4f}")

测试模型的OOD检测性能

# test.py
import torch
from torchvision.models import resnet18


from train import get_loader
from ood_utils import get_ood_scores, show_performance
from ood_data import build_ood_loader


def evaluate():
    model = resnet18(num_classes=100)
    model.load_state_dict(torch.load('cifar100_resnet18.pth'))
    model = model.cuda()

    # closed-set test
    test_loader = get_loader(train=False)
    in_score, right_score, wrong_score = get_ood_scores(model, test_loader, True)
    num_right, num_wrong = len(right_score), len(wrong_score)

    print(f'Error Rate {100 * num_wrong / (num_right + num_wrong):.4f}')
    show_performance(right_score, wrong_score)
    # open-set test
    ood_num_examples = len(test_loader.dataset) // 5
    ood_types = ['gaussian', 'rademacher', 'blob', 'svhn']
    for i in ood_types:
        print(f'ood type: {i}')
        ood_loader = build_ood_loader(i, ood_num_examples, batch_size=128, worker=8)
        out_score = get_ood_scores(model, ood_loader)
        show_performance(in_score, out_score)


if __name__ == '__main__':
    evaluate()

依赖

scikit-learn       1.5.2
scipy              1.14.1
torch              2.4.1
玄机博客
© 版权声明
THE END
喜欢就支持一下吧
点赞8 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片快捷回复

    暂无评论内容