_base.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import importlib
  2. import pickle
  3. import inspect
  4. from typing import Union
  5. from datetime import datetime
  6. import numpy as np
  7. import pandas as pd
  8. from sklearn.metrics import (
  9. r2_score,
  10. mean_absolute_error,
  11. mean_absolute_percentage_error
  12. )
  13. try:
  14. import plotnine as gg
  15. except:
  16. pass
  17. class BaseModel:
  18. def __init__(self) -> None:
  19. self.last_fit_y_true :Union[np.ndarray,None] = None
  20. self.last_fit_y_pred :Union[np.ndarray,None] = None
  21. self.model_info = {'LOAD_INFO':{}}
  22. def record_load_info(self,**info):
  23. self.model_info['LOAD_INFO'].update(info)
  24. def record_model(
  25. self,
  26. model_name : str,
  27. model : dict,
  28. train_data : dict,
  29. train_metric: dict,
  30. keep_raw : bool = False
  31. ) -> None:
  32. self.model_info[f'model_{model_name}'] = model
  33. self.model_info[f'model_train_info_{model_name}'] = {}
  34. self.model_info[f'model_train_info_{model_name}']['datetime'] = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
  35. self.model_info[f'model_train_info_{model_name}']['metric'] = train_metric
  36. for name,value in train_data.items():
  37. if value is None:
  38. continue
  39. if keep_raw:
  40. self.model_info[f'model_train_info_{model_name}'][f'{name}_raw'] = value
  41. self.model_info[f'model_train_info_{model_name}'][f'{name}_max'] = round(np.max(value),2)
  42. self.model_info[f'model_train_info_{model_name}'][f'{name}_min'] = round(np.min(value),2)
  43. def _get_train_data(self,model_name) -> dict:
  44. if not self.is_model_exist(model_name):
  45. return {}
  46. raw_train_data = {}
  47. train_data_info = self.model_info[f'model_train_info_{model_name}']
  48. for name,value in train_data_info.items():
  49. if value is None:
  50. continue
  51. if name.endswith('_raw'):
  52. name_adj = name.strip('_raw')
  53. raw_train_data[name_adj] = value
  54. return raw_train_data
  55. @property
  56. def all_model_names(self):
  57. names = []
  58. for key in self.model_info:
  59. if ('_train_info_' not in key) and ('model_' in key):
  60. name = key.replace('model_','')
  61. names.append(name)
  62. return names
  63. def is_model_exist(self,model_name) -> bool:
  64. if model_name in self.all_model_names:
  65. return True
  66. else:
  67. return False
  68. def is_model_train_data_exist(self,model_name,data_name) -> bool:
  69. if not self.is_model_exist(model_name):
  70. return False
  71. if f'{data_name}_max' in self.model_info[f'model_train_info_{model_name}']:
  72. return True
  73. else:
  74. return False
  75. def save(self,path):
  76. pd.to_pickle(self.model_info,path)
  77. @classmethod
  78. def load(cls,path):
  79. model_info = pd.read_pickle(path)
  80. load_info_init = {}
  81. load_info_attr = {}
  82. cls_init_param = list(inspect.signature(cls.__init__).parameters.keys())
  83. for load_key,load_value in model_info.get('LOAD_INFO',{}).items():
  84. if load_key in cls_init_param:
  85. load_info_init[load_key] = load_value
  86. else:
  87. load_info_attr[load_key] = load_value
  88. model = cls(**load_info_init)
  89. model.model_info = model_info
  90. for attr_name,attr_value in load_info_attr.items():
  91. setattr(model,attr_name,attr_value)
  92. return model
  93. def save_to_platform(
  94. self,
  95. version_id : int,
  96. model_id : str,
  97. update_method : str,
  98. model_info : dict,
  99. MODEL_FILE_PATH: str,
  100. MODEL_FUNC_PATH: str,
  101. ) -> None:
  102. """
  103. model_info = {
  104. NAME:{
  105. point_id : ...,
  106. point_name : ...,
  107. point_class: ...,
  108. thre_mae : ...,
  109. thre_mape : ...,
  110. thre_days : ...,
  111. }
  112. }
  113. """
  114. self.save(MODEL_FILE_PATH)
  115. model_update_info = {}
  116. for model_name in self.all_model_names:
  117. if model_name not in model_info:
  118. continue
  119. model_update_info[model_name] = {
  120. 'metric' : {
  121. 'MAE' : self.model_info[f'model_train_info_{model_name}']['metric']['MAE'],
  122. 'MAPE': self.model_info[f'model_train_info_{model_name}']['metric']['MAPE'],
  123. },
  124. 'point_id' : model_info[model_name]['point_id'],
  125. 'point_name' : model_info[model_name]['point_name'],
  126. 'point_class': model_info[model_name]['point_class'],
  127. 'thre_mae' : model_info[model_name]['thre_mae'],
  128. 'thre_mape' : model_info[model_name]['thre_mape'],
  129. 'thre_days' : model_info[model_name]['thre_days'],
  130. }
  131. from ._update_utils import update
  132. update(
  133. version_id = version_id,
  134. model_id = model_id,
  135. model_info = model_update_info,
  136. update_method = update_method,
  137. MODEL_FUNC_PATH = MODEL_FUNC_PATH,
  138. MODEL_FILE_PATH = MODEL_FILE_PATH
  139. )
  140. @classmethod
  141. def load_from_platform(
  142. cls,
  143. reload = False,
  144. source = 'file', # file / id
  145. model_id = None,
  146. ):
  147. # 设备模型组件对应模型管理中的文件
  148. if source == 'file':
  149. MODEL_PACKAGE = importlib.import_module(
  150. '...models.model_func', package='.'.join(__name__.split('.')[:-1]))
  151. if reload:
  152. importlib.reload(MODEL_PACKAGE)
  153. model_info = getattr(MODEL_PACKAGE,'model')
  154. # 通过模型的id获取到模型文件
  155. elif source == 'id':
  156. if model_id is None:
  157. raise Exception('必须输入模型的id')
  158. from workflow_utils import get_model_version_file
  159. try:
  160. model_info = get_model_version_file(model_id=model_id,filename='model.pkl')
  161. model_info = pickle.loads(model_info)
  162. except Exception as e:
  163. print(e)
  164. raise Exception('模型文件获取失败')
  165. load_info = model_info.get('LOAD_INFO',{})
  166. model = cls(**load_info)
  167. model.model_info = model_info
  168. return model
  169. def metric(self,y_true,y_pred,show=True):
  170. mask = ~(np.isnan(y_true) | np.isnan(y_pred))
  171. y_true = y_true[mask]
  172. y_pred = y_pred[mask]
  173. r2 = r2_score(y_true,y_pred)
  174. mae = mean_absolute_error(y_true,y_pred)
  175. mape = mean_absolute_percentage_error(y_true,y_pred)
  176. if show:
  177. print(f'R2\t: {r2}\nMAE\t: {mae} \nMAPE\t: {mape}')
  178. return {'R2':r2,'MAE':mae,'MAPE':mape}
  179. def last_metric(self):
  180. y_true = self.last_fit_y_true
  181. y_pred = self.last_fit_y_pred
  182. return self.metric(y_true=y_true,y_pred=y_pred)
  183. def summary(self):
  184. ...
  185. def plot_TVP(self):
  186. plot = (
  187. pd.DataFrame(
  188. {
  189. 'Real' : self.last_fit_y_true.flatten(),
  190. 'Predict': self.last_fit_y_pred.flatten()
  191. }
  192. )
  193. .pipe(gg.ggplot)
  194. + gg.aes(x='Real',y='Predict')
  195. + gg.geom_point()
  196. + gg.geom_abline(slope=1,intercept=0,color='red')
  197. )
  198. return plot
  199. def plot_contour(data,x,y,z,labs=None):
  200. data_pivot = (
  201. data.pivot(index=x,columns=y,values=z)
  202. .sort_index(axis=1,ascending=False)
  203. .sort_index(axis=0,ascending=False)
  204. )
  205. X = np.repeat(data_pivot.index.values.reshape(-1,1),len(data_pivot.columns),axis=1)
  206. Y = np.repeat(data_pivot.columns.values.reshape(1,-1),len(data_pivot.index),axis=0)
  207. Z = data_pivot.values
  208. labs = {} if labs is None else labs
  209. plot = (
  210. data.pipe(gg.ggplot)
  211. + gg.aes(x=x,y=y,fill=z)
  212. + gg.geom_tile()
  213. + gg.coord_cartesian(expand=False)
  214. + gg.theme(legend_position='none')
  215. + gg.labs(**labs)
  216. )
  217. fig = plot.draw()
  218. axs = fig.get_axes()
  219. contour = axs[0].contour(X, Y, Z, levels=10,colors='black',linewidths=1)
  220. axs[0].clabel(contour, inline=True, fontsize=8)
  221. return fig