train.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. from datetime import datetime
  3. from pathlib import Path
  4. from pprint import pprint
  5. import pandas as pd
  6. from ...model.DHU.DHU_AB import DHU_AB
  7. from .config_reader import ConfigReader
  8. from ...tools.data_loader import DataLoader
  9. NOW = datetime.now().replace(second=0,microsecond=0)
  10. PATH = os.path.dirname(os.path.realpath(__file__)).replace('\\','/')
  11. MODEL_FUNC_PATH = f'{PATH}/model_func.py'
  12. MODEL_FILE_PATH = f'./model.pkl'
  13. def train(*inputs,config=None):
  14. config = {} if config is None else config
  15. if '__LOCAL' in config.keys():
  16. config_reader_path = config['__LOCAL']
  17. data_URL = config['__URL']
  18. plot_metric = True
  19. else:
  20. config_reader_path = '/mnt/workflow_data'
  21. data_URL = 'http://basedataportal-svc:8080/data/getpointsdata'
  22. plot_metric = False
  23. config_reader = ConfigReader(path=f'{config_reader_path}/DHU_AB配置.xlsx')
  24. ALL_RESULT = {
  25. 'EXCEPTION':{
  26. 'Data': {},
  27. 'Fit' : {},
  28. 'Save': {}
  29. }
  30. }
  31. for each_eaup_name in config_reader.all_equp_names:
  32. equp_type = config_reader.get_equp_info(each_eaup_name,key='设备类型',info_type='str')
  33. # 获取数据
  34. try:
  35. # 部分情况下设备不需要部分点位表中的点位
  36. rm_point_name = []
  37. if not config_reader.get_equp_info(each_eaup_name,'存在回风口','bool'):
  38. rm_point_name += ['mixed_1_TinM','mixed_1_DinM']
  39. if not config_reader.get_equp_info(each_eaup_name,'存在补风口','bool'):
  40. rm_point_name += ['mixed_2_TinM','mixed_2_DinM']
  41. # 获取历史数据
  42. data_loader = DataLoader(
  43. path = f'{config_reader_path}/data/train/data_his/',
  44. start_time = config_reader.get_app_info(each_eaup_name,app_type='模型训练',key='开始时间',info_type='datetime'),
  45. end_time = config_reader.get_app_info(each_eaup_name,app_type='模型训练',key='结束时间',info_type='datetime'),
  46. print_process = config_reader.get_app_info(each_eaup_name,app_type='模型训练',key='打印取数日志',info_type='bool'),
  47. )
  48. data_loader.download_equp_data(
  49. equp_name = each_eaup_name,
  50. point = config_reader.get_equp_point(each_eaup_name,equp_type,equp_class=['A','B']),
  51. url = data_URL,
  52. clean_cache = False,
  53. rm_point_name = rm_point_name
  54. )
  55. equp_data = data_loader.get_equp_data(each_eaup_name)
  56. save_data(f'{config_reader_path}/data/train/data_his_raw',f'{each_eaup_name}.pkl',equp_data)
  57. except Exception as E:
  58. ALL_RESULT['EXCEPTION']['Data'][each_eaup_name] = E
  59. continue
  60. # 训练模型
  61. try:
  62. equp_model = DHU_AB(
  63. DHU_type = equp_type,
  64. exist_Fa_H = config_reader.get_equp_info(each_eaup_name,'存在回风口','bool'),
  65. exist_Fa_B = config_reader.get_equp_info(each_eaup_name,'存在补风口','bool'),
  66. other_info={
  67. 'heatingcoil_1_Fs_rated': config_reader.get_equp_info(each_eaup_name,'前蒸汽盘管额定流量','float'),
  68. 'heatingcoil_2_Fs_rated': config_reader.get_equp_info(each_eaup_name,'后蒸汽盘管额定流量','float'),
  69. }
  70. )
  71. # 清洗数据
  72. Path(f'{config_reader_path}/data/train/clean_log/').mkdir(parents=True, exist_ok=True)
  73. equp_data = equp_model.clean_data(
  74. data = equp_data,
  75. data_type = ['input','observed'],
  76. print_process = True,
  77. fill_zero = False,
  78. save_log = f'{config_reader_path}/data/train/clean_log/{each_eaup_name}.txt',
  79. )
  80. equp_data = equp_data.resample('60min').mean().dropna()
  81. save_data(f'{config_reader_path}/data/train/data_his_clean',f'{each_eaup_name}.pkl',equp_data)
  82. if not config_reader.get_app_info(each_eaup_name,'模型训练','训练模型','bool'):
  83. continue
  84. equp_model.fit(
  85. input_data = equp_data,
  86. observed_data = equp_data,
  87. plot_TVP = False,
  88. rw_FA_val = config_reader.get_app_info(each_eaup_name,'模型训练','新风阀门开度参数','bool')
  89. )
  90. Path(f'{config_reader_path}/model').mkdir(parents=True, exist_ok=True)
  91. equp_model.save(f'{config_reader_path}/model/{each_eaup_name}.pkl')
  92. save_data(f'{config_reader_path}/data/train/data_TVP',f'{each_eaup_name}.csv',equp_model.TVP_data)
  93. save_data(f'{config_reader_path}/data/train/data_metric',f'{each_eaup_name}.csv',equp_model.TVP_metric.round(2))
  94. if plot_metric:
  95. path = f'{config_reader_path}/plot/TVP'
  96. Path(path).mkdir(parents=True, exist_ok=True)
  97. equp_model.plot_TVP(equp_model.TVP_data,save_path=f'{path}/{each_eaup_name}.png')
  98. except Exception as E:
  99. ALL_RESULT['EXCEPTION']['Fit'][each_eaup_name] = E
  100. continue
  101. # 模型迭代
  102. if not config_reader.get_app_info(each_eaup_name,'模型训练','迭代模型','bool'):
  103. continue
  104. try:
  105. monitor_point = config_reader.point.loc[lambda dt:dt.类型=='B']
  106. model_update_info = {}
  107. for i in range(len(monitor_point)):
  108. name = monitor_point.loc[:,'编号'].iat[i]
  109. name_cn = monitor_point.loc[:,'名称'].iat[i]
  110. MAE = monitor_point.loc[:,'指标MAE'].iat[i]
  111. model_update_info[name] = {
  112. 'point_id' : name,
  113. 'point_name' : name_cn,
  114. 'point_class': name,
  115. 'thre_mae' : MAE,
  116. 'thre_mape' : 1,
  117. 'thre_days' : 7
  118. }
  119. equp_model.save_to_platform(
  120. version_id = datetime.now().strftime('%Y%m'),
  121. model_id = config_reader.get_equp_info(each_eaup_name,'模型编号','str'),
  122. update_method = 'update',
  123. model_info = model_update_info,
  124. MODEL_FILE_PATH = MODEL_FILE_PATH,
  125. MODEL_FUNC_PATH = MODEL_FUNC_PATH,
  126. )
  127. except Exception as E:
  128. ALL_RESULT['EXCEPTION']['Save'][each_eaup_name] = E
  129. continue
  130. pprint(ALL_RESULT)
  131. def save_data(dir,file:str,data:pd.DataFrame):
  132. Path(dir).mkdir(parents=True,exist_ok=True)
  133. if file.endswith('.csv'):
  134. data.to_csv(os.path.join(dir,file),index=True)
  135. elif file.endswith('.pkl'):
  136. data.to_pickle(os.path.join(dir,file))
  137. else:
  138. raise Exception('file type error')