diff --git a/pywidevine/serve.py b/pywidevine/serve.py index de7ee39..3c208a3 100644 --- a/pywidevine/serve.py +++ b/pywidevine/serve.py @@ -23,7 +23,7 @@ routes = web.RouteTableDef() async def _startup(app: web.Application): - app["sessions"]: dict[bytes, Cdm] = {} + app["cdms"]: dict[str, Cdm] = {} app["config"]["devices"] = { path.stem: path for x in app["config"]["devices"] @@ -32,8 +32,8 @@ async def _startup(app: web.Application): async def _cleanup(app: web.Application): - app["sessions"].clear() - del app["sessions"] + app["cdms"].clear() + del app["cdms"] app["config"].clear() del app["config"] @@ -48,7 +48,8 @@ async def ping(_) -> web.Response: @routes.get("/open/{device}") async def open(request: web.Request) -> web.Response: - user = request.app["config"]["users"][request.headers["X-Secret-Key"]] + secret_key = request.headers["X-Secret-Key"] + user = request.app["config"]["users"][secret_key] device = request.match_info["device"] if device not in user["devices"] or device not in request.app["config"]["devices"]: @@ -59,9 +60,11 @@ async def open(request: web.Request) -> web.Response: "message": f"Device '{device}' is not found or you are not authorized to use it." }, status=403) - device = Device.load(request.app["config"]["devices"][device]) + cdm = request.app["cdms"].get(secret_key) + if not cdm: + device = Device.load(request.app["config"]["devices"][device]) + cdm = request.app["cdms"][secret_key] = Cdm(device) - cdm = Cdm(device) try: session_id = cdm.open() except TooManySessions as e: @@ -70,16 +73,14 @@ async def open(request: web.Request) -> web.Response: "message": str(e) }, status=400) - request.app["sessions"][session_id] = cdm - return web.json_response({ "status": 200, "message": "Success", "data": { "session_id": session_id.hex(), "device": { - "system_id": device.system_id, - "security_level": device.security_level + "system_id": cdm.device.system_id, + "security_level": cdm.device.security_level } } }) @@ -87,6 +88,8 @@ async def open(request: web.Request) -> web.Response: @routes.post("/challenge/{license_type}") async def challenge(request: web.Request) -> web.Response: + secret_key = request.headers["X-Secret-Key"] + body = await request.json() for required_field in ("session_id", "init_data"): if not body.get(required_field): @@ -99,14 +102,17 @@ async def challenge(request: web.Request) -> web.Response: session_id = bytes.fromhex(body["session_id"]) # get cdm - if session_id not in request.app["sessions"]: - # e.g., app["sessions"] being cleared on server crash, reboot, and such - # or, the license message was from a challenge that was not made by our Cdm + cdm = request.app["cdms"].get(secret_key) + if not cdm or session_id not in cdm._sessions: + # This can happen if: + # - API server gets shutdown/restarted, + # - The user calls /challenge before /open, + # - The user called /open on a different IP Address + # - The user closed the session return web.json_response({ "status": 400, "message": "Invalid Session ID. Session ID may have Expired." }, status=400) - cdm = request.app["sessions"][session_id] # set service certificate service_certificate = body.get("service_certificate") @@ -137,6 +143,8 @@ async def challenge(request: web.Request) -> web.Response: @routes.post("/keys/{key_type}") async def keys(request: web.Request) -> web.Response: + secret_key = request.headers["X-Secret-Key"] + body = await request.json() for required_field in ("session_id", "license_message"): if not body.get(required_field): @@ -165,14 +173,17 @@ async def keys(request: web.Request) -> web.Response: }, status=400) # get cdm - if session_id not in request.app["sessions"]: - # e.g., app["sessions"] being cleared on server crash, reboot, and such - # or, the license message was from a challenge that was not made by our Cdm + cdm = request.app["cdms"].get(secret_key) + if not cdm or session_id not in cdm._sessions: + # This can happen if: + # - API server gets shutdown/restarted, + # - The user calls /challenge before /open, + # - The user called /open on a different IP Address + # - The user closed the session return web.json_response({ "status": 400, "message": "Invalid Session ID. Session ID may have Expired." }, status=400) - cdm = request.app["sessions"][session_id] # parse the license message cdm.parse_license(session_id, body["license_message"])