import unittest from point_loader import check_point_exists, load_points, validate_address_overlaps, validate_data_types from point_model import ModbusPoint class FakeCursor: def __init__(self, existing): self.existing = set(existing) self.batch_sizes = [] self.current_batch = [] def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def execute(self, sql, params=None): self.current_batch = list(params[0]) self.batch_sizes.append(len(self.current_batch)) def fetchall(self): return [(point_id,) for point_id in self.current_batch if point_id in self.existing] class FakeConnection: def __init__(self, existing): self.cursor_obj = FakeCursor(existing) def cursor(self): return self.cursor_obj class LoadPointsCursor: def __init__(self): self.sql = "" def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def execute(self, sql, params=None): self.sql = sql def fetchall(self): return [("p1", "P1", "int16", 1, 0)] class LoadPointsConnection: def __init__(self): self.cursor_obj = LoadPointsCursor() def cursor(self): return self.cursor_obj class PointLoaderTest(unittest.TestCase): def test_load_points_does_not_filter_enabled(self): conn = LoadPointsConnection() points = load_points(conn) self.assertEqual(points[0].point_id, "p1") self.assertNotIn("enabled", conn.cursor_obj.sql.lower()) def test_validate_address_overlaps(self): points = [ ModbusPoint("a", "A", "float32", 1, 0), ModbusPoint("b", "B", "int16", 1, 1), ModbusPoint("c", "C", "int16", 2, 1), ] errors = validate_address_overlaps(points) self.assertEqual(len(errors), 1) self.assertIn("地址重叠", errors[0]) self.assertIn("a", errors[0]) self.assertIn("b", errors[0]) def test_validate_data_types(self): points = [ ModbusPoint("a", "A", "float32", 1, 0), ModbusPoint("b", "B", "bad", 1, 2), ] errors = validate_data_types(points) self.assertEqual(errors, ["point_id=b, data_type=bad"]) def test_check_point_exists_batches_by_200(self): point_ids = [f"p{i}" for i in range(205)] conn = FakeConnection(existing=point_ids[:203]) missing = check_point_exists(conn, point_ids) self.assertEqual(missing, ["p203", "p204"]) self.assertEqual(conn.cursor_obj.batch_sizes, [200, 5]) if __name__ == "__main__": unittest.main()