forked from DRMTalks/devine
Replace Threading Pool system with Thread Storage in Vaults
This fixes the usage of vaults across different threads. It now makes a truly unique connection for each thread. The previous code did this as well, but put back the connection from x thread, to re-use in y thread. Now it simply creates and reuses the connection on their own thread. Once the thread is closed, the data is now also garbage collected. This now reduces the risk of filling up memory over time.
This commit is contained in:
parent
4e875f5ffc
commit
1443dfaecc
|
@ -1,8 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import threading
|
||||||
from queue import Empty, Queue
|
|
||||||
from threading import Lock
|
|
||||||
from typing import Iterator, Optional, Union
|
from typing import Iterator, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
@ -23,13 +21,13 @@ class MySQL(Vault):
|
||||||
"""
|
"""
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.slug = f"{host}:{database}:{username}"
|
self.slug = f"{host}:{database}:{username}"
|
||||||
self.con_pool = ConnectionPool(dict(
|
self.conn_factory = ConnectionFactory(dict(
|
||||||
host=host,
|
host=host,
|
||||||
db=database,
|
db=database,
|
||||||
user=username,
|
user=username,
|
||||||
cursorclass=DictCursor,
|
cursorclass=DictCursor,
|
||||||
**kwargs
|
**kwargs
|
||||||
), 5)
|
))
|
||||||
|
|
||||||
self.permissions = self.get_permissions()
|
self.permissions = self.get_permissions()
|
||||||
if not self.has_permission("SELECT"):
|
if not self.has_permission("SELECT"):
|
||||||
|
@ -43,7 +41,7 @@ class MySQL(Vault):
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -58,14 +56,13 @@ class MySQL(Vault):
|
||||||
return cek["key_"]
|
return cek["key_"]
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
# no table, no keys, simple
|
# no table, no keys, simple
|
||||||
return None
|
return None
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -78,7 +75,6 @@ class MySQL(Vault):
|
||||||
yield row["kid"], row["key_"]
|
yield row["kid"], row["key_"]
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||||
if not key or key.count("0") == len(key):
|
if not key or key.count("0") == len(key):
|
||||||
|
@ -96,7 +92,7 @@ class MySQL(Vault):
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -116,7 +112,6 @@ class MySQL(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -145,7 +140,7 @@ class MySQL(Vault):
|
||||||
for kid, key_ in kid_keys.items()
|
for kid, key_ in kid_keys.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -158,10 +153,9 @@ class MySQL(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def get_services(self) -> Iterator[str]:
|
def get_services(self) -> Iterator[str]:
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -171,11 +165,10 @@ class MySQL(Vault):
|
||||||
yield Services.get_tag(list(table.values())[0])
|
yield Services.get_tag(list(table.values())[0])
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def has_table(self, name: str) -> bool:
|
def has_table(self, name: str) -> bool:
|
||||||
"""Check if the Vault has a Table with the specified name."""
|
"""Check if the Vault has a Table with the specified name."""
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -186,7 +179,6 @@ class MySQL(Vault):
|
||||||
return list(cursor.fetchone().values())[0] == 1
|
return list(cursor.fetchone().values())[0] == 1
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def create_table(self, name: str):
|
def create_table(self, name: str):
|
||||||
"""Create a Table with the specified name if not yet created."""
|
"""Create a Table with the specified name if not yet created."""
|
||||||
|
@ -196,7 +188,7 @@ class MySQL(Vault):
|
||||||
if not self.has_permission("CREATE"):
|
if not self.has_permission("CREATE"):
|
||||||
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -214,11 +206,10 @@ class MySQL(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def get_permissions(self) -> list:
|
def get_permissions(self) -> list:
|
||||||
"""Get and parse Grants to a more easily usable list tuple array."""
|
"""Get and parse Grants to a more easily usable list tuple array."""
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -234,7 +225,6 @@ class MySQL(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
||||||
"""Check if the current connection has a specific permission."""
|
"""Check if the current connection has a specific permission."""
|
||||||
|
@ -246,27 +236,15 @@ class MySQL(Vault):
|
||||||
return bool(grants)
|
return bool(grants)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPool:
|
class ConnectionFactory:
|
||||||
def __init__(self, con: dict, size: int):
|
def __init__(self, con: dict):
|
||||||
self._con = con
|
self._con = con
|
||||||
self._size = size
|
self._store = threading.local()
|
||||||
self._pool = Queue(self._size)
|
|
||||||
self._lock = Lock()
|
|
||||||
|
|
||||||
def _create_connection(self):
|
def _create_connection(self) -> pymysql.Connection:
|
||||||
return pymysql.connect(**self._con)
|
return pymysql.connect(**self._con)
|
||||||
|
|
||||||
def get(self) -> pymysql.Connection:
|
def get(self) -> pymysql.Connection:
|
||||||
while True:
|
if not hasattr(self._store, "conn"):
|
||||||
try:
|
self._store.conn = self._create_connection()
|
||||||
return self._pool.get(block=False)
|
return self._store.conn
|
||||||
except Empty:
|
|
||||||
with self._lock:
|
|
||||||
if self._pool.qsize() < self._size:
|
|
||||||
return self._create_connection()
|
|
||||||
else:
|
|
||||||
# pool full, wait before retrying
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
def put(self, conn: pymysql.Connection):
|
|
||||||
self._pool.put(conn)
|
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
|
||||||
from sqlite3 import Connection
|
from sqlite3 import Connection
|
||||||
from threading import Lock
|
|
||||||
from typing import Iterator, Optional, Union
|
from typing import Iterator, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
@ -20,7 +18,7 @@ class SQLite(Vault):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.path = Path(path).expanduser()
|
self.path = Path(path).expanduser()
|
||||||
# TODO: Use a DictCursor or such to get fetches as dict?
|
# TODO: Use a DictCursor or such to get fetches as dict?
|
||||||
self.con_pool = ConnectionPool(self.path, 5)
|
self.conn_factory = ConnectionFactory(self.path)
|
||||||
|
|
||||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
|
@ -30,7 +28,7 @@ class SQLite(Vault):
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -44,14 +42,13 @@ class SQLite(Vault):
|
||||||
return cek[1]
|
return cek[1]
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
# no table, no keys, simple
|
# no table, no keys, simple
|
||||||
return None
|
return None
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -63,7 +60,6 @@ class SQLite(Vault):
|
||||||
yield kid, key_
|
yield kid, key_
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||||
if not key or key.count("0") == len(key):
|
if not key or key.count("0") == len(key):
|
||||||
|
@ -75,7 +71,7 @@ class SQLite(Vault):
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -95,7 +91,6 @@ class SQLite(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -118,7 +113,7 @@ class SQLite(Vault):
|
||||||
for kid, key_ in kid_keys.items()
|
for kid, key_ in kid_keys.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -131,10 +126,9 @@ class SQLite(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def get_services(self) -> Iterator[str]:
|
def get_services(self) -> Iterator[str]:
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -144,11 +138,10 @@ class SQLite(Vault):
|
||||||
yield Services.get_tag(name)
|
yield Services.get_tag(name)
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def has_table(self, name: str) -> bool:
|
def has_table(self, name: str) -> bool:
|
||||||
"""Check if the Vault has a Table with the specified name."""
|
"""Check if the Vault has a Table with the specified name."""
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -159,14 +152,13 @@ class SQLite(Vault):
|
||||||
return cursor.fetchone()[0] == 1
|
return cursor.fetchone()[0] == 1
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
def create_table(self, name: str):
|
def create_table(self, name: str):
|
||||||
"""Create a Table with the specified name if not yet created."""
|
"""Create a Table with the specified name if not yet created."""
|
||||||
if self.has_table(name):
|
if self.has_table(name):
|
||||||
return
|
return
|
||||||
|
|
||||||
conn = self.con_pool.get()
|
conn = self.conn_factory.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -185,30 +177,17 @@ class SQLite(Vault):
|
||||||
finally:
|
finally:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.con_pool.put(conn)
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPool:
|
class ConnectionFactory:
|
||||||
def __init__(self, path: Union[str, Path], size: int):
|
def __init__(self, path: Union[str, Path]):
|
||||||
self._path = path
|
self._path = path
|
||||||
self._size = size
|
self._store = threading.local()
|
||||||
self._pool = Queue(self._size)
|
|
||||||
self._lock = Lock()
|
|
||||||
|
|
||||||
def _create_connection(self):
|
def _create_connection(self) -> Connection:
|
||||||
return sqlite3.connect(self._path)
|
return sqlite3.connect(self._path)
|
||||||
|
|
||||||
def get(self) -> Connection:
|
def get(self) -> Connection:
|
||||||
while True:
|
if not hasattr(self._store, "conn"):
|
||||||
try:
|
self._store.conn = self._create_connection()
|
||||||
return self._pool.get(block=False)
|
return self._store.conn
|
||||||
except Empty:
|
|
||||||
with self._lock:
|
|
||||||
if self._pool.qsize() < self._size:
|
|
||||||
return self._create_connection()
|
|
||||||
else:
|
|
||||||
# pool full, wait before retrying
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
def put(self, conn: Connection):
|
|
||||||
self._pool.put(conn)
|
|
||||||
|
|
Loading…
Reference in New Issue