# 在 pytorch 中,我们可以通过设置 num_worker 的数量来提高数据加载的速度,从而减少将数据
从 CPU 加载到 GPU 的时间开销,以提高 GPU 的利用率,进而加快模型的训练速度。

# 在 linux 环境下设置 dataLoader 的 num_worker 数量大于 0 是可以正常运行的,但是在
Windows 环境下会报错,只能设置 num_worker=0 才可以正常运行,但是这样会使得模型
的训练速度极其漫长…

# 如果还是想在 Windows 环境下在 pytorch 中启用多线程加载数据,那么应该怎么办呢?
这个问题我找了很久很久很久… 才找到解决方案!!!

torchnet + dataloader


  1. 安装 torchnet
1
pip install torchnet
  1. 自定义 dataLoader
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class DataLoader:
def __init__(self, ...):
pass

def get_task_batch(self, ...):
pass

def get_iterator(self):
tnt_dataset = torchnet.dataset.ListDataset(
elem_list=range(self.task_num), load=self.get_task_batch)
data_loader = tnt_dataset.parallel(
batch_size=self.batch_size,
num_workers=self.num_workers,
drop_last=True,
# 此函数可以使每个worker使用不同的随机种子
worker_init_fn=self.worker_init_fn_seed,
shuffle=(False if self.test else True))
return data_loader

def worker_init_fn_seed(self, worker_id):
seed = 10 + 5 * worker_id
np.random.seed(seed)

def __call__(self):
return self.get_iterator()
  1. 在主程序中使用方法
1
2
3
4
5
# MiniImagenet为预处理数据的类
dataset = MiniImagenet(data_path, ...)
loader = DataLoader(dataset, num_workers=2, ...)
for step, batch in enumerate(loader()):
x, y = batch
更新于

请我喝[茶]~( ̄▽ ̄)~*

Revincent 微信

微信

Revincent 支付宝

支付宝