db.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import annotations
  2. import os
  3. from typing import Any
  4. from sqlalchemy import Integer, JSON, String, create_engine, select
  5. from sqlalchemy.engine import Engine
  6. from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
  7. def normalize_database_url(raw: str) -> str:
  8. return raw.strip()
  9. class Base(DeclarativeBase):
  10. pass
  11. class SysConfig(Base):
  12. __tablename__ = "sys_config"
  13. id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
  14. key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False)
  15. value: Mapped[Any] = mapped_column(JSON, nullable=False)
  16. _ENGINE: Engine | None = None
  17. _ENGINE_URL: str | None = None
  18. def database_url() -> str:
  19. raw = str(os.getenv("DATABASE_URL") or "sqlite:///llm_proxy.db").strip()
  20. return normalize_database_url(raw)
  21. def sql_engine() -> Engine:
  22. global _ENGINE, _ENGINE_URL
  23. effective_url = database_url()
  24. if _ENGINE is not None and _ENGINE_URL == effective_url:
  25. return _ENGINE
  26. engine_options: dict[str, Any] = {}
  27. connect_args: dict[str, Any] = {}
  28. if effective_url.startswith("sqlite:"):
  29. connect_args["check_same_thread"] = False
  30. if connect_args:
  31. engine_options["connect_args"] = connect_args
  32. _ENGINE = create_engine(effective_url, future=True, **engine_options)
  33. _ENGINE_URL = effective_url
  34. return _ENGINE
  35. def read_sys_config_value(config_key: str) -> Any | None:
  36. try:
  37. with Session(sql_engine()) as session:
  38. return session.scalar(select(SysConfig.value).where(SysConfig.key == config_key))
  39. except Exception as exc:
  40. raise ValueError(f"failed to read sys_config key={config_key}: {exc}") from exc
  41. def write_sys_config_value(config_key: str, value: Any) -> None:
  42. try:
  43. with Session(sql_engine()) as session:
  44. row = session.scalar(select(SysConfig).where(SysConfig.key == config_key))
  45. if row is None:
  46. session.add(SysConfig(key=config_key, value=value))
  47. else:
  48. row.value = value
  49. session.commit()
  50. except Exception as exc:
  51. raise ValueError(f"failed to write sys_config key={config_key}: {exc}") from exc