Dynamically reduce DASH worker pool on connection errors

This commit is contained in:
rlaphoenix 2023-03-12 00:06:26 +00:00
parent bf3219b4e8
commit 1ef7419966
1 changed files with 18 additions and 3 deletions

View File

@ -13,6 +13,7 @@ from copy import copy
from functools import partial
from hashlib import md5
from pathlib import Path
from queue import Queue
from threading import Event
from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin, urlparse
@ -454,7 +455,7 @@ class DASH:
else:
drm = None
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
def download_segment(filename: str, segment: tuple[str, Optional[str]], workers: Queue) -> int:
if stop_event.is_set():
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
@ -465,6 +466,7 @@ class DASH:
attempts = 1
while True:
workers.put(None)
try:
if segment_range:
# aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
@ -486,12 +488,23 @@ class DASH:
silent=attempts != 5,
segmented=True
))
break
except ConnectionRefusedError:
# server likely blocking more than a few connections
# reduce max workers but let this thread continue
with workers.mutex:
_ = workers._get() # take back this threads queue item
if workers.maxsize > 1:
workers.maxsize -= 1
print(f"REDUCED TRAIN SIZE TO {workers.maxsize}")
except Exception as ee:
_ = workers.get()
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
else:
_ = workers.get()
break
data_size = segment_save_path.stat().st_size
@ -523,6 +536,7 @@ class DASH:
progress(total=len(segments))
pool_workers = Queue(maxsize=16)
finished_threads = 0
download_sizes = []
last_speed_refresh = time.time()
@ -532,7 +546,8 @@ class DASH:
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
segment=segment,
workers=pool_workers
)
for i, segment in enumerate(segments)
)):