test_point_loader.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import unittest
  2. from point_loader import check_point_exists, load_points, validate_address_overlaps, validate_data_types
  3. from point_model import ModbusPoint
  4. class FakeCursor:
  5. def __init__(self, existing):
  6. self.existing = set(existing)
  7. self.batch_sizes = []
  8. self.current_batch = []
  9. def __enter__(self):
  10. return self
  11. def __exit__(self, exc_type, exc, tb):
  12. return False
  13. def execute(self, sql, params=None):
  14. self.current_batch = list(params[0])
  15. self.batch_sizes.append(len(self.current_batch))
  16. def fetchall(self):
  17. return [(point_id,) for point_id in self.current_batch if point_id in self.existing]
  18. class FakeConnection:
  19. def __init__(self, existing):
  20. self.cursor_obj = FakeCursor(existing)
  21. def cursor(self):
  22. return self.cursor_obj
  23. class LoadPointsCursor:
  24. def __init__(self):
  25. self.sql = ""
  26. def __enter__(self):
  27. return self
  28. def __exit__(self, exc_type, exc, tb):
  29. return False
  30. def execute(self, sql, params=None):
  31. self.sql = sql
  32. def fetchall(self):
  33. return [("p1", "P1", "int16", 1, 0)]
  34. class LoadPointsConnection:
  35. def __init__(self):
  36. self.cursor_obj = LoadPointsCursor()
  37. def cursor(self):
  38. return self.cursor_obj
  39. class PointLoaderTest(unittest.TestCase):
  40. def test_load_points_does_not_filter_enabled(self):
  41. conn = LoadPointsConnection()
  42. points = load_points(conn)
  43. self.assertEqual(points[0].point_id, "p1")
  44. self.assertNotIn("enabled", conn.cursor_obj.sql.lower())
  45. def test_validate_address_overlaps(self):
  46. points = [
  47. ModbusPoint("a", "A", "float32", 1, 0),
  48. ModbusPoint("b", "B", "int16", 1, 1),
  49. ModbusPoint("c", "C", "int16", 2, 1),
  50. ]
  51. errors = validate_address_overlaps(points)
  52. self.assertEqual(len(errors), 1)
  53. self.assertIn("地址重叠", errors[0])
  54. self.assertIn("a", errors[0])
  55. self.assertIn("b", errors[0])
  56. def test_validate_data_types(self):
  57. points = [
  58. ModbusPoint("a", "A", "float32", 1, 0),
  59. ModbusPoint("b", "B", "bad", 1, 2),
  60. ]
  61. errors = validate_data_types(points)
  62. self.assertEqual(errors, ["point_id=b, data_type=bad"])
  63. def test_check_point_exists_batches_by_200(self):
  64. point_ids = [f"p{i}" for i in range(205)]
  65. conn = FakeConnection(existing=point_ids[:203])
  66. missing = check_point_exists(conn, point_ids)
  67. self.assertEqual(missing, ["p203", "p204"])
  68. self.assertEqual(conn.cursor_obj.batch_sizes, [200, 5])
  69. if __name__ == "__main__":
  70. unittest.main()