Remove AtomicSQL, use Connection Pools in Vaults

This allows the use of vaults in any thread, while keeping database-file level thread safety. AtomicSQL doesn't actually do anything that useful.
This commit is contained in:
rlaphoenix 2023-02-21 05:38:39 +00:00
parent 707469d252
commit c925cb8af9
6 changed files with 255 additions and 272 deletions

View File

@ -89,7 +89,7 @@ def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) -
log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}")
try:
added = to_vault.add_keys(service_, content_keys, commit=True)
added = to_vault.add_keys(service_, content_keys)
except PermissionError:
log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...")
continue
@ -171,7 +171,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None:
for vault in vaults_:
log.info(f"Adding {total_count} Content Keys to {vault}")
added_count = vault.add_keys(service, kid_keys, commit=True)
added_count = vault.add_keys(service, kid_keys)
existed_count = total_count - added_count
log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)")

View File

@ -1,105 +0,0 @@
"""
AtomicSQL - Race-condition and Threading safe SQL Database Interface.
Copyright (C) 2020-2023 rlaphoenix
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sqlite3
import time
from threading import Lock
from typing import Any, Callable, Union
import pymysql.cursors
Connections = Union[sqlite3.Connection, pymysql.connections.Connection]
Cursors = Union[sqlite3.Cursor, pymysql.cursors.Cursor]
class AtomicSQL:
"""
Race-condition and Threading safe SQL Database Interface.
"""
def __init__(self) -> None:
self.master_lock = Lock() # prevents race condition
self.db: dict[bytes, Connections] = {} # used to hold the database connections and commit changes and such
self.cursor: dict[bytes, Cursors] = {} # used to execute queries and receive results
self.session_lock: dict[bytes, Lock] = {} # like master_lock, but per-session
def load(self, connection: Connections) -> bytes:
"""
Store SQL Connection object and return a reference ticket.
:param connection: SQLite3 or pymysql Connection object.
:returns: Session ID in which the database connection is referenced with.
"""
self.master_lock.acquire()
try:
# obtain a unique cryptographically random session_id
session_id = None
while not session_id or session_id in self.db:
session_id = os.urandom(16)
self.db[session_id] = connection
self.cursor[session_id] = self.db[session_id].cursor()
self.session_lock[session_id] = Lock()
return session_id
finally:
self.master_lock.release()
def safe_execute(self, session_id: bytes, action: Callable) -> Any:
"""
Execute code on the Database Connection in a race-condition safe way.
:param session_id: Database Connection's Session ID.
:param action: Function or lambda in which to execute, it's provided `db` and `cursor` arguments.
:returns: Whatever `action` returns.
"""
if session_id not in self.db:
raise ValueError(f"Session ID {session_id!r} is invalid.")
self.master_lock.acquire()
self.session_lock[session_id].acquire()
try:
failures = 0
while True:
try:
action(
db=self.db[session_id],
cursor=self.cursor[session_id]
)
break
except sqlite3.OperationalError as e:
failures += 1
delay = 3 * failures
print(f"AtomicSQL.safe_execute failed, {e}, retrying in {delay} seconds...")
time.sleep(delay)
if failures == 10:
raise ValueError("AtomicSQL.safe_execute failed too many time's. Aborting.")
return self.cursor[session_id]
finally:
self.session_lock[session_id].release()
self.master_lock.release()
def commit(self, session_id: bytes) -> bool:
"""
Commit changes to the Database Connection immediately.
This isn't necessary to be run every time you make changes, just ensure it's run
at least before termination.
:param session_id: Database Connection's Session ID.
:returns: True if it committed.
"""
self.safe_execute(
session_id,
lambda db, cursor: db.commit()
)
return True # todo ; actually check if db.commit worked

View File

@ -31,11 +31,11 @@ class Vault(metaclass=ABCMeta):
"""Get All Keys from Vault by Service."""
@abstractmethod
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
"""Add KID:KEY to the Vault."""
@abstractmethod
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
"""
Add Multiple Content Keys with Key IDs for Service to the Vault.
Pre-existing Content Keys are ignored/skipped.

View File

@ -57,7 +57,7 @@ class Vaults:
for vault in self.vaults:
if vault != excluding:
try:
success += vault.add_key(self.service, kid, key, commit=True)
success += vault.add_key(self.service, kid, key)
except (PermissionError, NotImplementedError):
pass
return success
@ -70,7 +70,7 @@ class Vaults:
success = 0
for vault in self.vaults:
try:
success += bool(vault.add_keys(self.service, kid_keys, commit=True))
success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError):
pass
return success

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import time
from queue import Empty, Queue
from threading import Lock
from typing import Iterator, Optional, Union
from uuid import UUID
@ -7,7 +10,6 @@ import pymysql
from pymysql.cursors import DictCursor
from devine.core.services import Services
from devine.core.utils.atomicsql import AtomicSQL
from devine.core.vault import Vault
@ -21,15 +23,13 @@ class MySQL(Vault):
"""
super().__init__(name)
self.slug = f"{host}:{database}:{username}"
self.con = pymysql.connect(
self.con_pool = ConnectionPool(dict(
host=host,
db=database,
user=username,
cursorclass=DictCursor,
**kwargs
)
self.adb = AtomicSQL()
self.ticket = self.adb.load(self.con)
), 5)
self.permissions = self.get_permissions()
if not self.has_permission("SELECT"):
@ -43,37 +43,44 @@ class MySQL(Vault):
if isinstance(kid, UUID):
kid = kid.hex
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
[kid, "0" * 32]
(kid, "0" * 32)
)
).fetchone()
if not c:
return None
return c["key_"]
cek = cursor.fetchone()
if not cek:
return None
return cek["key_"]
finally:
cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
["0" * 32]
("0" * 32,)
)
)
for row in cursor.fetchall():
yield row["kid"], row["key_"]
finally:
cursor.close()
self.con_pool.put(conn)
for row in c.fetchall():
yield row["kid"], row["key_"]
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
@ -82,39 +89,38 @@ class MySQL(Vault):
if not self.has_table(service):
try:
self.create_table(service, commit)
self.create_table(service)
except PermissionError:
return False
if isinstance(kid, UUID):
kid = kid.hex
if self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
[kid, key]
(kid, key)
)
).fetchone():
# table already has this exact KID:KEY stored
return True
self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
if cursor.fetchone():
# table already has this exact KID:KEY stored
return True
cursor.execute(
# TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
(kid, key)
)
)
if commit:
self.commit()
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
return True
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items():
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
@ -124,7 +130,7 @@ class MySQL(Vault):
if not self.has_table(service):
try:
self.create_table(service, commit)
self.create_table(service)
except PermissionError:
return 0
@ -139,40 +145,50 @@ class MySQL(Vault):
for kid, key_ in kid_keys.items()
}
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.executemany(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.executemany(
# TODO: SQL injection risk
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
kid_keys.items()
)
)
if commit:
self.commit()
return c.rowcount
return cursor.rowcount
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def get_services(self) -> Iterator[str]:
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute("SHOW TABLES")
)
for table in c.fetchall():
# each entry has a key named `Tables_in_<db name>`
yield Services.get_tag(list(table.values())[0])
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute("SHOW TABLES")
for table in cursor.fetchall():
# each entry has a key named `Tables_in_<db name>`
yield Services.get_tag(list(table.values())[0])
finally:
cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
return list(self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
[self.con.db, name]
)
).fetchone().values())[0] == 1
conn = self.con_pool.get()
cursor = conn.cursor()
def create_table(self, name: str, commit: bool = False):
try:
cursor.execute(
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
(conn.db, name)
)
return list(cursor.fetchone().values())[0] == 1
finally:
cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
if self.has_table(name):
return
@ -180,9 +196,11 @@ class MySQL(Vault):
if not self.has_permission("CREATE"):
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"""
CREATE TABLE IF NOT EXISTS {name} (
@ -193,23 +211,30 @@ class MySQL(Vault):
);
"""
)
)
if commit:
self.commit()
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def get_permissions(self) -> list:
"""Get and parse Grants to a more easily usable list tuple array."""
with self.con.cursor() as c:
c.execute("SHOW GRANTS")
grants = c.fetchall()
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
grants = [next(iter(x.values())) for x in grants]
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
grants = [(
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
location.replace("`", "").split(".")
) for perms, location in grants]
return grants
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
grants = [(
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
location.replace("`", "").split(".")
) for perms, location in grants]
return grants
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
"""Check if the current connection has a specific permission."""
@ -220,6 +245,28 @@ class MySQL(Vault):
grants = [x for x in grants if x[1][1] in (table, "*")]
return bool(grants)
def commit(self):
"""Commit any changes made that has not been written to db."""
self.adb.commit(self.ticket)
class ConnectionPool:
def __init__(self, con: dict, size: int):
self._con = con
self._size = size
self._pool = Queue(self._size)
self._lock = Lock()
def _create_connection(self):
return pymysql.connect(**self._con)
def get(self) -> pymysql.Connection:
while True:
try:
return self._pool.get(block=False)
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: pymysql.Connection):
self._pool.put(conn)

View File

@ -1,12 +1,15 @@
from __future__ import annotations
import sqlite3
import time
from pathlib import Path
from queue import Empty, Queue
from sqlite3 import Connection
from threading import Lock
from typing import Iterator, Optional, Union
from uuid import UUID
from devine.core.services import Services
from devine.core.utils.atomicsql import AtomicSQL
from devine.core.vault import Vault
@ -17,9 +20,7 @@ class SQLite(Vault):
super().__init__(name)
self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict?
self.con = sqlite3.connect(self.path)
self.adb = AtomicSQL()
self.ticket = self.adb.load(self.con)
self.con_pool = ConnectionPool(self.path, 5)
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
if not self.has_table(service):
@ -29,76 +30,82 @@ class SQLite(Vault):
if isinstance(kid, UUID):
kid = kid.hex
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
# TODO: SQL injection risk
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
[kid, "0" * 32]
)
).fetchone()
if not c:
return None
conn = self.con_pool.get()
cursor = conn.cursor()
return c[1] # `key_`
try:
cursor.execute(
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
(kid, "0" * 32)
)
cek = cursor.fetchone()
if not cek:
return None
return cek[1]
finally:
cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
# TODO: SQL injection risk
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
["0" * 32]
)
)
for (kid, key_) in c.fetchall():
yield kid, key_
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
("0" * 32,)
)
for (kid, key_) in cursor.fetchall():
yield kid, key_
finally:
cursor.close()
self.con_pool.put(conn)
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_table(service):
self.create_table(service, commit)
self.create_table(service)
if isinstance(kid, UUID):
kid = kid.hex
if self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
[kid, key]
(kid, key)
)
).fetchone():
# table already has this exact KID:KEY stored
return True
self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
if cursor.fetchone():
# table already has this exact KID:KEY stored
return True
cursor.execute(
# TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
(kid, key)
)
)
if commit:
self.commit()
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
return True
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items():
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_table(service):
self.create_table(service, commit)
self.create_table(service)
if not isinstance(kid_keys, dict):
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
@ -111,47 +118,59 @@ class SQLite(Vault):
for kid, key_ in kid_keys.items()
}
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.executemany(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.executemany(
# TODO: SQL injection risk
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
kid_keys.items()
)
)
if commit:
self.commit()
return c.rowcount
return cursor.rowcount
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def get_services(self) -> Iterator[str]:
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
)
for (name,) in c.fetchall():
if name != "sqlite_sequence":
yield Services.get_tag(name)
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
for (name,) in cursor.fetchall():
if name != "sqlite_sequence":
yield Services.get_tag(name)
finally:
cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
return self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
[name]
)
).fetchone()[0] == 1
conn = self.con_pool.get()
cursor = conn.cursor()
def create_table(self, name: str, commit: bool = False):
try:
cursor.execute(
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
(name,)
)
return cursor.fetchone()[0] == 1
finally:
cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
if self.has_table(name):
return
self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"""
CREATE TABLE IF NOT EXISTS {name} (
@ -163,11 +182,33 @@ class SQLite(Vault):
);
"""
)
)
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
if commit:
self.commit()
def commit(self):
"""Commit any changes made that has not been written to db."""
self.adb.commit(self.ticket)
class ConnectionPool:
def __init__(self, path: Union[str, Path], size: int):
self._path = path
self._size = size
self._pool = Queue(self._size)
self._lock = Lock()
def _create_connection(self):
return sqlite3.connect(self._path)
def get(self) -> Connection:
while True:
try:
return self._pool.get(block=False)
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: Connection):
self._pool.put(conn)