1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| class DataLoader: def __init__(self, ...): pass def get_task_batch(self, ...): pass def get_iterator(self): tnt_dataset = torchnet.dataset.ListDataset( elem_list=range(self.task_num), load=self.get_task_batch) data_loader = tnt_dataset.parallel( batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True, # 此函数可以使每个worker使用不同的随机种子 worker_init_fn=self.worker_init_fn_seed, shuffle=(False if self.test else True)) return data_loader def worker_init_fn_seed(self, worker_id): seed = 10 + 5 * worker_id np.random.seed(seed) def __call__(self): return self.get_iterator()
|