auth.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import os
  3. import time
  4. from datetime import datetime, timezone
  5. from typing import Any
  6. import requests
  7. from .db import read_sys_config_value, write_sys_config_value
  8. PROJECTS_CONFIG_KEY = "mcp_project_data_projects"
  9. PROJECT_TOKEN_CONFIG_KEY_PREFIX = "mcp_project_data_project_token:"
  10. PROJECT_AUTH_LOGIN_PATH = "/api/ai/auth/password_login"
  11. def _to_iso_utc(timestamp: float) -> str:
  12. return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat()
  13. def _request_timeout_seconds() -> int:
  14. return max(1, int(os.getenv("UPSTREAM_REQUEST_TIMEOUT", "60")))
  15. def _safe_int(raw_value: Any, field_name: str) -> int:
  16. try:
  17. return int(str(raw_value).strip())
  18. except Exception as exc:
  19. raise ValueError(f"invalid {field_name}: {raw_value}") from exc
  20. def _normalize_project_key(raw_value: Any) -> str:
  21. project_key = str(raw_value or "").strip()
  22. if not project_key:
  23. raise ValueError("project_key is required")
  24. return project_key
  25. def _coerce_bool(raw_value: Any, default: bool) -> bool:
  26. if raw_value is None:
  27. return default
  28. if isinstance(raw_value, bool):
  29. return raw_value
  30. text = str(raw_value).strip().lower()
  31. if text in {"1", "true", "yes", "y", "on"}:
  32. return True
  33. if text in {"0", "false", "no", "n", "off"}:
  34. return False
  35. return default
  36. def _request_json(
  37. method: str,
  38. url: str,
  39. authorization: str | None = None,
  40. *,
  41. json_payload: dict | None = None,
  42. ) -> Any:
  43. headers: dict[str, str] = {}
  44. token_text = str(authorization or "").strip()
  45. if token_text:
  46. headers["Authorization"] = token_text
  47. response = requests.request(
  48. method=method,
  49. url=url,
  50. headers=headers,
  51. json=json_payload,
  52. timeout=_request_timeout_seconds(),
  53. )
  54. try:
  55. payload = response.json()
  56. except ValueError as exc:
  57. text_preview = response.text[:300]
  58. raise ValueError(
  59. "upstream returned non-JSON response, "
  60. f"status={response.status_code}, body={text_preview}"
  61. ) from exc
  62. if response.status_code >= 400:
  63. raise ValueError(f"upstream HTTP {response.status_code}: {payload}")
  64. return payload
  65. def _project_token_config_key(project_key: str) -> str:
  66. return f"{PROJECT_TOKEN_CONFIG_KEY_PREFIX}{project_key}"
  67. def _load_projects_config() -> list[dict[str, Any]]:
  68. raw_value = read_sys_config_value(PROJECTS_CONFIG_KEY)
  69. if raw_value is None:
  70. raise ValueError("missing sys_config key: mcp_project_data_projects")
  71. if not isinstance(raw_value, list):
  72. raise ValueError("sys_config mcp_project_data_projects must be a JSON array")
  73. normalized: list[dict[str, Any]] = []
  74. latest_by_key: dict[str, dict[str, Any]] = {}
  75. for item in raw_value:
  76. if not isinstance(item, dict):
  77. continue
  78. project_key = _normalize_project_key(item.get("project_key"))
  79. project_name = str(item.get("project_name") or project_key).strip() or project_key
  80. base_url = str(item.get("base_url") or "").strip().rstrip("/")
  81. username = str(item.get("username") or "").strip()
  82. password = str(item.get("password") or "").strip()
  83. enabled = _coerce_bool(item.get("enabled"), True)
  84. if not base_url:
  85. raise ValueError(f"project '{project_key}' missing base_url")
  86. latest_by_key[project_key] = {
  87. "project_key": project_key,
  88. "project_name": project_name,
  89. "base_url": base_url,
  90. "username": username,
  91. "password": password,
  92. "enabled": enabled,
  93. }
  94. normalized.extend(latest_by_key.values())
  95. if not normalized:
  96. raise ValueError("mcp_project_data_projects has no valid project entries")
  97. return normalized
  98. def load_projects_config() -> list[dict[str, Any]]:
  99. return _load_projects_config()
  100. def find_project_config(project_key: str) -> dict[str, Any]:
  101. expected = _normalize_project_key(project_key)
  102. for item in _load_projects_config():
  103. if item["project_key"] != expected:
  104. continue
  105. if not item["enabled"]:
  106. raise ValueError(f"project '{expected}' is disabled")
  107. if not item["username"]:
  108. raise ValueError(f"project '{expected}' missing username")
  109. if not item["password"]:
  110. raise ValueError(f"project '{expected}' missing password")
  111. return item
  112. raise ValueError(f"project_key not found: {expected}")
  113. def _read_project_token_cache(project_key: str) -> dict[str, Any] | None:
  114. raw_value = read_sys_config_value(_project_token_config_key(project_key))
  115. if not isinstance(raw_value, dict):
  116. return None
  117. return raw_value
  118. def _write_project_token_cache(project_key: str, *, auth_token: str, expire_at: int) -> None:
  119. write_sys_config_value(
  120. _project_token_config_key(project_key),
  121. {
  122. "auth_token": auth_token,
  123. "expire_at": expire_at,
  124. "updated_at": _to_iso_utc(time.time()),
  125. },
  126. )
  127. def _login_project(project_cfg: dict[str, Any]) -> tuple[str, int]:
  128. payload = _request_json(
  129. "POST",
  130. f"{project_cfg['base_url']}{PROJECT_AUTH_LOGIN_PATH}",
  131. authorization=None,
  132. json_payload={
  133. "username": project_cfg["username"],
  134. "password": project_cfg["password"],
  135. },
  136. )
  137. if not isinstance(payload, dict):
  138. raise ValueError("login failed: invalid response payload")
  139. errcode = payload.get("errcode")
  140. if str(errcode) not in {"0", "0.0"}:
  141. message = str(payload.get("msg") or payload.get("message") or "").strip()
  142. raise ValueError(f"login failed: {message or payload}")
  143. auth_token = str(payload.get("token") or "").strip()
  144. if not auth_token:
  145. raise ValueError("login failed: missing token")
  146. expire_raw = payload.get("token_expire_time")
  147. if expire_raw is None:
  148. expire_at = int(time.time()) + 3600
  149. else:
  150. expire_at = _safe_int(expire_raw, "token_expire_time")
  151. if expire_at <= int(time.time()):
  152. expire_at = int(time.time()) + 3600
  153. return auth_token, expire_at
  154. def resolve_project_token(project_cfg: dict[str, Any]) -> str:
  155. project_key = project_cfg["project_key"]
  156. cached = _read_project_token_cache(project_key)
  157. if isinstance(cached, dict):
  158. token_text = str(cached.get("auth_token") or "").strip()
  159. expire_raw = cached.get("expire_at")
  160. if token_text and expire_raw is not None:
  161. try:
  162. expire_at = int(str(expire_raw).strip())
  163. if int(time.time()) < expire_at:
  164. return token_text
  165. except (TypeError, ValueError):
  166. pass
  167. auth_token, expire_at = _login_project(project_cfg)
  168. _write_project_token_cache(project_key, auth_token=auth_token, expire_at=expire_at)
  169. return auth_token