_base.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import importlib
  2. import pickle
  3. from typing import Union
  4. from datetime import datetime
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.metrics import (
  8. r2_score,
  9. mean_absolute_error,
  10. mean_absolute_percentage_error
  11. )
  12. try:
  13. import plotnine as gg
  14. except:
  15. pass
  16. class BaseModel:
  17. def __init__(self) -> None:
  18. self.last_fit_y_true :Union[np.ndarray,None] = None
  19. self.last_fit_y_pred :Union[np.ndarray,None] = None
  20. self.model_info = {'LOAD_INFO':{}}
  21. def record_load_info(self,**info):
  22. for attr_name,attr_value in info.items():
  23. setattr(self,attr_name,attr_value)
  24. self.model_info['LOAD_INFO'].update(info)
  25. def record_model(
  26. self,
  27. model_name : str,
  28. model : dict,
  29. train_data : dict,
  30. train_metric: dict,
  31. keep_raw : bool = False
  32. ) -> None:
  33. self.model_info[f'model_{model_name}'] = model
  34. self.model_info[f'model_train_info_{model_name}'] = {}
  35. self.model_info[f'model_train_info_{model_name}']['datetime'] = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
  36. self.model_info[f'model_train_info_{model_name}']['metric'] = train_metric
  37. for name,value in train_data.items():
  38. if value is None:
  39. continue
  40. if keep_raw:
  41. self.model_info[f'model_train_info_{model_name}'][f'{name}_raw'] = value
  42. self.model_info[f'model_train_info_{model_name}'][f'{name}_max'] = round(np.max(value),2)
  43. self.model_info[f'model_train_info_{model_name}'][f'{name}_min'] = round(np.min(value),2)
  44. def _get_train_data(self,model_name) -> dict:
  45. if not self.is_model_exist(model_name):
  46. return {}
  47. raw_train_data = {}
  48. train_data_info = self.model_info[f'model_train_info_{model_name}']
  49. for name,value in train_data_info.items():
  50. if value is None:
  51. continue
  52. if name.endswith('_raw'):
  53. name_adj = name.strip('_raw')
  54. raw_train_data[name_adj] = value
  55. return raw_train_data
  56. @property
  57. def all_model_names(self):
  58. names = []
  59. for key in self.model_info:
  60. if ('_train_info_' not in key) and ('model_' in key):
  61. name = key.replace('model_','')
  62. names.append(name)
  63. return names
  64. def is_model_exist(self,model_name) -> bool:
  65. if model_name in self.all_model_names:
  66. return True
  67. else:
  68. return False
  69. def is_model_train_data_exist(self,model_name,data_name) -> bool:
  70. if not self.is_model_exist(model_name):
  71. return False
  72. if f'{data_name}_max' in self.model_info[f'model_train_info_{model_name}']:
  73. return True
  74. else:
  75. return False
  76. def save(self,path):
  77. pd.to_pickle(self.model_info,path)
  78. @classmethod
  79. def load(cls,path):
  80. model_info = pd.read_pickle(path)
  81. load_info = model_info.get('LOAD_INFO',{})
  82. model = cls(**load_info)
  83. model.model_info = model_info
  84. return model
  85. def save_to_platform(
  86. self,
  87. version_id : int,
  88. model_id : str,
  89. update_method : str,
  90. model_info : dict,
  91. MODEL_FILE_PATH: str,
  92. MODEL_FUNC_PATH: str,
  93. ) -> None:
  94. """
  95. model_info = {
  96. NAME:{
  97. point_id : ...,
  98. point_name : ...,
  99. point_class: ...,
  100. thre_mae : ...,
  101. thre_mape : ...,
  102. thre_days : ...,
  103. }
  104. }
  105. """
  106. self.save(MODEL_FILE_PATH)
  107. model_update_info = {}
  108. for model_name in self.all_model_names:
  109. if model_name not in model_info:
  110. continue
  111. model_update_info[model_name] = {
  112. 'metric' : {
  113. 'MAE' : self.model_info[f'model_train_info_{model_name}']['metric']['MAE'],
  114. 'MAPE': self.model_info[f'model_train_info_{model_name}']['metric']['MAPE'],
  115. },
  116. 'point_id' : model_info[model_name]['point_id'],
  117. 'point_name' : model_info[model_name]['point_name'],
  118. 'point_class': model_info[model_name]['point_class'],
  119. 'thre_mae' : model_info[model_name]['thre_mae'],
  120. 'thre_mape' : model_info[model_name]['thre_mape'],
  121. 'thre_days' : model_info[model_name]['thre_days'],
  122. }
  123. from ._update_utils import update
  124. update(
  125. version_id = version_id,
  126. model_id = model_id,
  127. model_info = model_update_info,
  128. update_method = update_method,
  129. MODEL_FUNC_PATH = MODEL_FUNC_PATH,
  130. MODEL_FILE_PATH = MODEL_FILE_PATH
  131. )
  132. @classmethod
  133. def load_from_platform(
  134. cls,
  135. reload = False,
  136. source = 'file', # file / id
  137. model_id = None,
  138. ):
  139. # 设备模型组件对应模型管理中的文件
  140. if source == 'file':
  141. MODEL_PACKAGE = importlib.import_module(
  142. '..models.model_func', package='.'.join(__name__.split('.')[:-1]))
  143. if reload:
  144. importlib.reload(MODEL_PACKAGE)
  145. model_info = getattr(MODEL_PACKAGE,'model')
  146. # 通过模型的id获取到模型文件
  147. elif source == 'id':
  148. if model_id is None:
  149. raise Exception('必须输入模型的id')
  150. from workflow_utils import get_model_version_file
  151. try:
  152. model_info = get_model_version_file(model_id=model_id,filename='model.pkl')
  153. model_info = pickle.loads(model_info)
  154. except Exception as e:
  155. print(e)
  156. raise Exception('模型文件获取失败')
  157. load_info = model_info.get('LOAD_INFO',{})
  158. model = cls(**load_info)
  159. model.model_info = model_info
  160. return model
  161. def metric(self,y_true,y_pred,show=True):
  162. mask = ~(np.isnan(y_true) | np.isnan(y_pred))
  163. y_true = y_true[mask]
  164. y_pred = y_pred[mask]
  165. r2 = r2_score(y_true,y_pred)
  166. mae = mean_absolute_error(y_true,y_pred)
  167. mape = mean_absolute_percentage_error(y_true,y_pred)
  168. if show:
  169. print(f'R2\t: {r2}\nMAE\t: {mae} \nMAPE\t: {mape}')
  170. return {'R2':r2,'MAE':mae,'MAPE':mape}
  171. def last_metric(self):
  172. y_true = self.last_fit_y_true
  173. y_pred = self.last_fit_y_pred
  174. return self.metric(y_true=y_true,y_pred=y_pred)
  175. def summary(self):
  176. ...
  177. def plot_TVP(self):
  178. plot = (
  179. pd.DataFrame(
  180. {
  181. 'Real' : self.last_fit_y_true.flatten(),
  182. 'Predict': self.last_fit_y_pred.flatten()
  183. }
  184. )
  185. .pipe(gg.ggplot)
  186. + gg.aes(x='Real',y='Predict')
  187. + gg.geom_point()
  188. + gg.geom_abline(slope=1,intercept=0,color='red')
  189. )
  190. return plot
  191. def plot_contour(data,x,y,z,labs=None):
  192. data_pivot = (
  193. data.pivot(index=x,columns=y,values=z)
  194. .sort_index(axis=1,ascending=False)
  195. .sort_index(axis=0,ascending=False)
  196. )
  197. X = np.repeat(data_pivot.index.values.reshape(-1,1),len(data_pivot.columns),axis=1)
  198. Y = np.repeat(data_pivot.columns.values.reshape(1,-1),len(data_pivot.index),axis=0)
  199. Z = data_pivot.values
  200. labs = {} if labs is None else labs
  201. plot = (
  202. data.pipe(gg.ggplot)
  203. + gg.aes(x=x,y=y,fill=z)
  204. + gg.geom_tile()
  205. + gg.coord_cartesian(expand=False)
  206. + gg.theme(legend_position='none')
  207. + gg.labs(**labs)
  208. )
  209. fig = plot.draw()
  210. axs = fig.get_axes()
  211. contour = axs[0].contour(X, Y, Z, levels=10,colors='black',linewidths=1)
  212. axs[0].clabel(contour, inline=True, fontsize=8)
  213. return fig