"""Stress test Modbus TCP reads with multiple concurrent clients. Run this while the Modbus server is already running. """ import argparse import statistics import sys import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path from pymodbus.client import ModbusTcpClient sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from app_config import load_config # noqa: E402 from db import create_connection # noqa: E402 from point_loader import load_points # noqa: E402 from point_model import ModbusPoint # noqa: E402 @dataclass(frozen=True) class ClientResult: client_index: int elapsed_seconds: float success_points: int failed_points: int success_registers: int latencies_seconds: list[float] errors: list[str] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Stress test concurrent pymodbus clients reading all configured points.") parser.add_argument("--config", default="config.yaml", help="config file path") parser.add_argument("--modbus-host", help="Modbus TCP host. Defaults to config host, 0.0.0.0 becomes 127.0.0.1") parser.add_argument("--modbus-port", type=int, help="Modbus TCP port. Defaults to config port") parser.add_argument("--client-counts", default="4,8", help="comma-separated concurrent client counts") parser.add_argument("--repeat", type=int, default=1, help="repeat count for each client count") parser.add_argument("--duration", type=float, default=0, help="seconds to keep reading; 0 means read all points once") parser.add_argument("--timeout", type=float, default=3, help="Modbus client timeout seconds") parser.add_argument("--limit", type=int, default=0, help="limit point count; 0 means all points") parser.add_argument("--max-errors", type=int, default=20, help="max error samples to print per run") return parser.parse_args() def main() -> int: args = parse_args() config = load_config(args.config) points = _load_points(config) if args.limit > 0: points = points[:args.limit] if not points: print("没有可读取的点位") return 2 host = args.modbus_host or _client_host(config.modbus.host) port = args.modbus_port or config.modbus.port client_counts = _parse_client_counts(args.client_counts) print( f"压力测试开始: host={host}, port={port}, 点位数={len(points)}, " f"client并发={client_counts}, repeat={args.repeat}, duration={args.duration}s, timeout={args.timeout}s" ) exit_code = 0 for client_count in client_counts: for run_index in range(1, args.repeat + 1): results = _run_once(host, port, args.timeout, args.duration, points, client_count) if _print_run_summary(client_count, run_index, args.duration, points, results, args.max_errors) != 0: exit_code = 1 return exit_code def _load_points(config) -> list[ModbusPoint]: conn = create_connection(config.db) try: return load_points(conn) finally: conn.close() def _client_host(host: str) -> str: return "127.0.0.1" if host in {"", "0.0.0.0", "::"} else host def _parse_client_counts(raw: str) -> list[int]: client_counts = [int(item.strip()) for item in raw.split(",") if item.strip()] if not client_counts or any(item <= 0 for item in client_counts): raise ValueError("--client-counts 必须是正整数列表,例如 4,8") return client_counts def _run_once( host: str, port: int, timeout: float, duration: float, points: list[ModbusPoint], client_count: int, ) -> list[ClientResult]: start_barrier = threading.Barrier(client_count) with ThreadPoolExecutor(max_workers=client_count) as executor: futures = [ executor.submit(_read_all_points, index, host, port, timeout, duration, points, start_barrier) for index in range(1, client_count + 1) ] return [future.result() for future in as_completed(futures)] def _read_all_points( client_index: int, host: str, port: int, timeout: float, duration: float, points: list[ModbusPoint], start_barrier: threading.Barrier, ) -> ClientResult: errors: list[str] = [] success_points = 0 failed_points = 0 success_registers = 0 latencies_seconds: list[float] = [] start_barrier.wait(timeout + 10) started_at = time.perf_counter() client = ModbusTcpClient(host, port=port, timeout=timeout) try: if not client.connect(): return ClientResult(client_index, time.perf_counter() - started_at, 0, len(points), 0, [], ["连接失败"]) deadline = started_at + duration if duration > 0 else None while True: for point in points: if deadline is not None and time.perf_counter() >= deadline: break request_started_at = time.perf_counter() try: response = client.read_holding_registers( point.address, count=point.register_count, device_id=point.slave_id, ) except Exception as exc: latencies_seconds.append(time.perf_counter() - request_started_at) failed_points += 1 errors.append(f"{point.point_id}: exception={exc}") continue latencies_seconds.append(time.perf_counter() - request_started_at) if response.isError(): failed_points += 1 errors.append(f"{point.point_id}: response={response}") continue success_points += 1 success_registers += len(response.registers) if deadline is None or time.perf_counter() >= deadline: break finally: client.close() return ClientResult( client_index=client_index, elapsed_seconds=time.perf_counter() - started_at, success_points=success_points, failed_points=failed_points, success_registers=success_registers, latencies_seconds=latencies_seconds, errors=errors, ) def _print_run_summary( client_count: int, run_index: int, duration: float, points: list[ModbusPoint], results: list[ClientResult], max_errors: int, ) -> int: durations = [result.elapsed_seconds for result in results] wall_seconds = max(durations) if durations else 0 total_success_points = sum(result.success_points for result in results) total_failed_points = sum(result.failed_points for result in results) total_success_registers = sum(result.success_registers for result in results) expected_points = None if duration > 0 else client_count * len(points) point_reads_per_second = total_success_points / wall_seconds if wall_seconds else 0 register_reads_per_second = total_success_registers / wall_seconds if wall_seconds else 0 latencies = [latency for result in results for latency in result.latencies_seconds] expected_text = "持续读取" if expected_points is None else f"预期点位读取={expected_points}" print( f"\n并发client={client_count}, run={run_index}: " f"总耗时={wall_seconds:.3f}s, {expected_text}, " f"成功点位读取={total_success_points}, 失败点位读取={total_failed_points}, " f"点位吞吐={point_reads_per_second:.2f}/s, 寄存器吞吐={register_reads_per_second:.2f}/s" ) print( f"client耗时: min={min(durations):.3f}s, avg={statistics.mean(durations):.3f}s, " f"p95={_percentile(durations, 95):.3f}s, max={max(durations):.3f}s" ) if latencies: print( "单次请求响应耗时: " f"min={min(latencies) * 1000:.2f}ms, avg={statistics.mean(latencies) * 1000:.2f}ms, " f"p50={_percentile(latencies, 50) * 1000:.2f}ms, " f"p95={_percentile(latencies, 95) * 1000:.2f}ms, " f"p99={_percentile(latencies, 99) * 1000:.2f}ms, max={max(latencies) * 1000:.2f}ms" ) for result in sorted(results, key=lambda item: item.client_index): print( f"client#{result.client_index}: 耗时={result.elapsed_seconds:.3f}s, " f"成功={result.success_points}, 失败={result.failed_points}, 寄存器={result.success_registers}" ) error_samples = [error for result in results for error in result.errors[:max_errors]] for error in error_samples[:max_errors]: print(f"错误样例: {error}") if len(error_samples) > max_errors: print(f"错误样例还有 {len(error_samples) - max_errors} 条未显示") if expected_points is None: return 0 if total_success_points > 0 and total_failed_points == 0 else 1 return 0 if total_success_points == expected_points and total_failed_points == 0 else 1 def _percentile(values: list[float], percentile: int) -> float: if not values: return 0 ordered = sorted(values) index = round((len(ordered) - 1) * percentile / 100) return ordered[index] if __name__ == "__main__": raise SystemExit(main())