Skip to content

Commit

Permalink
fix dataloader exit hang when join re-enter (#32835)
Browse files Browse the repository at this point in the history
* fix dataloader exit hang when join re-enter. test=develop

* double check _shutdown. test=develop
  • Loading branch information
heavengate authored May 12, 2021
1 parent 0251320 commit 4ccd9a0
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,14 @@ def __init__(self, loader):

# if user exit python program when dataloader is still
# iterating, resource may no release safely, so we
# add __del__ function to to CleanupFuncRegistrar
# to make sure __del__ is always called when program
# add _shutdown_on_exit function to to CleanupFuncRegistrar
# to make sure _try_shutdown_all is always called when program
# exit for resoure releasing safely
CleanupFuncRegistrar.register(self.__del__)
# worker join may hang for in _try_shutdown_all call in atexit
# for main process is in atexit state in some OS, so we add
# timeout=1 for shutdown function call in atexit, for shutdown
# function call in __del__, we keep it as it is
CleanupFuncRegistrar.register(self._shutdown_on_exit)

def _init_workers(self):
# multiprocess worker and indice queue list initial as empty
Expand Down Expand Up @@ -363,7 +367,7 @@ def _shutdown_worker(self, worker_id):
self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False

def _try_shutdown_all(self):
def _try_shutdown_all(self, timeout=None):
if not self._shutdown:
try:
self._exit_thread_expectedly()
Expand All @@ -376,11 +380,12 @@ def _try_shutdown_all(self):
for i in range(self._num_workers):
self._shutdown_worker(i)

for w in self._workers:
w.join()
for q in self._indices_queues:
q.cancel_join_thread()
q.close()
if not self._shutdown:
for w in self._workers:
w.join(timeout)
for q in self._indices_queues:
q.cancel_join_thread()
q.close()
finally:
core._erase_process_pids(id(self))
self._shutdown = True
Expand Down Expand Up @@ -560,6 +565,9 @@ def _try_put_indices(self):
def __del__(self):
self._try_shutdown_all()

def _shutdown_on_exit(self):
self._try_shutdown_all(1)

def __next__(self):
try:
# _batches_outstanding here record the total batch data number
Expand Down

0 comments on commit 4ccd9a0

Please sign in to comment.