train.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. from datetime import datetime
  3. from pathlib import Path
  4. import pandas as pd
  5. from ...model.DHU.DHU_A import DHU_A
  6. from ...model.DHU.DHU_B import DHU_B
  7. from ...tools.config_reader import ConfigReader
  8. from ...tools.data_loader import DataLoader
  9. NOW = datetime.now()
  10. PATH = os.path.dirname(os.path.realpath(__file__)).replace('\\','/')
  11. MODEL_FUNC_PATH = f'{PATH}/model_func.py'
  12. MODEL_FILE_PATH = f'./model.pkl'
  13. def train(*inputs,config=None):
  14. config = {} if config is None else config
  15. if '__LOCAL' in config.keys():
  16. config_reader_path = config['__LOCAL']
  17. data_URL = config['__URL']
  18. else:
  19. config_reader_path = '/mnt/workflow_data'
  20. data_URL = 'http://basedataportal-svc:8080/data/getpointsdata'
  21. config_reader = ConfigReader(path=f'{config_reader_path}/DHU_A配置.xlsx')
  22. if config_reader.meta['设备类型'] == 'DHU_A':
  23. MODEL = DHU_A
  24. elif config_reader.meta['设备类型'] == 'DHU_B':
  25. MODEL = DHU_B
  26. else:
  27. raise NotImplementedError(config_reader.meta['设备类型'])
  28. ALL_RESULT = {
  29. 'EXCEPTION':{
  30. 'Data': {},
  31. 'Fit' : {},
  32. 'Save': {}
  33. }
  34. }
  35. for each_eaup_name in config_reader.all_equp_names:
  36. # 获取数据
  37. try:
  38. data_loader = DataLoader(
  39. path = f'{config_reader_path}/data/train/data_his/',
  40. start_time = config_reader.get_app_info(each_eaup_name,app_type='模型训练',key='开始时间',info_type='datetime'),
  41. end_time = config_reader.get_app_info(each_eaup_name,app_type='模型训练',key='结束时间',info_type='datetime')
  42. )
  43. data_loader.dowload_equp_data(
  44. equp_name = each_eaup_name,
  45. point = config_reader.get_equp_point(each_eaup_name,equp_class=['A','B']),
  46. url = data_URL,
  47. clean_cache = False
  48. )
  49. equp_data = data_loader.get_equp_data(each_eaup_name)
  50. equp_data = clean_data(equp_data)
  51. save_data(f'{config_reader_path}/data/train/data_his_clean',f'{each_eaup_name}.pkl',equp_data)
  52. except Exception as E:
  53. ALL_RESULT['EXCEPTION']['Data'][each_eaup_name] = E
  54. continue
  55. # 训练模型
  56. try:
  57. equp_model:DHU_A = MODEL()
  58. equp_model.fit(
  59. input_data = equp_data,
  60. observed_data = equp_data,
  61. plot_TVP = False,
  62. rw_FA_val = True, #TODO
  63. exist_Fa_H = config_reader.get_equp_info(each_eaup_name,'存在回风口','bool'),
  64. exist_Fa_B = config_reader.get_equp_info(each_eaup_name,'存在补风口','bool'),
  65. )
  66. Path(f'{config_reader_path}/data/train/model').mkdir(parents=True, exist_ok=True)
  67. equp_model.save(f'{config_reader_path}/data/train/model/{each_eaup_name}.pkl')
  68. save_data(f'{config_reader_path}/data/train/data_TVP',f'{each_eaup_name}.csv',equp_model.TVP_data)
  69. save_data(f'{config_reader_path}/data/train/data_metric',f'{each_eaup_name}.csv',equp_model.TVP_metric)
  70. except Exception as E:
  71. ALL_RESULT['EXCEPTION']['Fit'][each_eaup_name] = E
  72. continue
  73. # 模型迭代
  74. if not config_reader.get_app_info(each_eaup_name,'模型训练','保存模型','bool'):
  75. continue
  76. try:
  77. monitor_point = config_reader.point.loc[lambda dt:dt.类型=='B']
  78. model_update_info = {}
  79. for i in range(len(monitor_point)):
  80. name = monitor_point.loc[:,'编号'].iat[i]
  81. name_cn = monitor_point.loc[:,'名称'].iat[i]
  82. MAE = monitor_point.loc[:,'指标MAE'].iat[i]
  83. model_update_info[name] = {
  84. 'point_id' : name,
  85. 'point_name' : name_cn,
  86. 'point_class': name,
  87. 'thre_mae' : MAE,
  88. 'thre_mape' : 1,
  89. 'thre_days' : 7
  90. }
  91. equp_model.save_to_platform(
  92. version_id = datetime.now().strftime('%Y%m'),
  93. model_id = config_reader.get_equp_info(each_eaup_name,'模型编号','str'),
  94. update_method = 'update',
  95. model_info = model_update_info,
  96. MODEL_FILE_PATH = MODEL_FILE_PATH,
  97. MODEL_FUNC_PATH = MODEL_FUNC_PATH,
  98. )
  99. except Exception as E:
  100. ALL_RESULT['EXCEPTION']['Save'][each_eaup_name] = E
  101. continue
  102. print(ALL_RESULT)
  103. def clean_data(data) -> pd.DataFrame:
  104. data = (
  105. data
  106. .resample('60min')
  107. .mean()
  108. )
  109. return data
  110. def save_data(dir,file:str,data:pd.DataFrame):
  111. Path(dir).mkdir(parents=True,exist_ok=True)
  112. if file.endswith('.csv'):
  113. data.to_csv(os.path.join(dir,file),index=True)
  114. elif file.endswith('.pkl'):
  115. data.to_pickle(os.path.join(dir,file))
  116. else:
  117. raise Exception('file type error')