Explorar o código

修改配置文件参数

Lu Xianghui hai 3 días
pai
achega
266af6fa37
Modificáronse 7 ficheiros con 48 adicións e 28 borrados
  1. 10 2
      app_config.py
  2. 2 1
      config.yaml
  3. 4 4
      main.py
  4. 16 8
      tests/compare_modbus_http_client.py
  5. 7 4
      tests/test_app_config.py
  6. 5 5
      tests/test_value_refresh.py
  7. 4 4
      value_refresh.py

+ 10 - 2
app_config.py

@@ -20,13 +20,14 @@ class DbConfig:
 class ModbusConfig:
     host: str = "0.0.0.0"
     port: int = 502
-    interval: int = 5
 
 
 @dataclass(frozen=True)
 class HttpProviderConfig:
     url: str = "http://192.168.1.109:18503/data/get_points_real_value"
     timeout_seconds: int = 5
+    interval: int = 5
+    batch_size: int = 200
 
 
 @dataclass(frozen=True)
@@ -68,14 +69,21 @@ def _load_modbus(raw: dict[str, Any]) -> ModbusConfig:
     return ModbusConfig(
         host=str(raw.get("host", "0.0.0.0")),
         port=int(raw.get("port", 502)),
-        interval=int(raw.get("interval", 5)),
     )
 
 
 def _load_http_provider(raw: dict[str, Any]) -> HttpProviderConfig:
+    interval = int(raw.get("interval", 5))
+    batch_size = int(raw.get("batch_size", 200))
+    if interval <= 0:
+        raise ValueError("配置错误: http_provider.interval 必须大于0")
+    if batch_size <= 0:
+        raise ValueError("配置错误: http_provider.batch_size 必须大于0")
     return HttpProviderConfig(
         url=str(raw.get("url", "http://192.168.1.109:18503/data/get_points_real_value")),
         timeout_seconds=int(raw.get("timeout_seconds", 5)),
+        interval=interval,
+        batch_size=batch_size,
     )
 
 

+ 2 - 1
config.yaml

@@ -8,11 +8,12 @@ db:
 modbus:
   host: 0.0.0.0
   port: 502
-  interval: 5
 
 http_provider:
   url: http://192.168.1.109:18503/data/get_points_real_value
   timeout_seconds: 3
+  interval: 5
+  batch_size: 200
 
 logging:
   dir: logs

+ 4 - 4
main.py

@@ -25,14 +25,14 @@ def main(config_path: str = "config.yaml") -> int:
     logger.info("日志系统初始化完成")
     logger.info("配置文件加载完成")
     logger.info(
-        "运行配置: 数据库=%s:%s/%s, Modbus监听=%s:%s, 刷新周期=%s秒, 批量大小=%s",
+        "运行配置: 数据库=%s:%s/%s, Modbus监听=%s:%s, HTTP刷新周期=%s秒, HTTP批量大小=%s",
         config.db.host,
         config.db.port,
         config.db.database,
         config.modbus.host,
         config.modbus.port,
-        config.modbus.interval,
-        DEFAULT_BATCH_SIZE,
+        config.http_provider.interval,
+        config.http_provider.batch_size,
     )
 
     try:
@@ -48,7 +48,7 @@ def main(config_path: str = "config.yaml") -> int:
     provider = HttpValueProvider(config.http_provider.url, config.http_provider.timeout_seconds)
     logger.info("实时值Provider初始化完成,类型=http")
 
-    worker = ValueRefreshWorker(points, provider, store, config.modbus.interval)
+    worker = ValueRefreshWorker(points, provider, store, config.http_provider.interval, config.http_provider.batch_size)
     logger.info("开始请求初始化实时值")
     try:
         worker.refresh_once(initial=True)

+ 16 - 8
tests/compare_modbus_http_client.py

@@ -15,7 +15,6 @@ from pymodbus.client import ModbusTcpClient
 sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
 
 from app_config import load_config  # noqa: E402
-from constants import DEFAULT_BATCH_SIZE  # noqa: E402
 from db import create_connection  # noqa: E402
 from http_value_provider import HttpValueProvider  # noqa: E402
 from modbus_context import ReadonlyHoldingRegisterContext  # noqa: E402
@@ -57,8 +56,16 @@ def main() -> int:
     if args.self_start:
         host = args.modbus_host or "127.0.0.1"
         print(f"开始读取HTTP实时值接口快照,url={config.http_provider.url}, 点位数={len(points)}")
-        http_values = _read_http_values(provider, [point.point_id for point in points])
-        _start_embedded_modbus_server(host, port, points, http_values, config.modbus.interval, args.timeout)
+        http_values = _read_http_values(provider, [point.point_id for point in points], config.http_provider.batch_size)
+        _start_embedded_modbus_server(
+            host,
+            port,
+            points,
+            http_values,
+            config.http_provider.interval,
+            config.http_provider.batch_size,
+            args.timeout,
+        )
     else:
         http_values = None
 
@@ -67,7 +74,7 @@ def main() -> int:
 
     if http_values is None:
         print(f"开始读取HTTP实时值接口,url={config.http_provider.url}, 点位数={len(points)}")
-        http_values = _read_http_values(provider, [point.point_id for point in points])
+        http_values = _read_http_values(provider, [point.point_id for point in points], config.http_provider.batch_size)
 
     return _compare(points, modbus_results, http_values, args.max_details)
 
@@ -104,11 +111,12 @@ def _start_embedded_modbus_server(
     points: list[ModbusPoint],
     http_values: dict[str, object],
     interval_seconds: int,
+    batch_size: int,
     timeout_seconds: float,
 ) -> None:
     store = RegisterStore()
     initialize_register_store(points, store)
-    ValueRefreshWorker(points, SnapshotProvider(http_values), store, interval_seconds).refresh_once(initial=True)
+    ValueRefreshWorker(points, SnapshotProvider(http_values), store, interval_seconds, batch_size).refresh_once(initial=True)
     context = ReadonlyHoldingRegisterContext(store)
     thread = threading.Thread(
         target=run_modbus_server,
@@ -164,10 +172,10 @@ def _read_modbus_values(host: str, port: int, timeout: float, points: list[Modbu
         client.close()
 
 
-def _read_http_values(provider: HttpValueProvider, point_ids: list[str]) -> dict[str, object]:
+def _read_http_values(provider: HttpValueProvider, point_ids: list[str], batch_size: int) -> dict[str, object]:
     values: dict[str, object] = {}
-    for start in range(0, len(point_ids), DEFAULT_BATCH_SIZE):
-        batch = point_ids[start:start + DEFAULT_BATCH_SIZE]
+    for start in range(0, len(point_ids), batch_size):
+        batch = point_ids[start:start + batch_size]
         values.update(provider.fetch_values(batch))
     return values
 

+ 7 - 4
tests/test_app_config.py

@@ -27,10 +27,11 @@ db:
         self.assertEqual(config.db.port, 5432)
         self.assertEqual(config.modbus.host, "0.0.0.0")
         self.assertEqual(config.modbus.port, 502)
-        self.assertEqual(config.modbus.interval, 5)
+        self.assertEqual(config.http_provider.interval, 5)
+        self.assertEqual(config.http_provider.batch_size, 200)
         self.assertEqual(config.logging.retention_days, 3)
 
-    def test_load_modbus_interval(self):
+    def test_load_http_provider_refresh_options(self):
         with tempfile.TemporaryDirectory() as tmp_dir:
             path = Path(tmp_dir) / "config.yaml"
             path.write_text(
@@ -41,15 +42,17 @@ db:
   database: test_db
   user: postgres
   password: secret
-modbus:
+http_provider:
   interval: 10
+  batch_size: 50
 """.strip(),
                 encoding="utf-8",
             )
 
             config = load_config(path)
 
-        self.assertEqual(config.modbus.interval, 10)
+        self.assertEqual(config.http_provider.interval, 10)
+        self.assertEqual(config.http_provider.batch_size, 50)
 
 
 if __name__ == "__main__":

+ 5 - 5
tests/test_value_refresh.py

@@ -45,7 +45,7 @@ class ValueRefreshTest(unittest.TestCase):
         ]
         store = RegisterStore()
         initialize_register_store(points, store)
-        worker = ValueRefreshWorker(points, FakeProvider({"a": 12}), store, 5)
+        worker = ValueRefreshWorker(points, FakeProvider({"a": 12}), store, 5, 200)
 
         with self.assertRaises(RuntimeError):
             worker.refresh_once(initial=True)
@@ -56,8 +56,8 @@ class ValueRefreshTest(unittest.TestCase):
         points = [ModbusPoint("a", "A", "int16", 1, 0)]
         store = RegisterStore()
         initialize_register_store(points, store)
-        ValueRefreshWorker(points, FakeProvider({"a": 5}), store, 5).refresh_once(initial=True)
-        ValueRefreshWorker(points, FakeProvider({}), store, 5).refresh_once(initial=False)
+        ValueRefreshWorker(points, FakeProvider({"a": 5}), store, 5, 200).refresh_once(initial=True)
+        ValueRefreshWorker(points, FakeProvider({}), store, 5, 200).refresh_once(initial=False)
 
         self.assertEqual(store.read_holding_registers(1, 0, 1), [5])
 
@@ -67,7 +67,7 @@ class ValueRefreshTest(unittest.TestCase):
         initialize_register_store(points, store)
 
         with self.assertRaises(RuntimeError):
-            ValueRefreshWorker(points, FailingProvider(), store, 5).refresh_once(initial=True)
+            ValueRefreshWorker(points, FailingProvider(), store, 5, 200).refresh_once(initial=True)
 
         self.assertEqual(store.read_holding_registers(1, 0, 1), [0])
 
@@ -77,7 +77,7 @@ class ValueRefreshTest(unittest.TestCase):
         initialize_register_store(points, store)
         provider = FailingMiddleBatchProvider()
 
-        ValueRefreshWorker(points, provider, store, 5).refresh_once(initial=False)
+        ValueRefreshWorker(points, provider, store, 5, 200).refresh_once(initial=False)
 
         self.assertEqual(len(provider.called_batches), 3)
         self.assertEqual(store.read_holding_registers(1, 0, 1), [0])

+ 4 - 4
value_refresh.py

@@ -4,19 +4,19 @@ import logging
 import threading
 import time
 
-from constants import DEFAULT_BATCH_SIZE
 from modbus_codec import encode_registers
 
 logger = logging.getLogger(__name__)
 
 
 class ValueRefreshWorker(threading.Thread):
-    def __init__(self, points, provider, store, interval_seconds: int):
+    def __init__(self, points, provider, store, interval_seconds: int, batch_size: int):
         super().__init__(name="value-refresh-worker", daemon=True)
         self.points = points
         self.provider = provider
         self.store = store
         self.interval_seconds = interval_seconds
+        self.batch_size = batch_size
 
     def run(self) -> None:
         logger.info("实时值刷新线程已启动,刷新周期=%s秒", self.interval_seconds)
@@ -34,9 +34,9 @@ class ValueRefreshWorker(threading.Thread):
         total_batches = 0
         failed_batches = 0
         try:
-            for start in range(0, len(point_ids), DEFAULT_BATCH_SIZE):
+            for start in range(0, len(point_ids), self.batch_size):
                 total_batches += 1
-                batch = point_ids[start:start + DEFAULT_BATCH_SIZE]
+                batch = point_ids[start:start + self.batch_size]
                 try:
                     values = self.provider.fetch_values(batch)
                 except Exception: