stress_modbus_clients.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """Stress test Modbus TCP reads with multiple concurrent clients.
  2. Run this while the Modbus server is already running.
  3. """
  4. import argparse
  5. import statistics
  6. import sys
  7. import threading
  8. import time
  9. from concurrent.futures import ThreadPoolExecutor, as_completed
  10. from dataclasses import dataclass
  11. from pathlib import Path
  12. from pymodbus.client import ModbusTcpClient
  13. sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
  14. from app_config import load_config # noqa: E402
  15. from db import create_connection # noqa: E402
  16. from point_loader import load_points # noqa: E402
  17. from point_model import ModbusPoint # noqa: E402
  18. @dataclass(frozen=True)
  19. class ClientResult:
  20. client_index: int
  21. elapsed_seconds: float
  22. success_points: int
  23. failed_points: int
  24. success_registers: int
  25. latencies_seconds: list[float]
  26. errors: list[str]
  27. def parse_args() -> argparse.Namespace:
  28. parser = argparse.ArgumentParser(description="Stress test concurrent pymodbus clients reading all configured points.")
  29. parser.add_argument("--config", default="config.yaml", help="config file path")
  30. parser.add_argument("--modbus-host", help="Modbus TCP host. Defaults to config host, 0.0.0.0 becomes 127.0.0.1")
  31. parser.add_argument("--modbus-port", type=int, help="Modbus TCP port. Defaults to config port")
  32. parser.add_argument("--client-counts", default="4,8", help="comma-separated concurrent client counts")
  33. parser.add_argument("--repeat", type=int, default=1, help="repeat count for each client count")
  34. parser.add_argument("--duration", type=float, default=0, help="seconds to keep reading; 0 means read all points once")
  35. parser.add_argument("--timeout", type=float, default=3, help="Modbus client timeout seconds")
  36. parser.add_argument("--limit", type=int, default=0, help="limit point count; 0 means all points")
  37. parser.add_argument("--max-errors", type=int, default=20, help="max error samples to print per run")
  38. return parser.parse_args()
  39. def main() -> int:
  40. args = parse_args()
  41. config = load_config(args.config)
  42. points = _load_points(config)
  43. if args.limit > 0:
  44. points = points[:args.limit]
  45. if not points:
  46. print("没有可读取的点位")
  47. return 2
  48. host = args.modbus_host or _client_host(config.modbus.host)
  49. port = args.modbus_port or config.modbus.port
  50. client_counts = _parse_client_counts(args.client_counts)
  51. print(
  52. f"压力测试开始: host={host}, port={port}, 点位数={len(points)}, "
  53. f"client并发={client_counts}, repeat={args.repeat}, duration={args.duration}s, timeout={args.timeout}s"
  54. )
  55. exit_code = 0
  56. for client_count in client_counts:
  57. for run_index in range(1, args.repeat + 1):
  58. results = _run_once(host, port, args.timeout, args.duration, points, client_count)
  59. if _print_run_summary(client_count, run_index, args.duration, points, results, args.max_errors) != 0:
  60. exit_code = 1
  61. return exit_code
  62. def _load_points(config) -> list[ModbusPoint]:
  63. conn = create_connection(config.db)
  64. try:
  65. return load_points(conn)
  66. finally:
  67. conn.close()
  68. def _client_host(host: str) -> str:
  69. return "127.0.0.1" if host in {"", "0.0.0.0", "::"} else host
  70. def _parse_client_counts(raw: str) -> list[int]:
  71. client_counts = [int(item.strip()) for item in raw.split(",") if item.strip()]
  72. if not client_counts or any(item <= 0 for item in client_counts):
  73. raise ValueError("--client-counts 必须是正整数列表,例如 4,8")
  74. return client_counts
  75. def _run_once(
  76. host: str,
  77. port: int,
  78. timeout: float,
  79. duration: float,
  80. points: list[ModbusPoint],
  81. client_count: int,
  82. ) -> list[ClientResult]:
  83. start_barrier = threading.Barrier(client_count)
  84. with ThreadPoolExecutor(max_workers=client_count) as executor:
  85. futures = [
  86. executor.submit(_read_all_points, index, host, port, timeout, duration, points, start_barrier)
  87. for index in range(1, client_count + 1)
  88. ]
  89. return [future.result() for future in as_completed(futures)]
  90. def _read_all_points(
  91. client_index: int,
  92. host: str,
  93. port: int,
  94. timeout: float,
  95. duration: float,
  96. points: list[ModbusPoint],
  97. start_barrier: threading.Barrier,
  98. ) -> ClientResult:
  99. errors: list[str] = []
  100. success_points = 0
  101. failed_points = 0
  102. success_registers = 0
  103. latencies_seconds: list[float] = []
  104. start_barrier.wait(timeout + 10)
  105. started_at = time.perf_counter()
  106. client = ModbusTcpClient(host, port=port, timeout=timeout)
  107. try:
  108. if not client.connect():
  109. return ClientResult(client_index, time.perf_counter() - started_at, 0, len(points), 0, [], ["连接失败"])
  110. deadline = started_at + duration if duration > 0 else None
  111. while True:
  112. for point in points:
  113. if deadline is not None and time.perf_counter() >= deadline:
  114. break
  115. request_started_at = time.perf_counter()
  116. try:
  117. response = client.read_holding_registers(
  118. point.address,
  119. count=point.register_count,
  120. device_id=point.slave_id,
  121. )
  122. except Exception as exc:
  123. latencies_seconds.append(time.perf_counter() - request_started_at)
  124. failed_points += 1
  125. errors.append(f"{point.point_id}: exception={exc}")
  126. continue
  127. latencies_seconds.append(time.perf_counter() - request_started_at)
  128. if response.isError():
  129. failed_points += 1
  130. errors.append(f"{point.point_id}: response={response}")
  131. continue
  132. success_points += 1
  133. success_registers += len(response.registers)
  134. if deadline is None or time.perf_counter() >= deadline:
  135. break
  136. finally:
  137. client.close()
  138. return ClientResult(
  139. client_index=client_index,
  140. elapsed_seconds=time.perf_counter() - started_at,
  141. success_points=success_points,
  142. failed_points=failed_points,
  143. success_registers=success_registers,
  144. latencies_seconds=latencies_seconds,
  145. errors=errors,
  146. )
  147. def _print_run_summary(
  148. client_count: int,
  149. run_index: int,
  150. duration: float,
  151. points: list[ModbusPoint],
  152. results: list[ClientResult],
  153. max_errors: int,
  154. ) -> int:
  155. durations = [result.elapsed_seconds for result in results]
  156. wall_seconds = max(durations) if durations else 0
  157. total_success_points = sum(result.success_points for result in results)
  158. total_failed_points = sum(result.failed_points for result in results)
  159. total_success_registers = sum(result.success_registers for result in results)
  160. expected_points = None if duration > 0 else client_count * len(points)
  161. point_reads_per_second = total_success_points / wall_seconds if wall_seconds else 0
  162. register_reads_per_second = total_success_registers / wall_seconds if wall_seconds else 0
  163. latencies = [latency for result in results for latency in result.latencies_seconds]
  164. expected_text = "持续读取" if expected_points is None else f"预期点位读取={expected_points}"
  165. print(
  166. f"\n并发client={client_count}, run={run_index}: "
  167. f"总耗时={wall_seconds:.3f}s, {expected_text}, "
  168. f"成功点位读取={total_success_points}, 失败点位读取={total_failed_points}, "
  169. f"点位吞吐={point_reads_per_second:.2f}/s, 寄存器吞吐={register_reads_per_second:.2f}/s"
  170. )
  171. print(
  172. f"client耗时: min={min(durations):.3f}s, avg={statistics.mean(durations):.3f}s, "
  173. f"p95={_percentile(durations, 95):.3f}s, max={max(durations):.3f}s"
  174. )
  175. if latencies:
  176. print(
  177. "单次请求响应耗时: "
  178. f"min={min(latencies) * 1000:.2f}ms, avg={statistics.mean(latencies) * 1000:.2f}ms, "
  179. f"p50={_percentile(latencies, 50) * 1000:.2f}ms, "
  180. f"p95={_percentile(latencies, 95) * 1000:.2f}ms, "
  181. f"p99={_percentile(latencies, 99) * 1000:.2f}ms, max={max(latencies) * 1000:.2f}ms"
  182. )
  183. for result in sorted(results, key=lambda item: item.client_index):
  184. print(
  185. f"client#{result.client_index}: 耗时={result.elapsed_seconds:.3f}s, "
  186. f"成功={result.success_points}, 失败={result.failed_points}, 寄存器={result.success_registers}"
  187. )
  188. error_samples = [error for result in results for error in result.errors[:max_errors]]
  189. for error in error_samples[:max_errors]:
  190. print(f"错误样例: {error}")
  191. if len(error_samples) > max_errors:
  192. print(f"错误样例还有 {len(error_samples) - max_errors} 条未显示")
  193. if expected_points is None:
  194. return 0 if total_success_points > 0 and total_failed_points == 0 else 1
  195. return 0 if total_success_points == expected_points and total_failed_points == 0 else 1
  196. def _percentile(values: list[float], percentile: int) -> float:
  197. if not values:
  198. return 0
  199. ordered = sorted(values)
  200. index = round((len(ordered) - 1) * percentile / 100)
  201. return ordered[index]
  202. if __name__ == "__main__":
  203. raise SystemExit(main())