modbus_server_design_and_implementation.md 21 KB

Modbus Server 设计与实现

目标

实现一个只读 Modbus TCP Server。程序启动时从 modbus_server_point 加载点位配置,校验点位是否存在于 pt_point,校验 Modbus 保持寄存器地址是否重叠,然后通过实时值 Provider 初始化寄存器数据。启动完成后,后台线程周期性刷新点位值并更新寄存器。

当前实时值来源为 HTTP 接口,后续可以扩展为其他来源,例如 MQTT、数据库、消息队列等。

技术栈

  • Python 3.10
  • pymodbus 3.13.1
  • requests
  • psycopg2-binary
  • PyYAML
  • logging 标准库

约束规则

  • 所有点位映射到 Holding Register。
  • Modbus Client 只能读取,不允许写入。
  • 默认 Modbus TCP 端口为 502,可通过配置文件修改。
  • modbus_server_point.address 就是保持寄存器起始地址,不做额外偏移。
  • 字节顺序固定为 ABCD
  • HTTP 响应不处理 quality 字段,只根据 data_type 转换 value
  • 更新周期默认 5 秒。
  • 日志默认保留最近 3 天。
  • Client 连接和断开需要打印日志,但不记录连接数量。
  • 初始化过程中的每一步都必须打印中文日志。

数据类型与地址占用

data_type 寄存器数量 编码
int16 1 signed int16, big-endian
int32 2 signed int32, ABCD
float32 2 IEEE754 float32, ABCD

地址范围计算:

int16   address ~ address
int32   address ~ address + 1
float32 address ~ address + 1

例如:

float32 point_a address=0,占用 0,1
int16   point_b address=1,占用 1

这属于地址重叠,程序启动时打印错误日志并退出。

数据库配置

db:
  host: 192.168.1.109
  port: 48324
  database: proj_dev2024_config
  user: postgres
  password: aragronprod

完整配置示例

配置文件使用 config.yaml。Python 标准库不支持 YAML 解析,工程需要增加 PyYAML 依赖。

db:
  host: 192.168.1.109
  port: 48324
  database: proj_dev2024_config
  user: postgres
  password: aragronprod

modbus:
  host: 0.0.0.0
  port: 502
  interval: 5

http_provider:
  url: http://192.168.1.109:18503/data/get_points_real_value
  timeout_seconds: 5

logging:
  dir: logs
  retention_days: 3
  level: INFO

初始化流程

1. 读取 config.yaml
2. 初始化日志
3. 打印启动日志和关键配置
4. 查询 modbus_server_point 中的全部点位
5. 如果表内没有点位,打印 warning,不退出,启动空 Modbus Server
6. 校验 data_type 是否为 int16/int32/float32
7. 校验同一 slave_id 下地址是否重叠
8. 如果有地址重叠,统一打印并写日志,然后退出
9. 分批查询 pt_point,确认 point_id 存在
10. 如果有缺失 point_id,全部批次查询完成后统一打印并写日志,然后退出
11. 通过 ValueProvider 获取初始实时值
12. 初始化阶段缺失实时值的点位写入默认值 0,并打印 warning
13. 初始化只读 Holding Register 存储
14. 启动后台刷新线程
15. 启动 Modbus TCP Server

异常处理策略

场景 处理
modbus_server_point 无点位 打印 warning,不退出,启动空 Server
data_type 非法 打印 error,退出
Modbus 地址重叠 打印 error,退出
pt_point 缺失 point_id 全部批次查完后统一打印 error,退出
HTTP 初始化缺失实时值 写默认值 0,打印 warning,不退出
HTTP 周期刷新缺失实时值 保持旧值,打印 warning,不退出
HTTP 请求失败 打印 error,本轮跳过,不退出
Client 写寄存器 返回 Modbus 异常,不修改数据

模块划分

建议目录结构:

modbus_server_nd/
  main.py
  app_config.py
  constants.py
  logging_config.py
  db.py
  point_model.py
  point_loader.py
  value_provider.py
  http_value_provider.py
  modbus_codec.py
  register_store.py
  modbus_context.py
  modbus_server.py
config.yaml

模块职责:

模块 职责
main.py 程序入口,编排初始化和启动
app_config.py 加载配置和默认值
constants.py 代码常量,例如 DEFAULT_BATCH_SIZE = 200
logging_config.py 初始化日志轮转
db.py PostgreSQL 连接
point_model.py 点位数据结构
point_loader.py 加载点位、校验地址、校验 pt_point
value_provider.py 实时值 Provider 抽象接口
http_value_provider.py HTTP 实时值 Provider 实现
modbus_codec.py value 到寄存器的编码转换
register_store.py 线程安全寄存器存储
modbus_context.py pymodbus 自定义只读 context
modbus_server.py 启动 Modbus TCP Server 和后台刷新

扩展性设计

实时值来源通过接口隔离。

from abc import ABC, abstractmethod


class ValueProvider(ABC):
    @abstractmethod
    def fetch_values(self, point_ids: list[str]) -> dict[str, object]:
        """返回 point_id -> value。"""

当前 HTTP 实现:

import requests


class HttpValueProvider(ValueProvider):
    def __init__(self, url: str, timeout_seconds: int):
        self.url = url
        self.timeout_seconds = timeout_seconds

    def fetch_values(self, point_ids: list[str]) -> dict[str, object]:
        response = requests.post(
            self.url,
            json={"point_ids": point_ids},
            timeout=self.timeout_seconds,
        )
        response.raise_for_status()
        payload = response.json()
        if payload.get("state") != 0:
            raise RuntimeError(f"realtime api failed: {payload}")
        return {item["point_id"]: item.get("value") for item in payload.get("data", [])}

后续如果改成其他数据源,只新增一个 ValueProvider 实现即可。

点位模型

from dataclasses import dataclass


@dataclass(frozen=True)
class ModbusPoint:
    point_id: str
    name: str
    data_type: str
    slave_id: int
    address: int

    @property
    def register_count(self) -> int:
        return 1 if self.data_type == "int16" else 2

    @property
    def end_address(self) -> int:
        return self.address + self.register_count - 1

地址重叠校验

slave_id 分组,每个点位展开成地址范围,检查范围是否相交。

def validate_address_overlaps(points: list[ModbusPoint]) -> list[str]:
    errors: list[str] = []
    by_slave: dict[int, list[ModbusPoint]] = {}
    for point in points:
        by_slave.setdefault(point.slave_id, []).append(point)

    for slave_id, slave_points in by_slave.items():
        sorted_points = sorted(slave_points, key=lambda item: item.address)
        previous: ModbusPoint | None = None
        for current in sorted_points:
            if previous and current.address <= previous.end_address:
                errors.append(
                    "从站=%s 地址重叠: %s(%s) 范围=%s-%s, %s(%s) 范围=%s-%s"
                    % (
                        slave_id,
                        previous.point_id,
                        previous.data_type,
                        previous.address,
                        previous.end_address,
                        current.point_id,
                        current.data_type,
                        current.address,
                        current.end_address,
                    )
                )
            if previous is None or current.end_address > previous.end_address:
                previous = current
    return errors

pt_point 分批校验

只查 point_id 是否存在,不全表扫描。

批量大小不放入配置文件,代码中固定默认值为 200,建议放在 constants.py

DEFAULT_BATCH_SIZE = 200


def check_point_exists(conn, point_ids: list[str]) -> list[str]:
    existing: set[str] = set()
    with conn.cursor() as cursor:
        for start in range(0, len(point_ids), DEFAULT_BATCH_SIZE):
            batch = point_ids[start:start + DEFAULT_BATCH_SIZE]
            cursor.execute(
                "SELECT point_id FROM pt_point WHERE point_id = ANY(%s)",
                (batch,),
            )
            existing.update(row[0] for row in cursor.fetchall())
    return sorted(set(point_ids) - existing)

缺失点位需要等全部批次查完后再统一打印并退出。

ABCD 编码

import struct


def encode_registers(value: object, data_type: str) -> list[int]:
    if value is None:
        value = 0

    if data_type == "int16":
        packed = struct.pack(">h", int(value))
    elif data_type == "int32":
        packed = struct.pack(">i", int(value))
    elif data_type == "float32":
        packed = struct.pack(">f", float(value))
    else:
        raise ValueError(f"unsupported data_type: {data_type}")

    return [int.from_bytes(packed[index:index + 2], "big") for index in range(0, len(packed), 2)]

线程安全寄存器存储

后台刷新线程写入,pymodbus 请求线程读取,因此需要加锁。

from threading import RLock
from pymodbus.constants import ExcCodes


class RegisterStore:
    def __init__(self):
        self._lock = RLock()
        self._registers: dict[int, dict[int, int]] = {}
        self._valid_addresses: dict[int, set[int]] = {}

    def initialize_slave(self, slave_id: int, registers: dict[int, int]) -> None:
        with self._lock:
            self._registers[slave_id] = dict(registers)
            self._valid_addresses[slave_id] = set(registers)

    def read_holding_registers(self, slave_id: int, address: int, count: int):
        with self._lock:
            slave_registers = self._registers.get(slave_id)
            valid_addresses = self._valid_addresses.get(slave_id)
            if not slave_registers or not valid_addresses:
                return ExcCodes.ILLEGAL_ADDRESS
            addresses = range(address, address + count)
            if any(item not in valid_addresses for item in addresses):
                return ExcCodes.ILLEGAL_ADDRESS
            return [slave_registers[item] for item in addresses]

    def write_internal(self, slave_id: int, address: int, values: list[int]) -> None:
        with self._lock:
            slave_registers = self._registers.get(slave_id)
            valid_addresses = self._valid_addresses.get(slave_id)
            if slave_registers is None or valid_addresses is None:
                raise KeyError(f"slave_id is not initialized: {slave_id}")
            for offset, register in enumerate(values):
                register_address = address + offset
                if register_address not in valid_addresses:
                    raise KeyError(f"address is not configured: slave_id={slave_id}, address={register_address}")
                slave_registers[register_address] = register

初始化寄存器存储时,先为所有已配置点位写入默认值 0,这样 HTTP 初始化缺失实时值时,地址仍然有效,寄存器值为 0

def initialize_register_store(points: list[ModbusPoint], store: RegisterStore) -> None:
    by_slave: dict[int, dict[int, int]] = {}
    for point in points:
        registers = encode_registers(0, point.data_type)
        slave_registers = by_slave.setdefault(point.slave_id, {})
        for offset, register in enumerate(registers):
            slave_registers[point.address + offset] = register

    for slave_id, registers in by_slave.items():
        store.initialize_slave(slave_id, registers)

pymodbus 自定义只读 Context

pymodbus 3.13.1 中旧版 ModbusDeviceContext/ModbusServerContext 已废弃,且旧式示例不适合动态更新。建议实现自定义 context,继承 ModbusServerContext 以满足 server 类型判断,但实际读写由 RegisterStore 完成。

from pymodbus.constants import ExcCodes
from pymodbus.datastore import ModbusServerContext


class ReadonlyHoldingRegisterContext(ModbusServerContext):
    def __init__(self, store: RegisterStore):
        self.simdevices = []
        self.store = store

    def device_ids(self) -> list[int]:
        return sorted(self.store._registers.keys())

    async def async_getValues(self, device_id: int, func_code: int, address: int, count: int = 1):
        if func_code != 3:
            return ExcCodes.ILLEGAL_FUNCTION
        return self.store.read_holding_registers(device_id, address, count)

    async def async_setValues(self, device_id: int, func_code: int, address: int, values: list[int]):
        return ExcCodes.ILLEGAL_ADDRESS

说明:

  • 功能码 03 允许读取 Holding Register。
  • 其他读取类型返回 ILLEGAL_FUNCTION
  • 功能码 06/16/22/23 等写保持寄存器返回异常,不修改数据。
  • 后台刷新线程不走 async_setValues,而是调用 RegisterStore.write_internal

Client 连接与断开日志

pymodbustrace_connect 只能拿到连接或断开的布尔值,拿不到 client 地址。为了打印更有用的日志,可以自定义 request handler,从 transport 中读取 peername

不记录连接数量。

import logging
from pymodbus.server.requesthandler import ServerRequestHandler
from pymodbus.server.server import ModbusTcpServer

logger = logging.getLogger(__name__)


class LoggingServerRequestHandler(ServerRequestHandler):
    def _client_addr(self) -> str:
        if not self.transport:
            return "unknown"
        peer = self.transport.get_extra_info("peername")
        return "%s:%s" % peer if peer else "unknown"

    def callback_connected(self) -> None:
        super().callback_connected()
        logger.info("客户端已连接(Modbus): %s", self._client_addr())

    def callback_disconnected(self, exc: Exception | None) -> None:
        client_addr = self._client_addr()
        super().callback_disconnected(exc)
        if exc:
            logger.info("客户端已断开(Modbus): %s, 原因=%s", client_addr, exc)
        else:
            logger.info("客户端已断开(Modbus): %s", client_addr)


class LoggingModbusTcpServer(ModbusTcpServer):
    def callback_new_connection(self):
        return LoggingServerRequestHandler(
            self,
            self.trace_packet,
            self.trace_pdu,
            self.trace_connect,
        )

后台刷新线程

import logging
import threading
import time

logger = logging.getLogger(__name__)


class ValueRefreshWorker(threading.Thread):
    def __init__(self, points, provider, store, interval_seconds: int):
        super().__init__(name="value-refresh-worker", daemon=True)
        self.points = points
        self.provider = provider
        self.store = store
        self.interval_seconds = interval_seconds

    def run(self) -> None:
        logger.info("实时值刷新线程已启动,刷新周期=%s秒", self.interval_seconds)
        while True:
            try:
                self.refresh_once(initial=False)
            except Exception:
                logger.exception("实时值刷新失败")
            time.sleep(self.interval_seconds)

    def refresh_once(self, initial: bool) -> None:
        point_by_id = {point.point_id: point for point in self.points}
        point_ids = list(point_by_id)
        for start in range(0, len(point_ids), DEFAULT_BATCH_SIZE):
            batch = point_ids[start:start + DEFAULT_BATCH_SIZE]
            values = self.provider.fetch_values(batch)
            for point_id in batch:
                point = point_by_id[point_id]
                if point_id not in values:
                    if initial:
                        logger.warning("初始化实时值缺失,point_id=%s,使用默认值0", point_id)
                        value = 0
                    else:
                        logger.warning("周期刷新实时值缺失,point_id=%s,保持旧值", point_id)
                        continue
                else:
                    value = values[point_id]
                registers = encode_registers(value, point.data_type)
                self.store.write_internal(point.slave_id, point.address, registers)

启动 Modbus TCP Server

import asyncio
import logging

logger = logging.getLogger(__name__)


async def start_modbus_server(context, host: str, port: int) -> None:
    server = LoggingModbusTcpServer(
        context,
        address=(host, port),
        ignore_missing_devices=False,
        broadcast_enable=False,
    )
    logger.info("服务已启动监听(Modbus TCP),地址=%s:%s", host, port)
    await server.serve_forever()


def run_modbus_server(context, host: str, port: int) -> None:
    asyncio.run(start_modbus_server(context, host, port))

日志设计

使用 TimedRotatingFileHandler 按天切分日志,默认保留 3 天。

import logging
from logging.handlers import TimedRotatingFileHandler
from pathlib import Path


def setup_logging(log_dir: str, retention_days: int, level: str) -> None:
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    formatter = logging.Formatter(
        "%(asctime)s %(levelname)s [%(name)s] %(message)s"
    )

    file_handler = TimedRotatingFileHandler(
        filename=str(Path(log_dir) / "modbus-server.log"),
        when="midnight",
        interval=1,
        backupCount=retention_days,
        encoding="utf-8",
    )
    file_handler.setFormatter(formatter)

    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)

    root_logger = logging.getLogger()
    root_logger.setLevel(getattr(logging, level.upper(), logging.INFO))
    root_logger.handlers.clear()
    root_logger.addHandler(file_handler)
    root_logger.addHandler(console_handler)

每个初始化过程都需要打印日志,至少包括:

正在启动Modbus Server
日志系统初始化完成
配置文件加载完成
运行配置: 数据库=host:port/database, Modbus监听=host:port, 刷新周期=5秒, 批量大小=200
数据库连接成功
开始从modbus_server_point加载全部点位
点位加载完成,数量=...
数据表modbus_server_point没有点位,将启动空Modbus Server
开始校验点位data_type
开始校验Modbus地址重叠
开始校验pt_point点位是否存在,批量大小=200
开始请求初始化实时值
初始化实时值缺失,point_id=...,使用默认值0
寄存器存储初始化完成
实时值Provider初始化完成,类型=http
实时值刷新线程已启动,刷新周期=5秒
上下文初始化完成(Modbus)
服务已启动监听(Modbus TCP),地址=host:port
客户端已连接(Modbus): ip:port
客户端已断开(Modbus): ip:port

main.py 编排逻辑

def main() -> int:
    config = load_config("config.yaml")
    setup_logging(
        config.logging.dir,
        config.logging.retention_days,
        config.logging.level,
    )

    logger.info("正在启动Modbus Server")
    logger.info("日志系统初始化完成")
    logger.info("配置文件加载完成")
    logger.info(
        "运行配置: 数据库=%s:%s/%s, Modbus监听=%s:%s, 刷新周期=%s秒, 批量大小=%s",
        config.db.host,
        config.db.port,
        config.db.database,
        config.modbus.host,
        config.modbus.port,
        config.modbus.interval,
        DEFAULT_BATCH_SIZE,
    )

    conn = create_connection(config.db)
    logger.info("数据库连接成功")

    logger.info("开始从modbus_server_point加载全部点位")
    points = load_points(conn)
    logger.info("点位加载完成,数量=%s", len(points))

    if not points:
        logger.warning("数据表modbus_server_point没有点位,将启动空Modbus Server")

    logger.info("开始校验点位data_type")
    validate_data_types(points)

    logger.info("开始校验Modbus地址重叠")
    overlap_errors = validate_address_overlaps(points)
    if overlap_errors:
        for error in overlap_errors:
            logger.error(error)
        return 1

    logger.info("开始校验pt_point点位是否存在,批量大小=%s", DEFAULT_BATCH_SIZE)
    missing_point_ids = check_point_exists(conn, [point.point_id for point in points])
    if missing_point_ids:
        logger.error("数据表pt_point中缺失以下point_id: %s", missing_point_ids)
        return 1

    store = RegisterStore()
    initialize_register_store(points, store)
    logger.info("寄存器存储初始化完成")

    provider = HttpValueProvider(
        config.http_provider.url,
        config.http_provider.timeout_seconds,
    )
    logger.info("实时值Provider初始化完成,类型=http")

    worker = ValueRefreshWorker(
        points,
        provider,
        store,
        config.modbus.interval,
    )
    logger.info("开始请求初始化实时值")
    worker.refresh_once(initial=True)
    worker.start()

    context = ReadonlyHoldingRegisterContext(store)
    logger.info("上下文初始化完成(Modbus)")
    run_modbus_server(context, config.modbus.host, config.modbus.port)
    return 0

空 Server 行为

modbus_server_point 没有任何点位时:

  • 程序打印 warning。
  • 不退出。
  • 启动空 Modbus Server。
  • Client 读取任意 Holding Register 地址时返回 ILLEGAL_ADDRESS

数据库 SQL

加载点位:

SELECT point_id, name, data_type, slave_id, address
FROM modbus_server_point
ORDER BY slave_id, address, point_id;

分批校验 pt_point

SELECT point_id
FROM pt_point
WHERE point_id = ANY(%s);

启动注意事项

  • 默认端口 502 在 Linux 上通常需要 root 权限或端口授权。
  • 如果端口绑定失败,程序需要打印 error 并退出。
  • 如果实际部署不希望 root 运行,可以在配置文件中改为 5020,再由防火墙或代理转发。