代码来自
FSCE
import argparse import copy import os import random import numpy as np import xml.etree.ElementTree as ET from fvcore.common.file_io import PathManager # 类名 # VOC_CLASSES = ['air-hole', 'bite-edge', 'broken-arc', 'crack', 'hollow-bead', 'overlap','slag-inclusion', 'unfused'] VOC_CLASSES = ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled-in_scale', 'scratches'] def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--seeds", type=int, nargs="+", default=[1, 30], help="Range of seeds") args = parser.parse_args() return args def generate_seeds(args): data = [] data_per_cat = {c: [] for c in VOC_CLASSES} # for year in [2007, 2012]: for year in [2007]: # data_file = 'datasets/VOC{}/ImageSets/Main/trainval.txt'.format(year) data_file = './VOC2007/ImageSets/Main/trainval.txt' # data_file = 'datasets / VOC2007{} / ImageSets / Main / trainval.txt'.format(year) with PathManager.open(data_file) as f: # fileids = np.loadtxt(f, dtype=np.str).tolist() fileids = np.loadtxt(f, dtype=np.str_).tolist() data.extend(fileids) for fileid in data: # year = "2012" if "_" in fileid else "2007" year = 2007 dirname = os.path.join("./", "VOC{}".format(year)) anno_file = os.path.join(dirname, "Annotations", fileid + ".xml") tree = ET.parse(anno_file) clses = [] for obj in tree.findall("object"): cls = obj.find("name").text clses.append(cls) for cls in set(clses): data_per_cat[cls].append(anno_file) result = {cls: {} for cls in data_per_cat.keys()} shots = [1, 2, 3, 5, 10] for i in range(args.seeds[0], args.seeds[1]): random.seed(i) for c in data_per_cat.keys(): c_data = [] for j, shot in enumerate(shots): diff_shot = shots[j] - shots[j-1] if j != 0 else 1 shots_c = random.sample(data_per_cat[c], diff_shot) num_objs = 0 for s in shots_c: if s not in c_data: tree = ET.parse(s) file = tree.find("filename").text year = tree.find("folder").text # name = 'datasets/{}/JPEGImages/{}'.format(year, file) year = 'VOC2007' name = 'datasets/{}/JPEGImages/{}'.format(year, file) print(name) c_data.append(name) for obj in tree.findall("object"): if obj.find("name").text == c: num_objs += 1 if num_objs >= diff_shot: break # print(c_data) result[c][shot] = copy.deepcopy(c_data) save_path = 'datasets/vocsplit/seed{}'.format(i) os.makedirs(save_path, exist_ok=True) for c in result.keys(): for shot in result[c].keys(): filename = 'box_{}shot_{}_train.txt'.format(shot, c) with open(os.path.join(save_path, filename), 'w') as fp: fp.write('\n'.join(result[c][shot])+'\n') if __name__ == '__main__': args = parse_args() generate_seeds(args)