先說(shuō)明一下背景,目前正在魔改以下這篇論文的代碼:
https://github.com/QipengGuo/GraphWriter-DGLgithub.com
由于每次完成實(shí)驗(yàn)需要5個(gè)小時(shí)(baseline),自己的模型需要更久(2倍),非常不利于調(diào)參和發(fā)現(xiàn)問(wèn)題,所以開(kāi)始嘗試使用多卡加速。
torch.nn.DataParallel ==> 簡(jiǎn)稱 DP
torch.nn.parallel.DistributedDataParallel ==> 簡(jiǎn)稱DDP
一開(kāi)始采用dp試圖加速,結(jié)果因?yàn)閐gl的實(shí)現(xiàn)(每個(gè)batch的點(diǎn)都會(huì)打包進(jìn)一個(gè)batch,從而不可分割),而torch.nn.DataParallel的實(shí)現(xiàn)是把一個(gè)batch切分成更小,再加上他的加速性能也不如ddp,所以我開(kāi)始嘗試魔改成ddp。
另外,作者在實(shí)現(xiàn)Sampler的時(shí)候是繼承了torch.utils.data.Sampler這個(gè)類的,目的在于agenda數(shù)據(jù)集的文本長(zhǎng)度嚴(yán)重不均衡,如下:
為了讓模型更快train完,把長(zhǎng)度相近的文本打包成一個(gè)batch(溫馨提醒,torchtext也有相關(guān)的類 bucketiterator[1],大概形式如下:
class BucketSampler(torch.utils.data.Sampler):
def __init__(self, data_source, batch_size=32):
self.data_source = data_source
self.batch_size = batch_size
def __iter__(self):
idxs, lens, batch, middle_batch_size, long_batch_size = basesampler(self.data_source , self.batch_size)
for idx in idxs:
batch.append(idx)
mlen = max([0]+[lens[x] for x in batch])
#if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
if (mlen<100 and len(batch) == self.batch_size) or (mlen>100 and mlen<220 and len(batch) >= middle_batch_size) or (mlen>220 and len(batch)>=long_batch_size) or len(batch)==self.batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def __len__(self):
return (len(self.data_source)+self.batch_size-1)//self.batch_size
這是背景。
寫(xiě)bug第一步:繼承DistributedSampler的漏洞百出
我一開(kāi)始理想當(dāng)然的把作者的sampler源碼crtl-cv下來(lái),唯獨(dú)只改動(dòng)了這里:
class DDPBaseBucketSampler(torch.utils.data.distributed.DistributedSampler):
隨后就發(fā)現(xiàn)了幾個(gè)問(wèn)題:
- dataloader不會(huì)發(fā)包;
- dataloader給每個(gè)進(jìn)程發(fā)的是完整的數(shù)據(jù),按武德來(lái)說(shuō),應(yīng)該是1/n的數(shù)據(jù),n為你設(shè)置的gpu數(shù)量;
然后我就開(kāi)始看起了源碼[2],很快啊:
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
else:
indices = list(range(len(self.dataset))) # type: ignore
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rankself.num_replicas] # 這一步保證每個(gè)進(jìn)程拿到的數(shù)據(jù)不同
assert len(indices) == self.num_samples
return iter(indices)
這里最關(guān)鍵的問(wèn)題是是什么呢?首先在torch.utils.data.distributed.DistributedSampler里面,數(shù)據(jù)集的變量叫self.dataset而不是data_source;其次和torch.utils.data.Sampler要求你_重寫(xiě)__iter__函數(shù)不同:
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
DistributedSampler這個(gè)父類里有部分實(shí)現(xiàn),如果你沒(méi)有考慮到這部分,就自然會(huì)出現(xiàn)每個(gè)進(jìn)程拿到的數(shù)據(jù)都是all的情況。
于是我重寫(xiě)了我的DDPBaseBucketSampler類:
def basesampler(lens, indices, batch_size):
# the magic number comes from the author's code
t1 = []
t2 = []
t3 = []
for i, l in enumerate(lens):
if (l<100):
t1.append(indices[i])
elif (l>100 and l<220):
t2.append(indices[i])
else:
t3.append(indices[i])
datas = [t1,t2,t3]
random.shuffle(datas)
idxs = sum(datas, [])
batch = []
#為了保證不爆卡,我們給不同長(zhǎng)度的數(shù)據(jù)上保護(hù)鎖
middle_batch_size = min(int(batch_size * 0.75) , 32)
long_batch_size = min(int(batch_size * 0.5) , 24)
return idxs, batch, middle_batch_size, long_batch_size
class DDPBaseBucketSampler(torch.utils.data.distributed.DistributedSampler):
'''
這里要注意和單GPU的sampler類同步
'''
def __init__(self, dataset, num_replicas, rank, shuffle=True, batch_size=32):
super(DDPBaseBucketSampler, self).__init__(dataset, num_replicas, rank, shuffle)
self.batch_size = batch_size
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
#print('here is pytorch code and you can delete it in the /home/lzk/anaconda3/lib/python3.7/site-packages/torch/utils/data')
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
indices = indices[self.rankself.num_replicas]
assert len(indices) == self.num_samples
# 然后我也要拿到每個(gè)數(shù)據(jù)的長(zhǎng)度 (每個(gè)rank不同)
lens = torch.Tensor([len(x) for x in self.dataset])
idxs, batch, middle_batch_size, long_batch_size = basesampler(lens[indices], indices, self.batch_size)
for idx in idxs:
batch.append(idx)
mlen = max([0]+[lens[x] for x in batch])
#if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
if (mlen<100 and len(batch) == self.batch_size) or (mlen>100 and mlen<220 and len(batch) >= middle_batch_size) or (mlen>220 and len(batch)>=long_batch_size) or len(batch)==self.batch_size:
yield batch
batch = []
# print('應(yīng)該出現(xiàn)2次如果是2個(gè)進(jìn)程的話')
if len(batch) > 0:
yield batch
def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size
后面每個(gè)進(jìn)程終于可以跑屬于自己的數(shù)據(jù)了(1/n,n=進(jìn)程數(shù)量=GPU數(shù)量,單機(jī))
緊接著問(wèn)題又來(lái)了,我發(fā)現(xiàn)訓(xùn)練過(guò)程正常結(jié)束后,主進(jìn)程無(wú)法退出mp.spawn()函數(shù)。
寫(xiě)bug第二步,master進(jìn)程無(wú)法正常結(jié)束
number workers ddp pytorch下無(wú)法正常結(jié)束。具體表現(xiàn)為,mp.spawn傳遞的函數(shù)參數(shù)可以順利運(yùn)行完,但是master進(jìn)程一直占著卡,不退出。一開(kāi)始我懷疑是sampler函數(shù)的分發(fā)batch的機(jī)制導(dǎo)致的,什么意思呢?就是由于每個(gè)進(jìn)程拿到的數(shù)據(jù)不一樣,各自進(jìn)程執(zhí)行sampler類的時(shí)候,由于我規(guī)定了長(zhǎng)度接近的文本打包在一起,所以可能master進(jìn)程有一百個(gè)iter,slave只有80個(gè),然后我馬上試了一下,很快啊:
發(fā)現(xiàn)只有細(xì)微的差別,并且,程序最后都越過(guò)了這些print,應(yīng)該不會(huì)是batch數(shù)量不一致導(dǎo)致的問(wèn)題。(順便指的一提的是,sampler在很早的時(shí)候就把batch打包好了)
加了摧毀進(jìn)程,也于事無(wú)補(bǔ)
if args.is_ddp:
dist.destroy_process_group()
print('rank destroy_process_group: ' , rank)
然后只能點(diǎn)擊強(qiáng)制退出
File "train.py", line 322, in
main(args.gpu, args)
File "/home/lzk/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
while not spawn_context.join():
File "/home/lzk/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 77, in join
timeout=timeout,
File "/home/lzk/anaconda3/lib/python3.7/multiprocessing/connection.py", line 920, in wait
ready = selector.select(timeout)
File "/home/lzk/anaconda3/lib/python3.7/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
TypeError: keyboard_interrupt_handler() takes 1 positional argument but 2 were given
^CError in atexit._run_exitfuncs:
Traceback (most recent call last):
File "/home/lzk/anaconda3/lib/python3.7/multiprocessing/popen_fork.py", line 28, in poll
pid, sts = os.waitpid(self.pid, flag)
TypeError: keyboard_interrupt_handler() takes 1 positional argument but 2 were given
代碼參考:基于Python初探Linux下的僵尸進(jìn)程和孤兒進(jìn)程(三)[3]、Multiprocessing in python blocked[4]
很顯然是pytorch master進(jìn)程產(chǎn)生死鎖了,變成了僵尸進(jìn)程。
再探究,發(fā)現(xiàn)當(dāng)我把dataloader的number workers設(shè)為0的時(shí)候,程序可以正常結(jié)束。經(jīng)過(guò)我的注釋大法后我發(fā)現(xiàn),哪怕我把for _i , batch in enumerate(dataloader)內(nèi)的代碼全部注釋改為pass,程序還是會(huì)出現(xiàn)master無(wú)法正常結(jié)束的情況。所以問(wèn)題鎖定在dataloader身上。參考:nero:PyTorch DataLoader初探[5]
另外一種想法是,mp.spawn出現(xiàn)了問(wèn)題。使用此方式啟動(dòng)的進(jìn)程,只會(huì)執(zhí)行和 target 參數(shù)或者 run() 方法相關(guān)的代碼。Windows 平臺(tái)只能使用此方法,事實(shí)上該平臺(tái)默認(rèn)使用的也是該啟動(dòng)方式。相比其他兩種方式,此方式啟動(dòng)進(jìn)程的效率最低。參考:Python設(shè)置進(jìn)程啟動(dòng)的3種方式[6]
現(xiàn)在試一下,繞開(kāi)mp.spawn函數(shù),用shell腳本實(shí)現(xiàn)ddp,能不能不報(bào)錯(cuò):
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.1.201" --master_port=23456 我的文件.py
參數(shù)解釋:
- nnodes:因?yàn)槭菃螜C(jī)多卡,所以設(shè)為1,顯然node_rank 只能是0了
- local_rank:進(jìn)程在運(yùn)行的時(shí)候,會(huì)利用args插入local_rank這個(gè)參數(shù)標(biāo)識(shí)進(jìn)程序號(hào)
一番改動(dòng)后,發(fā)現(xiàn)問(wèn)題有所好轉(zhuǎn),最直觀的感受是速度快了非常多!!現(xiàn)在我沒(méi)有父進(jìn)程的問(wèn)題了,但還是在運(yùn)行完所有的程序后,無(wú)法正常結(jié)束:
此時(shí)我的代碼運(yùn)行到:
上面的代碼是main函數(shù),2個(gè)進(jìn)程(master,salve)都可以越過(guò)barrier,其中slave順利結(jié)束,但是master卻遲遲不見(jiàn)蹤影:
這個(gè)時(shí)候ctrl+c終止,發(fā)現(xiàn):
順著報(bào)錯(cuò)路徑去torch/distributed/launch.py, line 239找代碼:
def main():
args = parse_args()
# world size in terms of number of processes
dist_world_size = args.nproc_per_node * args.nnodes
# set PyTorch distributed related environmental variables
current_env = os.environ.copy()
current_env["MASTER_ADDR"] = args.master_addr
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)
processes = []
if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
current_env["OMP_NUM_THREADS"] = str(1)
print("*****************************************
"
"Setting OMP_NUM_THREADS environment variable for each process "
"to be {} in default, to avoid your system being overloaded, "
"please further tune the variable for optimal performance in "
"your application as needed.
"
"*****************************************".format(current_env["OMP_NUM_THREADS"]))
for local_rank in range(0, args.nproc_per_node):
# each process's rank
dist_rank = args.nproc_per_node * args.node_rank + local_rank
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)
# spawn the processes
if args.use_env:
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
else:
cmd = [sys.executable,
"-u",
args.training_script,
"--local_rank={}".format(local_rank)] + args.training_script_args
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
for process in processes:
process.wait() # 等待運(yùn)行結(jié)束
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process.returncode,
cmd=cmd)
可惡,master和dataloader到底有什么關(guān)系哇。。
這個(gè)問(wèn)題終于在昨天(2020/12/22)被解決了,說(shuō)來(lái)也好笑,左手是graphwriter的ddp實(shí)現(xiàn),無(wú)法正常退出,右手是minst的ddp最小例程,可以正常退出,于是我開(kāi)始了刪減大法。替換了數(shù)據(jù)集,model,然后讓dataloader空轉(zhuǎn),都沒(méi)有發(fā)現(xiàn)問(wèn)題,最后一步步逼近,知道我把自己的代碼這一行注釋掉以后,終于可以正常結(jié)束了:
def main(args):
############################################################
print('local_rank : ' , args.local_rank )
if args.is_ddp:
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.world_size,
rank=args.local_rank
)
############################################################
# torch.multiprocessing.set_sharing_strategy('file_system') 萬(wàn)惡之源
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"].split(',')[args.local_rank]
args.device = torch.device(0)
...
為什么我當(dāng)時(shí)會(huì)加上這句話呢?因?yàn)楫?dāng)時(shí)在調(diào)試number worker的時(shí)候(當(dāng)時(shí)年輕,以為越大越好,所以設(shè)置成了number workers = cpu.count()),發(fā)現(xiàn)系統(tǒng)報(bào)錯(cuò),說(shuō)超出了打開(kāi)文件的最大數(shù)量限制。在torch.multiprocessing的設(shè)定里,共享策略(參考pytorch中文文檔[7])默認(rèn)是File descriptor,此策略將使用文件描述符作為共享內(nèi)存句柄。當(dāng)存儲(chǔ)被移動(dòng)到共享內(nèi)存中,一個(gè)由shm_open
獲得的文件描述符被緩存。當(dāng)時(shí),文檔還提到:
如果你的系統(tǒng)對(duì)打開(kāi)的文件描述符數(shù)量有限制,并且無(wú)法提高,你應(yīng)該使用
file_system
策略。
所以我換成了torch.multiprocessing.set_sharing_strategy('file_system'),但是卻忽略文檔里的共享內(nèi)存泄露警告。顯然,或許這不是嚴(yán)重的問(wèn)題,文檔里提到:
也有可能我所說(shuō)的master進(jìn)程就是這個(gè)torch_shm_manager,因?yàn)閐estory進(jìn)程組始終無(wú)法結(jié)束0號(hào)進(jìn)程:
這個(gè)BUG結(jié)束了,真開(kāi)心,期待下一個(gè)BUG快快到來(lái)。
責(zé)任編輯:xj
原文標(biāo)題:Pytorch翻車記錄:?jiǎn)慰ǜ亩嗫ú瓤佑洠?/p>
文章出處:【微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
-
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8492瀏覽量
134117 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5554瀏覽量
122492 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13772
原文標(biāo)題:Pytorch翻車記錄:?jiǎn)慰ǜ亩嗫ú瓤佑洠?/p>
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
評(píng)論