Remove AtomicSQL, use Connection Pools in Vaults

This allows the use of vaults in any thread, while keeping database-file level thread safety. AtomicSQL doesn't actually do anything that useful.
This commit is contained in:
rlaphoenix 2023-02-21 05:38:39 +00:00
parent 707469d252
commit c925cb8af9
6 changed files with 255 additions and 272 deletions

View File

@ -89,7 +89,7 @@ def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) -
log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}") log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}")
try: try:
added = to_vault.add_keys(service_, content_keys, commit=True) added = to_vault.add_keys(service_, content_keys)
except PermissionError: except PermissionError:
log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...") log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...")
continue continue
@ -171,7 +171,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None:
for vault in vaults_: for vault in vaults_:
log.info(f"Adding {total_count} Content Keys to {vault}") log.info(f"Adding {total_count} Content Keys to {vault}")
added_count = vault.add_keys(service, kid_keys, commit=True) added_count = vault.add_keys(service, kid_keys)
existed_count = total_count - added_count existed_count = total_count - added_count
log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)") log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)")

View File

@ -1,105 +0,0 @@
"""
AtomicSQL - Race-condition and Threading safe SQL Database Interface.
Copyright (C) 2020-2023 rlaphoenix
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sqlite3
import time
from threading import Lock
from typing import Any, Callable, Union
import pymysql.cursors
Connections = Union[sqlite3.Connection, pymysql.connections.Connection]
Cursors = Union[sqlite3.Cursor, pymysql.cursors.Cursor]
class AtomicSQL:
"""
Race-condition and Threading safe SQL Database Interface.
"""
def __init__(self) -> None:
self.master_lock = Lock() # prevents race condition
self.db: dict[bytes, Connections] = {} # used to hold the database connections and commit changes and such
self.cursor: dict[bytes, Cursors] = {} # used to execute queries and receive results
self.session_lock: dict[bytes, Lock] = {} # like master_lock, but per-session
def load(self, connection: Connections) -> bytes:
"""
Store SQL Connection object and return a reference ticket.
:param connection: SQLite3 or pymysql Connection object.
:returns: Session ID in which the database connection is referenced with.
"""
self.master_lock.acquire()
try:
# obtain a unique cryptographically random session_id
session_id = None
while not session_id or session_id in self.db:
session_id = os.urandom(16)
self.db[session_id] = connection
self.cursor[session_id] = self.db[session_id].cursor()
self.session_lock[session_id] = Lock()
return session_id
finally:
self.master_lock.release()
def safe_execute(self, session_id: bytes, action: Callable) -> Any:
"""
Execute code on the Database Connection in a race-condition safe way.
:param session_id: Database Connection's Session ID.
:param action: Function or lambda in which to execute, it's provided `db` and `cursor` arguments.
:returns: Whatever `action` returns.
"""
if session_id not in self.db:
raise ValueError(f"Session ID {session_id!r} is invalid.")
self.master_lock.acquire()
self.session_lock[session_id].acquire()
try:
failures = 0
while True:
try:
action(
db=self.db[session_id],
cursor=self.cursor[session_id]
)
break
except sqlite3.OperationalError as e:
failures += 1
delay = 3 * failures
print(f"AtomicSQL.safe_execute failed, {e}, retrying in {delay} seconds...")
time.sleep(delay)
if failures == 10:
raise ValueError("AtomicSQL.safe_execute failed too many time's. Aborting.")
return self.cursor[session_id]
finally:
self.session_lock[session_id].release()
self.master_lock.release()
def commit(self, session_id: bytes) -> bool:
"""
Commit changes to the Database Connection immediately.
This isn't necessary to be run every time you make changes, just ensure it's run
at least before termination.
:param session_id: Database Connection's Session ID.
:returns: True if it committed.
"""
self.safe_execute(
session_id,
lambda db, cursor: db.commit()
)
return True # todo ; actually check if db.commit worked

View File

@ -31,11 +31,11 @@ class Vault(metaclass=ABCMeta):
"""Get All Keys from Vault by Service.""" """Get All Keys from Vault by Service."""
@abstractmethod @abstractmethod
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool: def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
"""Add KID:KEY to the Vault.""" """Add KID:KEY to the Vault."""
@abstractmethod @abstractmethod
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int: def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
""" """
Add Multiple Content Keys with Key IDs for Service to the Vault. Add Multiple Content Keys with Key IDs for Service to the Vault.
Pre-existing Content Keys are ignored/skipped. Pre-existing Content Keys are ignored/skipped.

View File

@ -57,7 +57,7 @@ class Vaults:
for vault in self.vaults: for vault in self.vaults:
if vault != excluding: if vault != excluding:
try: try:
success += vault.add_key(self.service, kid, key, commit=True) success += vault.add_key(self.service, kid, key)
except (PermissionError, NotImplementedError): except (PermissionError, NotImplementedError):
pass pass
return success return success
@ -70,7 +70,7 @@ class Vaults:
success = 0 success = 0
for vault in self.vaults: for vault in self.vaults:
try: try:
success += bool(vault.add_keys(self.service, kid_keys, commit=True)) success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError): except (PermissionError, NotImplementedError):
pass pass
return success return success

View File

@ -1,5 +1,8 @@
from __future__ import annotations from __future__ import annotations
import time
from queue import Empty, Queue
from threading import Lock
from typing import Iterator, Optional, Union from typing import Iterator, Optional, Union
from uuid import UUID from uuid import UUID
@ -7,7 +10,6 @@ import pymysql
from pymysql.cursors import DictCursor from pymysql.cursors import DictCursor
from devine.core.services import Services from devine.core.services import Services
from devine.core.utils.atomicsql import AtomicSQL
from devine.core.vault import Vault from devine.core.vault import Vault
@ -21,15 +23,13 @@ class MySQL(Vault):
""" """
super().__init__(name) super().__init__(name)
self.slug = f"{host}:{database}:{username}" self.slug = f"{host}:{database}:{username}"
self.con = pymysql.connect( self.con_pool = ConnectionPool(dict(
host=host, host=host,
db=database, db=database,
user=username, user=username,
cursorclass=DictCursor, cursorclass=DictCursor,
**kwargs **kwargs
) ), 5)
self.adb = AtomicSQL()
self.ticket = self.adb.load(self.con)
self.permissions = self.get_permissions() self.permissions = self.get_permissions()
if not self.has_permission("SELECT"): if not self.has_permission("SELECT"):
@ -43,37 +43,44 @@ class MySQL(Vault):
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
try:
cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s", f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
[kid, "0" * 32] (kid, "0" * 32)
) )
).fetchone() cek = cursor.fetchone()
if not c: if not cek:
return None return None
return cek["key_"]
return c["key_"] finally:
cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]: def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service): if not self.has_table(service):
# no table, no keys, simple # no table, no keys, simple
return None return None
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
try:
cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s", f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
["0" * 32] ("0" * 32,)
) )
) for row in cursor.fetchall():
yield row["kid"], row["key_"]
finally:
cursor.close()
self.con_pool.put(conn)
for row in c.fetchall(): def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
yield row["kid"], row["key_"]
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
if not key or key.count("0") == len(key): if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.") raise ValueError("You cannot add a NULL Content Key to a Vault.")
@ -82,39 +89,38 @@ class MySQL(Vault):
if not self.has_table(service): if not self.has_table(service):
try: try:
self.create_table(service, commit) self.create_table(service)
except PermissionError: except PermissionError:
return False return False
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
if self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
try:
cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s", f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
[kid, key] (kid, key)
) )
).fetchone(): if cursor.fetchone():
# table already has this exact KID:KEY stored # table already has this exact KID:KEY stored
return True return True
cursor.execute(
self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)", f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
(kid, key) (kid, key)
) )
) finally:
conn.commit()
if commit: cursor.close()
self.commit() self.con_pool.put(conn)
return True return True
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int: def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items(): for kid, key in kid_keys.items():
if not key or key.count("0") == len(key): if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.") raise ValueError("You cannot add a NULL Content Key to a Vault.")
@ -124,7 +130,7 @@ class MySQL(Vault):
if not self.has_table(service): if not self.has_table(service):
try: try:
self.create_table(service, commit) self.create_table(service)
except PermissionError: except PermissionError:
return 0 return 0
@ -139,40 +145,50 @@ class MySQL(Vault):
for kid, key_ in kid_keys.items() for kid, key_ in kid_keys.items()
} }
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.executemany(
try:
cursor.executemany(
# TODO: SQL injection risk # TODO: SQL injection risk
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)", f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
kid_keys.items() kid_keys.items()
) )
) return cursor.rowcount
finally:
if commit: conn.commit()
self.commit() cursor.close()
self.con_pool.put(conn)
return c.rowcount
def get_services(self) -> Iterator[str]: def get_services(self) -> Iterator[str]:
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute("SHOW TABLES")
) try:
for table in c.fetchall(): cursor.execute("SHOW TABLES")
# each entry has a key named `Tables_in_<db name>` for table in cursor.fetchall():
yield Services.get_tag(list(table.values())[0]) # each entry has a key named `Tables_in_<db name>`
yield Services.get_tag(list(table.values())[0])
finally:
cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool: def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name.""" """Check if the Vault has a Table with the specified name."""
return list(self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
[self.con.db, name]
)
).fetchone().values())[0] == 1
def create_table(self, name: str, commit: bool = False): try:
cursor.execute(
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
(conn.db, name)
)
return list(cursor.fetchone().values())[0] == 1
finally:
cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created.""" """Create a Table with the specified name if not yet created."""
if self.has_table(name): if self.has_table(name):
return return
@ -180,9 +196,11 @@ class MySQL(Vault):
if not self.has_permission("CREATE"): if not self.has_permission("CREATE"):
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.") raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
try:
cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f""" f"""
CREATE TABLE IF NOT EXISTS {name} ( CREATE TABLE IF NOT EXISTS {name} (
@ -193,23 +211,30 @@ class MySQL(Vault):
); );
""" """
) )
) finally:
conn.commit()
if commit: cursor.close()
self.commit() self.con_pool.put(conn)
def get_permissions(self) -> list: def get_permissions(self) -> list:
"""Get and parse Grants to a more easily usable list tuple array.""" """Get and parse Grants to a more easily usable list tuple array."""
with self.con.cursor() as c: conn = self.con_pool.get()
c.execute("SHOW GRANTS") cursor = conn.cursor()
grants = c.fetchall()
try:
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
grants = [next(iter(x.values())) for x in grants] grants = [next(iter(x.values())) for x in grants]
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)] grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
grants = [( grants = [(
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))), list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
location.replace("`", "").split(".") location.replace("`", "").split(".")
) for perms, location in grants] ) for perms, location in grants]
return grants return grants
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool: def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
"""Check if the current connection has a specific permission.""" """Check if the current connection has a specific permission."""
@ -220,6 +245,28 @@ class MySQL(Vault):
grants = [x for x in grants if x[1][1] in (table, "*")] grants = [x for x in grants if x[1][1] in (table, "*")]
return bool(grants) return bool(grants)
def commit(self):
"""Commit any changes made that has not been written to db.""" class ConnectionPool:
self.adb.commit(self.ticket) def __init__(self, con: dict, size: int):
self._con = con
self._size = size
self._pool = Queue(self._size)
self._lock = Lock()
def _create_connection(self):
return pymysql.connect(**self._con)
def get(self) -> pymysql.Connection:
while True:
try:
return self._pool.get(block=False)
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: pymysql.Connection):
self._pool.put(conn)

View File

@ -1,12 +1,15 @@
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
import time
from pathlib import Path from pathlib import Path
from queue import Empty, Queue
from sqlite3 import Connection
from threading import Lock
from typing import Iterator, Optional, Union from typing import Iterator, Optional, Union
from uuid import UUID from uuid import UUID
from devine.core.services import Services from devine.core.services import Services
from devine.core.utils.atomicsql import AtomicSQL
from devine.core.vault import Vault from devine.core.vault import Vault
@ -17,9 +20,7 @@ class SQLite(Vault):
super().__init__(name) super().__init__(name)
self.path = Path(path).expanduser() self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict? # TODO: Use a DictCursor or such to get fetches as dict?
self.con = sqlite3.connect(self.path) self.con_pool = ConnectionPool(self.path, 5)
self.adb = AtomicSQL()
self.ticket = self.adb.load(self.con)
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]: def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
if not self.has_table(service): if not self.has_table(service):
@ -29,76 +30,82 @@ class SQLite(Vault):
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
# TODO: SQL injection risk
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
[kid, "0" * 32]
)
).fetchone()
if not c:
return None
return c[1] # `key_` try:
cursor.execute(
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
(kid, "0" * 32)
)
cek = cursor.fetchone()
if not cek:
return None
return cek[1]
finally:
cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]: def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service): if not self.has_table(service):
# no table, no keys, simple # no table, no keys, simple
return None return None
c = self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
# TODO: SQL injection risk
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
["0" * 32]
)
)
for (kid, key_) in c.fetchall():
yield kid, key_
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool: conn = self.con_pool.get()
cursor = conn.cursor()
try:
cursor.execute(
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
("0" * 32,)
)
for (kid, key_) in cursor.fetchall():
yield kid, key_
finally:
cursor.close()
self.con_pool.put(conn)
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key): if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.") raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_table(service): if not self.has_table(service):
self.create_table(service, commit) self.create_table(service)
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
if self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
try:
cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?", f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
[kid, key] (kid, key)
) )
).fetchone(): if cursor.fetchone():
# table already has this exact KID:KEY stored # table already has this exact KID:KEY stored
return True return True
cursor.execute(
self.adb.safe_execute(
self.ticket,
lambda db, cursor: cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)", f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
(kid, key) (kid, key)
) )
) finally:
conn.commit()
if commit: cursor.close()
self.commit() self.con_pool.put(conn)
return True return True
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int: def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items(): for kid, key in kid_keys.items():
if not key or key.count("0") == len(key): if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.") raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_table(service): if not self.has_table(service):
self.create_table(service, commit) self.create_table(service)
if not isinstance(kid_keys, dict): if not isinstance(kid_keys, dict):
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}") raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
@ -111,47 +118,59 @@ class SQLite(Vault):
for kid, key_ in kid_keys.items() for kid, key_ in kid_keys.items()
} }
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.executemany(
try:
cursor.executemany(
# TODO: SQL injection risk # TODO: SQL injection risk
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)", f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
kid_keys.items() kid_keys.items()
) )
) return cursor.rowcount
finally:
if commit: conn.commit()
self.commit() cursor.close()
self.con_pool.put(conn)
return c.rowcount
def get_services(self) -> Iterator[str]: def get_services(self) -> Iterator[str]:
c = self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
) try:
for (name,) in c.fetchall(): cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
if name != "sqlite_sequence": for (name,) in cursor.fetchall():
yield Services.get_tag(name) if name != "sqlite_sequence":
yield Services.get_tag(name)
finally:
cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool: def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name.""" """Check if the Vault has a Table with the specified name."""
return self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
[name]
)
).fetchone()[0] == 1
def create_table(self, name: str, commit: bool = False): try:
cursor.execute(
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
(name,)
)
return cursor.fetchone()[0] == 1
finally:
cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created.""" """Create a Table with the specified name if not yet created."""
if self.has_table(name): if self.has_table(name):
return return
self.adb.safe_execute( conn = self.con_pool.get()
self.ticket, cursor = conn.cursor()
lambda db, cursor: cursor.execute(
try:
cursor.execute(
# TODO: SQL injection risk # TODO: SQL injection risk
f""" f"""
CREATE TABLE IF NOT EXISTS {name} ( CREATE TABLE IF NOT EXISTS {name} (
@ -163,11 +182,33 @@ class SQLite(Vault):
); );
""" """
) )
) finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
if commit:
self.commit()
def commit(self): class ConnectionPool:
"""Commit any changes made that has not been written to db.""" def __init__(self, path: Union[str, Path], size: int):
self.adb.commit(self.ticket) self._path = path
self._size = size
self._pool = Queue(self._size)
self._lock = Lock()
def _create_connection(self):
return sqlite3.connect(self._path)
def get(self) -> Connection:
while True:
try:
return self._pool.get(block=False)
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: Connection):
self._pool.put(conn)