forked from DRMTalks/devine
Remove AtomicSQL, use Connection Pools in Vaults
This allows the use of vaults in any thread, while keeping database-file level thread safety. AtomicSQL doesn't actually do anything that useful.
This commit is contained in:
parent
707469d252
commit
c925cb8af9
|
@ -89,7 +89,7 @@ def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) -
|
||||||
log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}")
|
log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
added = to_vault.add_keys(service_, content_keys, commit=True)
|
added = to_vault.add_keys(service_, content_keys)
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...")
|
log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...")
|
||||||
continue
|
continue
|
||||||
|
@ -171,7 +171,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None:
|
||||||
|
|
||||||
for vault in vaults_:
|
for vault in vaults_:
|
||||||
log.info(f"Adding {total_count} Content Keys to {vault}")
|
log.info(f"Adding {total_count} Content Keys to {vault}")
|
||||||
added_count = vault.add_keys(service, kid_keys, commit=True)
|
added_count = vault.add_keys(service, kid_keys)
|
||||||
existed_count = total_count - added_count
|
existed_count = total_count - added_count
|
||||||
log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)")
|
log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)")
|
||||||
|
|
||||||
|
|
|
@ -1,105 +0,0 @@
|
||||||
"""
|
|
||||||
AtomicSQL - Race-condition and Threading safe SQL Database Interface.
|
|
||||||
Copyright (C) 2020-2023 rlaphoenix
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU General Public License as published by
|
|
||||||
the Free Software Foundation, either version 3 of the License, or
|
|
||||||
(at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import time
|
|
||||||
from threading import Lock
|
|
||||||
from typing import Any, Callable, Union
|
|
||||||
|
|
||||||
import pymysql.cursors
|
|
||||||
|
|
||||||
Connections = Union[sqlite3.Connection, pymysql.connections.Connection]
|
|
||||||
Cursors = Union[sqlite3.Cursor, pymysql.cursors.Cursor]
|
|
||||||
|
|
||||||
|
|
||||||
class AtomicSQL:
|
|
||||||
"""
|
|
||||||
Race-condition and Threading safe SQL Database Interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.master_lock = Lock() # prevents race condition
|
|
||||||
self.db: dict[bytes, Connections] = {} # used to hold the database connections and commit changes and such
|
|
||||||
self.cursor: dict[bytes, Cursors] = {} # used to execute queries and receive results
|
|
||||||
self.session_lock: dict[bytes, Lock] = {} # like master_lock, but per-session
|
|
||||||
|
|
||||||
def load(self, connection: Connections) -> bytes:
|
|
||||||
"""
|
|
||||||
Store SQL Connection object and return a reference ticket.
|
|
||||||
:param connection: SQLite3 or pymysql Connection object.
|
|
||||||
:returns: Session ID in which the database connection is referenced with.
|
|
||||||
"""
|
|
||||||
self.master_lock.acquire()
|
|
||||||
try:
|
|
||||||
# obtain a unique cryptographically random session_id
|
|
||||||
session_id = None
|
|
||||||
while not session_id or session_id in self.db:
|
|
||||||
session_id = os.urandom(16)
|
|
||||||
self.db[session_id] = connection
|
|
||||||
self.cursor[session_id] = self.db[session_id].cursor()
|
|
||||||
self.session_lock[session_id] = Lock()
|
|
||||||
return session_id
|
|
||||||
finally:
|
|
||||||
self.master_lock.release()
|
|
||||||
|
|
||||||
def safe_execute(self, session_id: bytes, action: Callable) -> Any:
|
|
||||||
"""
|
|
||||||
Execute code on the Database Connection in a race-condition safe way.
|
|
||||||
:param session_id: Database Connection's Session ID.
|
|
||||||
:param action: Function or lambda in which to execute, it's provided `db` and `cursor` arguments.
|
|
||||||
:returns: Whatever `action` returns.
|
|
||||||
"""
|
|
||||||
if session_id not in self.db:
|
|
||||||
raise ValueError(f"Session ID {session_id!r} is invalid.")
|
|
||||||
self.master_lock.acquire()
|
|
||||||
self.session_lock[session_id].acquire()
|
|
||||||
try:
|
|
||||||
failures = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
action(
|
|
||||||
db=self.db[session_id],
|
|
||||||
cursor=self.cursor[session_id]
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except sqlite3.OperationalError as e:
|
|
||||||
failures += 1
|
|
||||||
delay = 3 * failures
|
|
||||||
print(f"AtomicSQL.safe_execute failed, {e}, retrying in {delay} seconds...")
|
|
||||||
time.sleep(delay)
|
|
||||||
if failures == 10:
|
|
||||||
raise ValueError("AtomicSQL.safe_execute failed too many time's. Aborting.")
|
|
||||||
return self.cursor[session_id]
|
|
||||||
finally:
|
|
||||||
self.session_lock[session_id].release()
|
|
||||||
self.master_lock.release()
|
|
||||||
|
|
||||||
def commit(self, session_id: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
Commit changes to the Database Connection immediately.
|
|
||||||
This isn't necessary to be run every time you make changes, just ensure it's run
|
|
||||||
at least before termination.
|
|
||||||
:param session_id: Database Connection's Session ID.
|
|
||||||
:returns: True if it committed.
|
|
||||||
"""
|
|
||||||
self.safe_execute(
|
|
||||||
session_id,
|
|
||||||
lambda db, cursor: db.commit()
|
|
||||||
)
|
|
||||||
return True # todo ; actually check if db.commit worked
|
|
|
@ -31,11 +31,11 @@ class Vault(metaclass=ABCMeta):
|
||||||
"""Get All Keys from Vault by Service."""
|
"""Get All Keys from Vault by Service."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
|
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||||
"""Add KID:KEY to the Vault."""
|
"""Add KID:KEY to the Vault."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
|
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||||
"""
|
"""
|
||||||
Add Multiple Content Keys with Key IDs for Service to the Vault.
|
Add Multiple Content Keys with Key IDs for Service to the Vault.
|
||||||
Pre-existing Content Keys are ignored/skipped.
|
Pre-existing Content Keys are ignored/skipped.
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Vaults:
|
||||||
for vault in self.vaults:
|
for vault in self.vaults:
|
||||||
if vault != excluding:
|
if vault != excluding:
|
||||||
try:
|
try:
|
||||||
success += vault.add_key(self.service, kid, key, commit=True)
|
success += vault.add_key(self.service, kid, key)
|
||||||
except (PermissionError, NotImplementedError):
|
except (PermissionError, NotImplementedError):
|
||||||
pass
|
pass
|
||||||
return success
|
return success
|
||||||
|
@ -70,7 +70,7 @@ class Vaults:
|
||||||
success = 0
|
success = 0
|
||||||
for vault in self.vaults:
|
for vault in self.vaults:
|
||||||
try:
|
try:
|
||||||
success += bool(vault.add_keys(self.service, kid_keys, commit=True))
|
success += bool(vault.add_keys(self.service, kid_keys))
|
||||||
except (PermissionError, NotImplementedError):
|
except (PermissionError, NotImplementedError):
|
||||||
pass
|
pass
|
||||||
return success
|
return success
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from queue import Empty, Queue
|
||||||
|
from threading import Lock
|
||||||
from typing import Iterator, Optional, Union
|
from typing import Iterator, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
@ -7,7 +10,6 @@ import pymysql
|
||||||
from pymysql.cursors import DictCursor
|
from pymysql.cursors import DictCursor
|
||||||
|
|
||||||
from devine.core.services import Services
|
from devine.core.services import Services
|
||||||
from devine.core.utils.atomicsql import AtomicSQL
|
|
||||||
from devine.core.vault import Vault
|
from devine.core.vault import Vault
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,15 +23,13 @@ class MySQL(Vault):
|
||||||
"""
|
"""
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.slug = f"{host}:{database}:{username}"
|
self.slug = f"{host}:{database}:{username}"
|
||||||
self.con = pymysql.connect(
|
self.con_pool = ConnectionPool(dict(
|
||||||
host=host,
|
host=host,
|
||||||
db=database,
|
db=database,
|
||||||
user=username,
|
user=username,
|
||||||
cursorclass=DictCursor,
|
cursorclass=DictCursor,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
), 5)
|
||||||
self.adb = AtomicSQL()
|
|
||||||
self.ticket = self.adb.load(self.con)
|
|
||||||
|
|
||||||
self.permissions = self.get_permissions()
|
self.permissions = self.get_permissions()
|
||||||
if not self.has_permission("SELECT"):
|
if not self.has_permission("SELECT"):
|
||||||
|
@ -43,37 +43,44 @@ class MySQL(Vault):
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
|
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
|
||||||
[kid, "0" * 32]
|
(kid, "0" * 32)
|
||||||
)
|
)
|
||||||
).fetchone()
|
cek = cursor.fetchone()
|
||||||
if not c:
|
if not cek:
|
||||||
return None
|
return None
|
||||||
|
return cek["key_"]
|
||||||
return c["key_"]
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
# no table, no keys, simple
|
# no table, no keys, simple
|
||||||
return None
|
return None
|
||||||
|
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
|
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
|
||||||
["0" * 32]
|
("0" * 32,)
|
||||||
)
|
)
|
||||||
)
|
for row in cursor.fetchall():
|
||||||
|
yield row["kid"], row["key_"]
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
for row in c.fetchall():
|
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||||
yield row["kid"], row["key_"]
|
|
||||||
|
|
||||||
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
|
|
||||||
if not key or key.count("0") == len(key):
|
if not key or key.count("0") == len(key):
|
||||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||||
|
|
||||||
|
@ -82,39 +89,38 @@ class MySQL(Vault):
|
||||||
|
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
try:
|
try:
|
||||||
self.create_table(service, commit)
|
self.create_table(service)
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
if self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
|
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
|
||||||
[kid, key]
|
(kid, key)
|
||||||
)
|
)
|
||||||
).fetchone():
|
if cursor.fetchone():
|
||||||
# table already has this exact KID:KEY stored
|
# table already has this exact KID:KEY stored
|
||||||
return True
|
return True
|
||||||
|
cursor.execute(
|
||||||
self.adb.safe_execute(
|
|
||||||
self.ticket,
|
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
||||||
(kid, key)
|
(kid, key)
|
||||||
)
|
)
|
||||||
)
|
finally:
|
||||||
|
conn.commit()
|
||||||
if commit:
|
cursor.close()
|
||||||
self.commit()
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
|
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||||
for kid, key in kid_keys.items():
|
for kid, key in kid_keys.items():
|
||||||
if not key or key.count("0") == len(key):
|
if not key or key.count("0") == len(key):
|
||||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||||
|
@ -124,7 +130,7 @@ class MySQL(Vault):
|
||||||
|
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
try:
|
try:
|
||||||
self.create_table(service, commit)
|
self.create_table(service)
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -139,40 +145,50 @@ class MySQL(Vault):
|
||||||
for kid, key_ in kid_keys.items()
|
for kid, key_ in kid_keys.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.executemany(
|
|
||||||
|
try:
|
||||||
|
cursor.executemany(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
||||||
kid_keys.items()
|
kid_keys.items()
|
||||||
)
|
)
|
||||||
)
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
if commit:
|
conn.commit()
|
||||||
self.commit()
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
return c.rowcount
|
|
||||||
|
|
||||||
def get_services(self) -> Iterator[str]:
|
def get_services(self) -> Iterator[str]:
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute("SHOW TABLES")
|
|
||||||
)
|
try:
|
||||||
for table in c.fetchall():
|
cursor.execute("SHOW TABLES")
|
||||||
# each entry has a key named `Tables_in_<db name>`
|
for table in cursor.fetchall():
|
||||||
yield Services.get_tag(list(table.values())[0])
|
# each entry has a key named `Tables_in_<db name>`
|
||||||
|
yield Services.get_tag(list(table.values())[0])
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
def has_table(self, name: str) -> bool:
|
def has_table(self, name: str) -> bool:
|
||||||
"""Check if the Vault has a Table with the specified name."""
|
"""Check if the Vault has a Table with the specified name."""
|
||||||
return list(self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
|
|
||||||
[self.con.db, name]
|
|
||||||
)
|
|
||||||
).fetchone().values())[0] == 1
|
|
||||||
|
|
||||||
def create_table(self, name: str, commit: bool = False):
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
|
||||||
|
(conn.db, name)
|
||||||
|
)
|
||||||
|
return list(cursor.fetchone().values())[0] == 1
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
|
def create_table(self, name: str):
|
||||||
"""Create a Table with the specified name if not yet created."""
|
"""Create a Table with the specified name if not yet created."""
|
||||||
if self.has_table(name):
|
if self.has_table(name):
|
||||||
return
|
return
|
||||||
|
@ -180,9 +196,11 @@ class MySQL(Vault):
|
||||||
if not self.has_permission("CREATE"):
|
if not self.has_permission("CREATE"):
|
||||||
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
||||||
|
|
||||||
self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"""
|
f"""
|
||||||
CREATE TABLE IF NOT EXISTS {name} (
|
CREATE TABLE IF NOT EXISTS {name} (
|
||||||
|
@ -193,23 +211,30 @@ class MySQL(Vault):
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
finally:
|
||||||
|
conn.commit()
|
||||||
if commit:
|
cursor.close()
|
||||||
self.commit()
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
def get_permissions(self) -> list:
|
def get_permissions(self) -> list:
|
||||||
"""Get and parse Grants to a more easily usable list tuple array."""
|
"""Get and parse Grants to a more easily usable list tuple array."""
|
||||||
with self.con.cursor() as c:
|
conn = self.con_pool.get()
|
||||||
c.execute("SHOW GRANTS")
|
cursor = conn.cursor()
|
||||||
grants = c.fetchall()
|
|
||||||
|
try:
|
||||||
|
cursor.execute("SHOW GRANTS")
|
||||||
|
grants = cursor.fetchall()
|
||||||
grants = [next(iter(x.values())) for x in grants]
|
grants = [next(iter(x.values())) for x in grants]
|
||||||
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
|
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
|
||||||
grants = [(
|
grants = [(
|
||||||
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
|
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
|
||||||
location.replace("`", "").split(".")
|
location.replace("`", "").split(".")
|
||||||
) for perms, location in grants]
|
) for perms, location in grants]
|
||||||
return grants
|
return grants
|
||||||
|
finally:
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
||||||
"""Check if the current connection has a specific permission."""
|
"""Check if the current connection has a specific permission."""
|
||||||
|
@ -220,6 +245,28 @@ class MySQL(Vault):
|
||||||
grants = [x for x in grants if x[1][1] in (table, "*")]
|
grants = [x for x in grants if x[1][1] in (table, "*")]
|
||||||
return bool(grants)
|
return bool(grants)
|
||||||
|
|
||||||
def commit(self):
|
|
||||||
"""Commit any changes made that has not been written to db."""
|
class ConnectionPool:
|
||||||
self.adb.commit(self.ticket)
|
def __init__(self, con: dict, size: int):
|
||||||
|
self._con = con
|
||||||
|
self._size = size
|
||||||
|
self._pool = Queue(self._size)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def _create_connection(self):
|
||||||
|
return pymysql.connect(**self._con)
|
||||||
|
|
||||||
|
def get(self) -> pymysql.Connection:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self._pool.get(block=False)
|
||||||
|
except Empty:
|
||||||
|
with self._lock:
|
||||||
|
if self._pool.qsize() < self._size:
|
||||||
|
return self._create_connection()
|
||||||
|
else:
|
||||||
|
# pool full, wait before retrying
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def put(self, conn: pymysql.Connection):
|
||||||
|
self._pool.put(conn)
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from queue import Empty, Queue
|
||||||
|
from sqlite3 import Connection
|
||||||
|
from threading import Lock
|
||||||
from typing import Iterator, Optional, Union
|
from typing import Iterator, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from devine.core.services import Services
|
from devine.core.services import Services
|
||||||
from devine.core.utils.atomicsql import AtomicSQL
|
|
||||||
from devine.core.vault import Vault
|
from devine.core.vault import Vault
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,9 +20,7 @@ class SQLite(Vault):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.path = Path(path).expanduser()
|
self.path = Path(path).expanduser()
|
||||||
# TODO: Use a DictCursor or such to get fetches as dict?
|
# TODO: Use a DictCursor or such to get fetches as dict?
|
||||||
self.con = sqlite3.connect(self.path)
|
self.con_pool = ConnectionPool(self.path, 5)
|
||||||
self.adb = AtomicSQL()
|
|
||||||
self.ticket = self.adb.load(self.con)
|
|
||||||
|
|
||||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
|
@ -29,76 +30,82 @@ class SQLite(Vault):
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
# TODO: SQL injection risk
|
|
||||||
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
|
|
||||||
[kid, "0" * 32]
|
|
||||||
)
|
|
||||||
).fetchone()
|
|
||||||
if not c:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return c[1] # `key_`
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
|
||||||
|
(kid, "0" * 32)
|
||||||
|
)
|
||||||
|
cek = cursor.fetchone()
|
||||||
|
if not cek:
|
||||||
|
return None
|
||||||
|
return cek[1]
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
# no table, no keys, simple
|
# no table, no keys, simple
|
||||||
return None
|
return None
|
||||||
c = self.adb.safe_execute(
|
|
||||||
self.ticket,
|
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
# TODO: SQL injection risk
|
|
||||||
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
|
|
||||||
["0" * 32]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for (kid, key_) in c.fetchall():
|
|
||||||
yield kid, key_
|
|
||||||
|
|
||||||
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
|
conn = self.con_pool.get()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
|
||||||
|
("0" * 32,)
|
||||||
|
)
|
||||||
|
for (kid, key_) in cursor.fetchall():
|
||||||
|
yield kid, key_
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
|
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||||
if not key or key.count("0") == len(key):
|
if not key or key.count("0") == len(key):
|
||||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||||
|
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
self.create_table(service, commit)
|
self.create_table(service)
|
||||||
|
|
||||||
if isinstance(kid, UUID):
|
if isinstance(kid, UUID):
|
||||||
kid = kid.hex
|
kid = kid.hex
|
||||||
|
|
||||||
if self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
|
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
|
||||||
[kid, key]
|
(kid, key)
|
||||||
)
|
)
|
||||||
).fetchone():
|
if cursor.fetchone():
|
||||||
# table already has this exact KID:KEY stored
|
# table already has this exact KID:KEY stored
|
||||||
return True
|
return True
|
||||||
|
cursor.execute(
|
||||||
self.adb.safe_execute(
|
|
||||||
self.ticket,
|
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
|
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
|
||||||
(kid, key)
|
(kid, key)
|
||||||
)
|
)
|
||||||
)
|
finally:
|
||||||
|
conn.commit()
|
||||||
if commit:
|
cursor.close()
|
||||||
self.commit()
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
|
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||||
for kid, key in kid_keys.items():
|
for kid, key in kid_keys.items():
|
||||||
if not key or key.count("0") == len(key):
|
if not key or key.count("0") == len(key):
|
||||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||||
|
|
||||||
if not self.has_table(service):
|
if not self.has_table(service):
|
||||||
self.create_table(service, commit)
|
self.create_table(service)
|
||||||
|
|
||||||
if not isinstance(kid_keys, dict):
|
if not isinstance(kid_keys, dict):
|
||||||
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
|
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
|
||||||
|
@ -111,47 +118,59 @@ class SQLite(Vault):
|
||||||
for kid, key_ in kid_keys.items()
|
for kid, key_ in kid_keys.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.executemany(
|
|
||||||
|
try:
|
||||||
|
cursor.executemany(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
|
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
|
||||||
kid_keys.items()
|
kid_keys.items()
|
||||||
)
|
)
|
||||||
)
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
if commit:
|
conn.commit()
|
||||||
self.commit()
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
return c.rowcount
|
|
||||||
|
|
||||||
def get_services(self) -> Iterator[str]:
|
def get_services(self) -> Iterator[str]:
|
||||||
c = self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
|
||||||
)
|
try:
|
||||||
for (name,) in c.fetchall():
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||||
if name != "sqlite_sequence":
|
for (name,) in cursor.fetchall():
|
||||||
yield Services.get_tag(name)
|
if name != "sqlite_sequence":
|
||||||
|
yield Services.get_tag(name)
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
def has_table(self, name: str) -> bool:
|
def has_table(self, name: str) -> bool:
|
||||||
"""Check if the Vault has a Table with the specified name."""
|
"""Check if the Vault has a Table with the specified name."""
|
||||||
return self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
|
|
||||||
[name]
|
|
||||||
)
|
|
||||||
).fetchone()[0] == 1
|
|
||||||
|
|
||||||
def create_table(self, name: str, commit: bool = False):
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
return cursor.fetchone()[0] == 1
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
|
def create_table(self, name: str):
|
||||||
"""Create a Table with the specified name if not yet created."""
|
"""Create a Table with the specified name if not yet created."""
|
||||||
if self.has_table(name):
|
if self.has_table(name):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.adb.safe_execute(
|
conn = self.con_pool.get()
|
||||||
self.ticket,
|
cursor = conn.cursor()
|
||||||
lambda db, cursor: cursor.execute(
|
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
# TODO: SQL injection risk
|
# TODO: SQL injection risk
|
||||||
f"""
|
f"""
|
||||||
CREATE TABLE IF NOT EXISTS {name} (
|
CREATE TABLE IF NOT EXISTS {name} (
|
||||||
|
@ -163,11 +182,33 @@ class SQLite(Vault):
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
finally:
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
self.con_pool.put(conn)
|
||||||
|
|
||||||
if commit:
|
|
||||||
self.commit()
|
|
||||||
|
|
||||||
def commit(self):
|
class ConnectionPool:
|
||||||
"""Commit any changes made that has not been written to db."""
|
def __init__(self, path: Union[str, Path], size: int):
|
||||||
self.adb.commit(self.ticket)
|
self._path = path
|
||||||
|
self._size = size
|
||||||
|
self._pool = Queue(self._size)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def _create_connection(self):
|
||||||
|
return sqlite3.connect(self._path)
|
||||||
|
|
||||||
|
def get(self) -> Connection:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self._pool.get(block=False)
|
||||||
|
except Empty:
|
||||||
|
with self._lock:
|
||||||
|
if self._pool.qsize() < self._size:
|
||||||
|
return self._create_connection()
|
||||||
|
else:
|
||||||
|
# pool full, wait before retrying
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def put(self, conn: Connection):
|
||||||
|
self._pool.put(conn)
|
||||||
|
|
Loading…
Reference in New Issue