app_config.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """YAML configuration loading."""
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. from typing import Any
  5. import yaml
  6. @dataclass(frozen=True)
  7. class DbConfig:
  8. host: str
  9. port: int
  10. database: str
  11. user: str
  12. password: str
  13. @dataclass(frozen=True)
  14. class ModbusConfig:
  15. host: str = "0.0.0.0"
  16. port: int = 502
  17. @dataclass(frozen=True)
  18. class HttpProviderConfig:
  19. url: str = "http://192.168.1.109:18503/data/get_points_real_value"
  20. timeout_seconds: int = 5
  21. interval: int = 5
  22. batch_size: int = 200
  23. @dataclass(frozen=True)
  24. class LoggingConfig:
  25. dir: str = "logs"
  26. retention_days: int = 3
  27. level: str = "INFO"
  28. @dataclass(frozen=True)
  29. class AppConfig:
  30. db: DbConfig
  31. modbus: ModbusConfig
  32. http_provider: HttpProviderConfig
  33. logging: LoggingConfig
  34. def load_config(path: str | Path) -> AppConfig:
  35. config_path = Path(path)
  36. with config_path.open("r", encoding="utf-8") as file:
  37. raw = yaml.safe_load(file) or {}
  38. db = _require_mapping(raw, "db")
  39. return AppConfig(
  40. db=DbConfig(
  41. host=str(_require(db, "host")),
  42. port=int(_require(db, "port")),
  43. database=str(_require(db, "database")),
  44. user=str(_require(db, "user")),
  45. password=str(_require(db, "password")),
  46. ),
  47. modbus=_load_modbus(raw.get("modbus") or {}),
  48. http_provider=_load_http_provider(raw.get("http_provider") or {}),
  49. logging=_load_logging(raw.get("logging") or {}),
  50. )
  51. def _load_modbus(raw: dict[str, Any]) -> ModbusConfig:
  52. return ModbusConfig(
  53. host=str(raw.get("host", "0.0.0.0")),
  54. port=int(raw.get("port", 502)),
  55. )
  56. def _load_http_provider(raw: dict[str, Any]) -> HttpProviderConfig:
  57. interval = int(raw.get("interval", 5))
  58. batch_size = int(raw.get("batch_size", 200))
  59. if interval <= 0:
  60. raise ValueError("配置错误: http_provider.interval 必须大于0")
  61. if batch_size <= 0:
  62. raise ValueError("配置错误: http_provider.batch_size 必须大于0")
  63. return HttpProviderConfig(
  64. url=str(raw.get("url", "http://192.168.1.109:18503/data/get_points_real_value")),
  65. timeout_seconds=int(raw.get("timeout_seconds", 5)),
  66. interval=interval,
  67. batch_size=batch_size,
  68. )
  69. def _load_logging(raw: dict[str, Any]) -> LoggingConfig:
  70. return LoggingConfig(
  71. dir=str(raw.get("dir", "logs")),
  72. retention_days=int(raw.get("retention_days", 3)),
  73. level=str(raw.get("level", "INFO")),
  74. )
  75. def _require_mapping(raw: dict[str, Any], key: str) -> dict[str, Any]:
  76. value = raw.get(key)
  77. if not isinstance(value, dict):
  78. raise ValueError(f"配置缺失或格式错误: {key}")
  79. return value
  80. def _require(raw: dict[str, Any], key: str) -> Any:
  81. if key not in raw or raw[key] is None:
  82. raise ValueError(f"配置缺失: {key}")
  83. return raw[key]