ETJava Beta | Java    注册   登录
  • 搜索:
  • 基于cifar数据集合成含开集、闭集噪声的数据集

    发表于      阅读(1)     博客类别:Crawler     转自:https://www.cnblogs.com/zh-jp/p/18274899
    如有侵权 请联系我们删除  (页面底部联系我们)  

    前言

    噪声标签学习下的一个任务是:训练集上存在开集噪声和闭集噪声;然后在测试集上对闭集样本进行分类。

    训练集中被加入的开集样本,会被均匀得打上闭集样本的标签充当开集噪声;而闭集噪声的设置与一般的噪声标签学习一致,分为对称噪声:随机将闭集样本的标签替换为其他类别;和非对称噪声:将闭集样本的标签替换为特定的类别。

    论文实验中,常用cifar数据集模拟这类任务。目前已知有两类方法:

    • 第一类基于cifar100,将100个类的一部分,通常是20个类作为开集样本,将它们标签替换了前80个类作为开集噪声;然后对于后续80个类,选择部分样本设置为对称/非对称闭集噪声。CVPR2022的PNP: Robust Learning From Noisy Labels by Probabilistic Noise Prediction提供的代码中,使用了这种方法。但是,如果要考虑非对称噪声,在cifar10上就很难实现,cifar10的类的顺序不像cifar100那样有规律,不好设置闭集噪声。

    • 第二类方法适用cifar10和cifar100,保持原始数据集的样本数不变,使用额外的数据集(通常是imagenet32、places365)代替部分样本作为开集噪声,对于剩下的非开集噪声样本再设置闭集噪声。ECCV2022的Embedding contrastive unsupervised features to cluster in-and out-of-distribution noise in corrupted image datasets提供的代码使用了这种方式。

    places365可以使用torchvision.datasets.Places365下载,由于训练集较大,通常是用它的验证集作为辅助数据集。

    imagenet32是imagnet的32x32版本,同样是1k类,但是类的具体含义的顺序与imagenet不同,imagenet32类的具体含义可见这里。image32下载地址在对应论文A downsampled variant of imagenet as an alternative to the cifar datasets提供的链接

    尝试构造一下

    使用第二种方法构造含开集、闭集噪声数据集,开集噪声率\(r_{ood}=0.2\),闭集噪声率\(r_{id}=0.2\);辅助数据集使用imagenet32,基于cifar构造含开集闭集噪声的训练集。

    设计imagenet32数据集

    import os
    import pickle
    import numpy as np
    from PIL import Image
    from torch.utils.data import Dataset
    
    _train_list = ['train_data_batch_1',
                   'train_data_batch_2',
                   'train_data_batch_3',
                   'train_data_batch_4',
                   'train_data_batch_5',
                   'train_data_batch_6',
                   'train_data_batch_7',
                   'train_data_batch_8',
                   'train_data_batch_9',
                   'train_data_batch_10']
    _val_list = ['val_data']
    
    
    def get_dataset(transform_train, transform_test):
        # prepare datasets
    
        # Train set
        train = Imagenet32(train=True, transform=transform_train)  # Load all 1000 classes in memory
    
        # Test set
        test = Imagenet32(train=False, transform=transform_test)  # Load all 1000 test classes in memory
    
        return train, test
    
    
    class Imagenet32(Dataset):
        def __init__(self, root='~/data/imagenet32', train=True, transform=None):
            if root[0] == '~':
                root = os.path.expanduser(root)
            self.transform = transform
            size = 32
            # Now load the picked numpy arrays
    
            if train:
                data, labels = [], []
    
                for f in _train_list:
                    file = os.path.join(root, f)
    
                    with open(file, 'rb') as fo:
                        entry = pickle.load(fo, encoding='latin1')
                        data.append(entry['data'])
                        labels += entry['labels']
                data = np.concatenate(data)
    
            else:
                f = _val_list[0]
                file = os.path.join(root, f)
                with open(file, 'rb') as fo:
                    entry = pickle.load(fo, encoding='latin1')
                    data = entry['data']
                    labels = entry['labels']
    
            data = data.reshape((-1, 3, size, size))
            self.data = data.transpose((0, 2, 3, 1))  # Convert to HWC
            labels = np.array(labels) - 1
            self.labels = labels.tolist()
    
        def __getitem__(self, index):
    
            img, target = self.data[index], self.labels[index]
            img = Image.fromarray(img)
    
            if self.transform is not None:
                img = self.transform(img)
    
            return img, target, index
    
        def __len__(self):
            return len(self.data)
    

    目录结构:

    imagenet32
    ├─ train_data_batch_1
    ├─ train_data_batch_10
    ├─ train_data_batch_2
    ├─ train_data_batch_3
    ├─ train_data_batch_4
    ├─ train_data_batch_5
    ├─ train_data_batch_6
    ├─ train_data_batch_7
    ├─ train_data_batch_8
    ├─ train_data_batch_9
    └─ val_data
    

    设计cifar数据集

    import torchvision
    import numpy as np
    from dataset.imagenet32 import Imagenet32
    
    
    class CIFAR10(torchvision.datasets.CIFAR10):
        nb_classes = 10
    
        def __init__(self, root='~/data', train=True, transform=None,
                     r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):
    
            super().__init__(root, train=train, transform=transform)
            if train is False:
                return
            np.random.seed(seed)
            if r_ood > 0.:
                ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]
                if corruption == 'imagenet':
                    imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)
                    img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]
                else:
                    raise ValueError(f'Unknown corruption: {corruption}')
                self.ids_ood = ids_ood
                self.data[ids_ood] = img_ood
    
            if r_id > 0.:
                ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]
                ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]
                for i, t in enumerate(self.targets):
                    if i in ids_id:
                        self.targets[i] = int(np.random.random() * self.nb_classes)
                self.ids_id = ids_id
    
    
    class CIFAR100(CIFAR10):
        base_folder = "cifar-100-python"
        url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
        filename = "cifar-100-python.tar.gz"
        tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
        train_list = [
            ["train", "16019d7e3df5f24257cddd939b257f8d"],
        ]
    
        test_list = [
            ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
        ]
        meta = {
            "filename": "meta",
            "key": "fine_label_names",
            "md5": "7973b15100ade9c7d40fb424638fde48",
        }
    
        nb_classes = 100
    
        def __init__(self, root='~/data', train=True, transform=None,
                     r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet'):
            super().__init__(root=root, train=train, transform=transform, r_ood=r_ood, r_id=r_id, seed=seed,
                             corruption=corruption)
    
    

    查看统计结果

    import pandas as pd
    import altair as alt
    from dataset.cifar import CIFAR10, CIFAR100
    
    # Initialize CIFAR10 dataset
    cifar10 = CIFAR10()
    cifar100 = CIFAR100()
    
    
    def statistics_samples(dataset):
        ids_ood = dataset.ids_ood
        ids_id = dataset.ids_id
    
        # Collect statistics
        statistics = []
        for i in range(dataset.nb_classes):
            statistics.append({
                'class': i,
                'id': 0,
                'ood': 0,
                'clear': 0
            })
    
        for i, t in enumerate(dataset.targets):
            if i in ids_ood:
                statistics[t]['ood'] += 1
            elif i in ids_id:
                statistics[t]['id'] += 1
            else:
                statistics[t]['clear'] += 1
    
        df = pd.DataFrame(statistics)
    
        # Melt the DataFrame for Altair
        df_melt = df.melt(id_vars='class', var_name='type', value_name='count')
    
        # Create the bar chart
        chart = alt.Chart(df_melt).mark_bar().encode(
            x=alt.X('class:O', title='Classes'),
            y=alt.Y('count:Q', title='Sample Count'),
            color='type:N'
        )
        return chart
    
    
    chart1 = statistics_samples(cifar10)
    chart2 = statistics_samples(cifar100)
    chart1 = chart1.properties(
        title='cifar10',
        width=100,  # Adjust width to fit both charts side by side
        height=400
    )
    chart2 = chart2.properties(
        title='cifar100',
        width=800,
        height=400
    )
    combined_chart = alt.hconcat(chart1, chart2).configure_axis(
        labelFontSize=12,
        titleFontSize=14
    ).configure_legend(
        titleFontSize=14,
        labelFontSize=12
    )
    combined_chart
    

    运行环境

    # Name                    Version                   Build  Channel
    altair                    5.3.0                    pypi_0    pypi
    pytorch                   2.3.1           py3.12_cuda12.1_cudnn8_0    pytorch
    pandas                    2.2.2                    pypi_0    pypi