Bootstrap
Committed 893cfc
index 0000000..f209c4c
--- /dev/null
+poetry.lock
+.coverage
+.test_coverage.xml
+.test_results.xml
+_test_data
index 0000000..3de51a6
--- /dev/null
+[tool.poetry]
+name = "sqlyte"
+version = "0.1.1"
+description = "a simple SQLite interface"
+keywords = ["SQLite"]
+authors = ["Angelo Gladding <angelo@ragt.ag>"]
+license = "BSD-2-Clause"
+
+[tool.poetry.dependencies]
+python = ">=3.10,<3.11"
+pendulum = "^2.1.2"
+
+[tool.poetry.group.dev.dependencies]
+gmpg = {path="../gmpg", develop=true}
+
+# [[tool.poetry.source]]
+# name = "main"
+# url = "https://ragt.ag/code/pypi"
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
index 0000000..ed8fdde
--- /dev/null
+{
+ "reportGeneralTypeIssues": false
+}
index 0000000..f011201
--- /dev/null
+"""
+A simple SQLite interface.
+
+[SQLite](https://sqlite.org) is a C-language library that implements a small,
+fast, self-contained, high-reliability, full-featured, SQL database engine.
+In contrast to many other database management systems, SQLite is not a
+client–server database engine. Rather, it is embedded into the end program.
+
+[sqlite3](https://docs.python.org/3/library/sqlite3) is a Python interface to
+SQLite.
+
+`sqlyte` provides an opinionated interface on top of `sqlite3`.
+
+"""
+
+# TODO when a new table is created, ask to import from table no longer in use
+# -- deprecate no longer used tables if no longer needed
+
+import contextlib
+import datetime
+import decimal
+import functools
+import json
+import logging
+import os
+import pathlib
+
+import pendulum
+
+try:
+ from pysqlite3 import dbapi2 as sqlite3
+except (ImportError, NameError): # NOTE pysqlite3.dbapi2 raises NameError
+ logging.info("falling back to sqlite3 in the standard library")
+ import sqlite3
+
+__all__ = ["db"]
+
+
+# TODO register and handle JSON type
+
+
+def from_datetime(val):
+ if isinstance(val, datetime.datetime):
+ return pendulum.instance(val)
+ val = val.decode("utf-8")
+ # remove timezone column
+ if val[-6] in "-+":
+ val = "".join(val.rpartition(":")[::2])
+ return pendulum.parse(val)
+
+
+sqlite3.register_converter("DATETIME", from_datetime)
+sqlite3.register_adapter(pendulum.DateTime, lambda val: val.isoformat(" "))
+
+
+def from_json(val):
+ # TODO traverse looking for nested published/updated
+ def f(dct):
+ def upgrade_date(key):
+ if key not in dct:
+ return
+ item = dct[key]
+ tz = None
+ # XXX print(type(item))
+ if isinstance(item, dict):
+ val = item["datetime"]
+ tz = item["timezone"]
+ elif isinstance(item, list):
+ if isinstance(item[0], dict):
+ val = item[0]["datetime"]
+ tz = item[0]["timezone"]
+ else:
+ val = item[0]
+ else:
+ val = item
+ if not val:
+ return
+ if val[-6] in "-+":
+ val = "".join(val.rpartition(":")[::2])
+ try:
+ dt = pendulum.parse(val.strip())
+ except pendulum.exceptions.ParserError:
+ dt = "?"
+ else:
+ if tz:
+ try:
+ dt = dt.in_timezone(tz)
+ except pendulum.tz.zoneinfo.exceptions.InvalidTimezone:
+ # XXX print("tz conversion failing silently...") # TODO use log
+ pass
+ dct[key] = [dt]
+
+ upgrade_date("published")
+ upgrade_date("updated")
+ return dct
+
+ return json.loads(val, object_hook=f)
+
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ # XXX if isinstance(obj, uri.URI):
+ # XXX return str(obj)
+ if isinstance(obj, (datetime.date, datetime.datetime)):
+ obj = pendulum.instance(obj)
+ return {
+ "datetime": obj.in_timezone("UTC").isoformat(),
+ "timezone": obj.tzinfo.name,
+ }
+ return json.JSONEncoder.default(self, obj)
+
+
+sqlite3.register_converter("JSON", from_json)
+sqlite3.register_adapter(dict, lambda val: JSONEncoder(indent=2).encode(val))
+
+
+def dttz_to_iso(raw_dttz):
+ """
+ Return the ISO format of given `raw_dttz`.
+
+ `raw_dttz` is a JSON string in the form of {"datetime": ..., "timezone": ...}).
+ """
+ dttz = json.loads(raw_dttz)
+ return (
+ pendulum.parse(dttz["datetime"])
+ .astimezone(pendulum.timezone(dttz["timezone"]))
+ .isoformat()
+ )
+
+
+class Model:
+ def __init__(self, name, **schemas):
+ self.name = name
+ self.schemas = schemas
+ self.version = 0
+ self.migrations = {}
+ self.controllers = {}
+
+ def migrate(self, version):
+ if version > self.version:
+ self.version = version
+
+ def add_migration(f):
+ self.migrations[version] = f
+ return f
+
+ return add_migration
+
+ def control(self, controller):
+ self.controllers[controller.__name__] = controller
+ return controller
+
+ def __call__(self, db):
+ return ModelController(self.controllers, db)
+
+
+class ModelController:
+ def __init__(self, controllers, db):
+ self.controllers = controllers
+ self.db = db
+
+ def __getattr__(self, attr):
+ # XXX TODO print(self.controllers, self.db)
+ return functools.partial(self.controllers[attr], self.db)
+
+
+model = Model
+
+
+def ors(item, values, fuzzy=False):
+ template = "{} LIKE '{}%'" if fuzzy else "{} = '{}'"
+ return " OR ".join(template.format(item, v) for v in values)
+
+
+def adapt(x):
+ return x
+
+
+def get_icu():
+ current_dir = pathlib.Path(__file__).parent
+ icuext_path = current_dir / "libsqliteicu.so"
+ if not icuext_path.exists():
+ icuext_source_path = current_dir / "icu.c"
+ os.system(
+ f"gcc -fPIC -shared {icuext_source_path} "
+ f"`pkg-config --libs icu-i18n` -o {icuext_path}"
+ )
+ return icuext_path
+
+
+class Database:
+ """"""
+
+ IntegrityError = sqlite3.IntegrityError
+ OperationalError = sqlite3.OperationalError
+ ProgrammingError = sqlite3.ProgrammingError
+
+ def __init__(self, path):
+ self.path = path
+ for command in (
+ "pragma",
+ "create",
+ "rename_table",
+ "drop",
+ "insert",
+ "replace",
+ "select",
+ "update",
+ "delete",
+ "columns",
+ "add_column",
+ "drop_column",
+ "rename_column",
+ ):
+
+ def single_statement_cursor(command):
+ @functools.wraps(getattr(Cursor, command))
+ def proxy(_self, *args, **kwargs):
+ with self.transaction as cur:
+ return getattr(cur, command)(_self, *args, **kwargs)
+
+ return proxy
+
+ setattr(self, command, single_statement_cursor(command))
+
+ conn = sqlite3.connect(
+ path, detect_types=sqlite3.PARSE_DECLTYPES, isolation_level=None
+ )
+ conn.execute("PRAGMA journal_mode=WAL")
+ conn.create_function("dttz_to_iso", 1, dttz_to_iso)
+ # TODO conn.cursor().execute("PRAGMA foreign_keys = ON;")
+
+ # TODO try:
+ # TODO conn.enable_load_extension(True)
+ # TODO except AttributeError:
+ # TODO pass
+ # TODO else:
+ # TODO icuext_path = get_icu()
+ # TODO try:
+ # TODO conn.load_extension(str(icuext_path))
+ # TODO except sqlite3.OperationalError:
+ # TODO pass # TODO make ICU available for all platforms
+ # TODO else:
+ # TODO conn.enable_load_extension(False)
+ # TODO conn.execute("SELECT icu_load_collation('en_US', 'UNICODE');")
+
+ conn.row_factory = sqlite3.Row
+ # conn.execute("PRAGMA user_version")
+ self.conn = conn
+
+ self.debug = False
+
+ def __repr__(self):
+ return f"sql.db: {self.path}"
+
+ # XXX def define(self, table, **schema):
+ # XXX """define multiple tables at once migrating them if necessary"""
+ # XXX # TODO bump version a la "PRAGMA user_version = 1;" and store change
+ # XXX # TODO store backups
+ # XXX try:
+ # XXX self.create(
+ # XXX table,
+ # XXX ", ".join(
+ # XXX f"{row} {definition}" for row, definition in list(schema.items())
+ # XXX ),
+ # XXX )
+ # XXX except self.OperationalError:
+ # XXX pass
+ # XXX # while table_schemas:
+ # XXX # for table, schema in list(table_schemas.items()):
+ # XXX # print(table)
+ # XXX # import textwrap
+ # XXX # print(textwrap.dedent(schema))
+ # XXX # table_schemas.pop(table)
+ # XXX # new_table = "new_{}".format(table)
+ # XXX # self.create(table, schema)
+ # XXX # self.create(new_table, schema)
+ # XXX # with self.transaction as cur:
+ # XXX # old_columns = cur.columns(table)
+ # XXX # new_columns = cur.columns(new_table)
+ # XXX # if old_columns == new_columns:
+ # XXX # cur.drop(new_table)
+ # XXX # continue
+ # XXX # old_names = {col[0] for col in old_columns}
+ # XXX # new_names = {col[0] for col in new_columns}
+ # XXX # cols = list(old_names.intersection(new_names))
+ # XXX # print("Migrating table `{}`..".format(table), end=" ")
+ # XXX # for row in cur.select(table, what=", ".join(cols)):
+ # XXX # cur.insert(new_table, dict(zip(cols, list(row))))
+ # XXX # cur.drop(table)
+ # XXX # cur.cur.execute(f"""ALTER TABLE {new_table}
+ # XXX # RENAME TO {table}""")
+ # XXX # print("success")
+
+ @property
+ def tables(self):
+ return [
+ r[0]
+ for r in self.select("sqlite_master", what="name", where="type='table'")
+ ]
+
+ @property
+ @contextlib.contextmanager
+ def transaction(self):
+ """
+ enter a transaction context and return its `Cursor`
+
+ >>> with Database().transaction as cur: # doctest: +SKIP
+ ... cur.insert(...)
+ ... cur.insert(...)
+ ... cur.select(...)
+ ... cur.insert(...)
+
+ """
+ # TODO log transaction begin, complete, etc..
+ with self.conn:
+ cursor = Cursor(self.conn.cursor())
+ cursor.debug = self.debug
+ yield cursor
+ # with sqlite3.connect(self.path,
+ # detect_types=sqlite3.PARSE_DECLTYPES) as conn:
+ # # conn.cursor().execute("PRAGMA foreign_keys = ON;")
+ # conn.enable_load_extension(True)
+ # icuext_path = pathlib.Path(__file__).parent / "libsqliteicu"
+ # conn.load_extension(str(icuext_path))
+ # conn.enable_load_extension(False)
+ # conn.execute("SELECT icu_load_collation('en_US', 'UNICODE');")
+ # conn.row_factory = sqlite3.Row
+ # cursor = Cursor(conn.cursor())
+ # cursor.debug = self.debug
+ # yield cursor
+
+ def destroy(self):
+ pathlib.Path(self.path).unlink()
+
+
+def db(path, *models) -> Database:
+ """
+ return a connection to a `SQLite` database
+
+ Database supplied by given `path` or in environment variable `$SQLDB`.
+
+ Note: `table_schemas` should not include a table (dict key) named "path".
+
+ """
+ # XXX if not path:
+ # XXX path = os.environ.get("SQLDB", None)
+ # XXX if path:
+
+ dbi = Database(path)
+ current_models = {}
+ try:
+ dbi.create("_models", "name TEXT, version INTEGER")
+ except dbi.OperationalError:
+ for model in dbi.select("_models"):
+ current_models[model["name"]] = model["version"]
+ for model in models:
+ try:
+ current_version = current_models[model.name]
+ except KeyError: # doesn't exist, create all tables in model
+ for table, schema in model.schemas.items():
+ fts = schema.get("FTS", False)
+ if fts:
+ dbi.create(
+ table,
+ ", ".join(f"{col}" for col in schema),
+ fts=True,
+ )
+ else:
+ dbi.create(
+ table,
+ ", ".join(
+ f"{col} {definition}" for col, definition in schema.items()
+ ),
+ fts=False,
+ )
+ dbi.insert("_models", name=model.name, version=model.version)
+ current_models[model.name] = model.version
+ continue
+ # XXX TODO print("V", current_version, model.version)
+ if current_version == model.version:
+ # TODO check schema in code against schema in db, suggest migration
+ continue # model exists and is up-to-date
+ elif current_version > model.version:
+ raise Exception("Your database version is ahead of your software version.")
+ elif current_version < model.version:
+ for migration in range(current_version + 1, model.version + 1):
+ model.migrations[migration](dbi)
+ dbi.update(
+ "_models", where="name = ?", vals=[model.name], version=model.version
+ )
+ return dbi
+
+
+class Cursor:
+
+ """"""
+
+ IntegrityError = sqlite3.IntegrityError
+ OperationalError = sqlite3.OperationalError
+
+ def __init__(self, cur):
+ self.cur = cur
+ self.debug = False
+
+ def pragma(self, command, value=None):
+ if value is None:
+ self.cur.execute(f"PRAGMA {command}")
+ return self.cur.fetchone()[command]
+ self.cur.execute(f"PRAGMA {command} = {value}")
+
+ def create(self, table, schema, fts=False):
+ """
+ create a table with given column schema
+
+ """
+ if fts:
+ sql = f"CREATE VIRTUAL TABLE {table} USING fts5 ({schema})"
+ else:
+ sql = f"CREATE TABLE {table} ({schema})"
+ self.cur.execute(sql)
+
+ def rename_table(self, table, new_table):
+ """Rename a table."""
+ self.cur.execute(f"ALTER TABLE {table} RENAME TO {new_table}")
+
+ def drop(self, *tables):
+ """
+ drop one or more tables
+
+ """
+ for table in tables:
+ self.cur.execute(f"DROP TABLE {table}")
+
+ def insert(self, table, *records, _force=False, **record):
+ return self._insert("insert", table, *records, _force=False, **record)
+
+ def replace(self, table, *records, _force=False, **record):
+ return self._insert("replace", table, *records, _force=False, **record)
+
+ def _insert(self, operation, table, *records, _force=False, **record):
+ """Insert one or more records into given table."""
+ if record:
+ if records:
+ raise TypeError("use `record` *or* `records` not both")
+ records = (record,) # XXX += (record,)
+ values = []
+ for record in records:
+ for column, val in record.items():
+ if isinstance(val, dict):
+ record[column] = JSONEncoder().encode(val)
+ elif isinstance(val, decimal.Decimal):
+ record[column] = float(val)
+ columns, vals = zip(*record.items())
+ values.append(vals)
+ sql = "{} INTO {}({})".format(
+ operation.upper(), table, ", ".join(columns)
+ ) + " VALUES ({})".format(", ".join("?" * len(vals)))
+ if self.debug:
+ print(sql)
+ try:
+ if len(values) == 1:
+ self.cur.execute(sql, vals)
+ self.cur.execute("SELECT last_insert_rowid()")
+ return self.cur.fetchone()[0]
+ else:
+ self.cur.executemany(sql, values)
+ except sqlite3.IntegrityError as err:
+ if not _force:
+ raise err
+
+ def select(
+ self,
+ table,
+ what="*",
+ where=None,
+ order=None,
+ group=None,
+ join=None,
+ limit=None,
+ offset=None,
+ vals=None,
+ ):
+ """
+ select records from one or more tables
+
+ """
+ sql = self._select_sql(
+ table,
+ what=what,
+ where=where,
+ order=order,
+ group=group,
+ join=join,
+ limit=limit,
+ offset=offset,
+ vals=vals,
+ )[1:-1]
+ if self.debug:
+ print(sql)
+ if vals:
+ print(" ", vals)
+ if vals:
+ self.cur.execute(sql, vals)
+ else:
+ self.cur.execute(sql)
+
+ class Results:
+ def __init__(innerself, results):
+ innerself.results = list(results)
+
+ def pop(innerself, index):
+ return innerself.results.pop(index)
+
+ def __getitem__(innerself, index):
+ return innerself.results[index]
+
+ def __len__(innerself):
+ return len(innerself.results)
+
+ def _repr_html_(innerself):
+ results = "<tr>"
+ types = {}
+ for column in self.columns(table):
+ types[column[0]] = column[1]
+ results += f"<td>{column[0]} " f"<small>{column[1]}</small></td>"
+ results += "</tr>"
+ for result in innerself.results:
+ results += "<tr>"
+ for key, value in dict(result).items():
+ if types[key] == "JSON":
+ encoded_json = JSONEncoder(indent=2).encode(value)
+ value = solarized.highlight(encoded_json, ".json")
+ results += f"<td>{value}</td>"
+ results += "</tr>"
+ return f"<table>{results}</table>"
+
+ return Results(self.cur.fetchall())
+
+ def _select_sql(
+ self,
+ table,
+ what="*",
+ where=None,
+ order=None,
+ group=None,
+ join=None,
+ limit=None,
+ offset=None,
+ vals=None,
+ suffix="",
+ ):
+ sql_parts = ["SELECT {}".format(what), "FROM {}".format(table)]
+ if join:
+ if not isinstance(join, (list, tuple)):
+ join = [join]
+ for join_statement in join:
+ sql_parts.append("LEFT JOIN {}".format(join_statement))
+ if where:
+ # if vals:
+ # where = where.format(*[str(adapt(v)) for v in vals])
+ sql_parts.append("WHERE {}".format(where))
+ if group:
+ sql_parts.append("GROUP BY {}".format(group))
+ if order:
+ sql_parts.append("ORDER BY {}".format(order))
+ if limit:
+ limitsql = "LIMIT {}".format(limit)
+ if offset:
+ limitsql += " {}".format(offset)
+ sql_parts.append(limitsql)
+ return "({}) {}".format("\n".join(sql_parts), suffix).rstrip()
+
+ def update(self, table, what=None, where=None, vals=None, **record):
+ """
+ update one or more records
+
+ Use `what` *or* `record`.
+
+ """
+ sql_parts = ["UPDATE {}".format(table)]
+ if what:
+ what_sql = what
+ else:
+ keys, values = zip(*record.items())
+ what_sql = ", ".join("{}=?".format(key) for key in keys)
+ if vals is None:
+ vals = []
+ vals = list(values) + vals
+ sql_parts.append("SET {}".format(what_sql))
+ if where:
+ sql_parts.append("WHERE {}".format(where))
+ sql = "\n".join(sql_parts)
+ if self.debug:
+ print(sql)
+ if vals:
+ print(vals)
+ if vals:
+ self.cur.execute(sql, vals)
+ else:
+ self.cur.execute(sql)
+
+ def delete(self, table, where=None, vals=None):
+ """
+ delete one or more records
+
+ """
+ sql_parts = ["DELETE FROM {}".format(table)]
+ if where:
+ sql_parts.append("WHERE {}".format(where))
+ sql = "\n".join(sql_parts)
+ if vals:
+ self.cur.execute(sql, vals)
+ else:
+ self.cur.execute(sql)
+
+ def columns(self, table):
+ """Return columns for given table."""
+ return [
+ list(column)[1:]
+ for column in self.cur.execute("PRAGMA table_info({})".format(table))
+ ]
+
+ def add_column(self, table, column_def):
+ """Add a column to given table."""
+ self.cur.execute(f"ALTER TABLE {table} ADD COLUMN {column_def}")
+
+ def drop_column(self, table, column):
+ """Add a column to given table."""
+ self.cur.execute(f"ALTER TABLE {table} DROP COLUMN {column}")
+
+ def rename_column(self, table, column, new_column):
+ """Rename a column of given table."""
+ self.cur.execute(f"ALTER TABLE {table} RENAME COLUMN {column} TO {new_column}")
index 0000000..74a8366
--- /dev/null
+import pathlib
+import shutil
+
+import sqlyte
+
+
+def setup_module(module):
+ shutil.rmtree("_test_data", ignore_errors=True)
+ pathlib.Path("_test_data").mkdir()
+
+
+def test_json():
+ db = sqlyte.db("_test_data/1.db", sqlyte.Model("test", testing={"name": "TEXT"}))
+ db.insert("testing", name="foo")
+ db.insert("testing", name="bar")
+
+
+def test_transaction():
+ db = sqlyte.db("_test_data/2.db", sqlyte.Model("test", testing={"name": "TEXT"}))
+ with db.conn:
+ cursor = sqlyte.Cursor(db.conn.cursor())
+ cursor.insert("testing", name="bar")