forked from DRMTalks/devine
Replace Threading Pool system with Thread Storage in Vaults
This fixes the usage of vaults across different threads. It now makes a truly unique connection for each thread. The previous code did this as well, but put back the connection from x thread, to re-use in y thread. Now it simply creates and reuses the connection on their own thread. Once the thread is closed, the data is now also garbage collected. This now reduces the risk of filling up memory over time.
This commit is contained in:
parent
4e875f5ffc
commit
1443dfaecc
|
@ -1,8 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from threading import Lock
|
||||
import threading
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
|
@ -23,13 +21,13 @@ class MySQL(Vault):
|
|||
"""
|
||||
super().__init__(name)
|
||||
self.slug = f"{host}:{database}:{username}"
|
||||
self.con_pool = ConnectionPool(dict(
|
||||
self.conn_factory = ConnectionFactory(dict(
|
||||
host=host,
|
||||
db=database,
|
||||
user=username,
|
||||
cursorclass=DictCursor,
|
||||
**kwargs
|
||||
), 5)
|
||||
))
|
||||
|
||||
self.permissions = self.get_permissions()
|
||||
if not self.has_permission("SELECT"):
|
||||
|
@ -43,7 +41,7 @@ class MySQL(Vault):
|
|||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -58,14 +56,13 @@ class MySQL(Vault):
|
|||
return cek["key_"]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if not self.has_table(service):
|
||||
# no table, no keys, simple
|
||||
return None
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -78,7 +75,6 @@ class MySQL(Vault):
|
|||
yield row["kid"], row["key_"]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
|
@ -96,7 +92,7 @@ class MySQL(Vault):
|
|||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -116,7 +112,6 @@ class MySQL(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -145,7 +140,7 @@ class MySQL(Vault):
|
|||
for kid, key_ in kid_keys.items()
|
||||
}
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -158,10 +153,9 @@ class MySQL(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -171,11 +165,10 @@ class MySQL(Vault):
|
|||
yield Services.get_tag(list(table.values())[0])
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def has_table(self, name: str) -> bool:
|
||||
"""Check if the Vault has a Table with the specified name."""
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -186,7 +179,6 @@ class MySQL(Vault):
|
|||
return list(cursor.fetchone().values())[0] == 1
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def create_table(self, name: str):
|
||||
"""Create a Table with the specified name if not yet created."""
|
||||
|
@ -196,7 +188,7 @@ class MySQL(Vault):
|
|||
if not self.has_permission("CREATE"):
|
||||
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -214,11 +206,10 @@ class MySQL(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_permissions(self) -> list:
|
||||
"""Get and parse Grants to a more easily usable list tuple array."""
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -234,7 +225,6 @@ class MySQL(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
||||
"""Check if the current connection has a specific permission."""
|
||||
|
@ -246,27 +236,15 @@ class MySQL(Vault):
|
|||
return bool(grants)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, con: dict, size: int):
|
||||
class ConnectionFactory:
|
||||
def __init__(self, con: dict):
|
||||
self._con = con
|
||||
self._size = size
|
||||
self._pool = Queue(self._size)
|
||||
self._lock = Lock()
|
||||
self._store = threading.local()
|
||||
|
||||
def _create_connection(self):
|
||||
def _create_connection(self) -> pymysql.Connection:
|
||||
return pymysql.connect(**self._con)
|
||||
|
||||
def get(self) -> pymysql.Connection:
|
||||
while True:
|
||||
try:
|
||||
return self._pool.get(block=False)
|
||||
except Empty:
|
||||
with self._lock:
|
||||
if self._pool.qsize() < self._size:
|
||||
return self._create_connection()
|
||||
else:
|
||||
# pool full, wait before retrying
|
||||
time.sleep(0.1)
|
||||
|
||||
def put(self, conn: pymysql.Connection):
|
||||
self._pool.put(conn)
|
||||
if not hasattr(self._store, "conn"):
|
||||
self._store.conn = self._create_connection()
|
||||
return self._store.conn
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from sqlite3 import Connection
|
||||
from threading import Lock
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
|
@ -20,7 +18,7 @@ class SQLite(Vault):
|
|||
super().__init__(name)
|
||||
self.path = Path(path).expanduser()
|
||||
# TODO: Use a DictCursor or such to get fetches as dict?
|
||||
self.con_pool = ConnectionPool(self.path, 5)
|
||||
self.conn_factory = ConnectionFactory(self.path)
|
||||
|
||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||
if not self.has_table(service):
|
||||
|
@ -30,7 +28,7 @@ class SQLite(Vault):
|
|||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -44,14 +42,13 @@ class SQLite(Vault):
|
|||
return cek[1]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if not self.has_table(service):
|
||||
# no table, no keys, simple
|
||||
return None
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -63,7 +60,6 @@ class SQLite(Vault):
|
|||
yield kid, key_
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
|
@ -75,7 +71,7 @@ class SQLite(Vault):
|
|||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -95,7 +91,6 @@ class SQLite(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -118,7 +113,7 @@ class SQLite(Vault):
|
|||
for kid, key_ in kid_keys.items()
|
||||
}
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -131,10 +126,9 @@ class SQLite(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -144,11 +138,10 @@ class SQLite(Vault):
|
|||
yield Services.get_tag(name)
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def has_table(self, name: str) -> bool:
|
||||
"""Check if the Vault has a Table with the specified name."""
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -159,14 +152,13 @@ class SQLite(Vault):
|
|||
return cursor.fetchone()[0] == 1
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def create_table(self, name: str):
|
||||
"""Create a Table with the specified name if not yet created."""
|
||||
if self.has_table(name):
|
||||
return
|
||||
|
||||
conn = self.con_pool.get()
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
|
@ -185,30 +177,17 @@ class SQLite(Vault):
|
|||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, path: Union[str, Path], size: int):
|
||||
class ConnectionFactory:
|
||||
def __init__(self, path: Union[str, Path]):
|
||||
self._path = path
|
||||
self._size = size
|
||||
self._pool = Queue(self._size)
|
||||
self._lock = Lock()
|
||||
self._store = threading.local()
|
||||
|
||||
def _create_connection(self):
|
||||
def _create_connection(self) -> Connection:
|
||||
return sqlite3.connect(self._path)
|
||||
|
||||
def get(self) -> Connection:
|
||||
while True:
|
||||
try:
|
||||
return self._pool.get(block=False)
|
||||
except Empty:
|
||||
with self._lock:
|
||||
if self._pool.qsize() < self._size:
|
||||
return self._create_connection()
|
||||
else:
|
||||
# pool full, wait before retrying
|
||||
time.sleep(0.1)
|
||||
|
||||
def put(self, conn: Connection):
|
||||
self._pool.put(conn)
|
||||
if not hasattr(self._store, "conn"):
|
||||
self._store.conn = self._create_connection()
|
||||
return self._store.conn
|
||||
|
|
Loading…
Reference in New Issue