import importlib import pickle import inspect from typing import Union from datetime import datetime import numpy as np import pandas as pd from sklearn.metrics import ( r2_score, mean_absolute_error, mean_absolute_percentage_error ) try: import plotnine as gg except: pass class BaseModel: def __init__(self) -> None: self.last_fit_y_true :Union[np.ndarray,None] = None self.last_fit_y_pred :Union[np.ndarray,None] = None self.model_info = {'LOAD_INFO':{}} def record_load_info(self,**info): self.model_info['LOAD_INFO'].update(info) def record_model( self, model_name : str, model : dict, train_data : dict, train_metric: dict, keep_raw : bool = False ) -> None: self.model_info[f'model_{model_name}'] = model self.model_info[f'model_train_info_{model_name}'] = {} self.model_info[f'model_train_info_{model_name}']['datetime'] = datetime.now().strftime('%Y/%m/%d %H:%M:%S') self.model_info[f'model_train_info_{model_name}']['metric'] = train_metric for name,value in train_data.items(): if value is None: continue if keep_raw: self.model_info[f'model_train_info_{model_name}'][f'{name}_raw'] = value self.model_info[f'model_train_info_{model_name}'][f'{name}_max'] = round(np.max(value),2) self.model_info[f'model_train_info_{model_name}'][f'{name}_min'] = round(np.min(value),2) def _get_train_data(self,model_name) -> dict: if not self.is_model_exist(model_name): return {} raw_train_data = {} train_data_info = self.model_info[f'model_train_info_{model_name}'] for name,value in train_data_info.items(): if value is None: continue if name.endswith('_raw'): name_adj = name.strip('_raw') raw_train_data[name_adj] = value return raw_train_data @property def all_model_names(self): names = [] for key in self.model_info: if ('_train_info_' not in key) and ('model_' in key): name = key.replace('model_','') names.append(name) return names def is_model_exist(self,model_name) -> bool: if model_name in self.all_model_names: return True else: return False def is_model_train_data_exist(self,model_name,data_name) -> bool: if not self.is_model_exist(model_name): return False if f'{data_name}_max' in self.model_info[f'model_train_info_{model_name}']: return True else: return False def save(self,path): pd.to_pickle(self.model_info,path) @classmethod def load(cls,path): model_info = pd.read_pickle(path) load_info_init = {} load_info_attr = {} cls_init_param = list(inspect.signature(cls.__init__).parameters.keys()) for load_key,load_value in model_info.get('LOAD_INFO',{}).items(): if load_key in cls_init_param: load_info_init[load_key] = load_value else: load_info_attr[load_key] = load_value model = cls(**load_info_init) model.model_info = model_info for attr_name,attr_value in load_info_attr.items(): setattr(model,attr_name,attr_value) return model def save_to_platform( self, version_id : int, model_id : str, update_method : str, model_info : dict, MODEL_FILE_PATH: str, MODEL_FUNC_PATH: str, ) -> None: """ model_info = { NAME:{ point_id : ..., point_name : ..., point_class: ..., thre_mae : ..., thre_mape : ..., thre_days : ..., } } """ self.save(MODEL_FILE_PATH) model_update_info = {} for model_name in self.all_model_names: if model_name not in model_info: continue model_update_info[model_name] = { 'metric' : { 'MAE' : self.model_info[f'model_train_info_{model_name}']['metric']['MAE'], 'MAPE': self.model_info[f'model_train_info_{model_name}']['metric']['MAPE'], }, 'point_id' : model_info[model_name]['point_id'], 'point_name' : model_info[model_name]['point_name'], 'point_class': model_info[model_name]['point_class'], 'thre_mae' : model_info[model_name]['thre_mae'], 'thre_mape' : model_info[model_name]['thre_mape'], 'thre_days' : model_info[model_name]['thre_days'], } from ._update_utils import update update( version_id = version_id, model_id = model_id, model_info = model_update_info, update_method = update_method, MODEL_FUNC_PATH = MODEL_FUNC_PATH, MODEL_FILE_PATH = MODEL_FILE_PATH ) @classmethod def load_from_platform( cls, reload = False, source = 'file', # file / id model_id = None, ): # 设备模型组件对应模型管理中的文件 if source == 'file': MODEL_PACKAGE = importlib.import_module( '...models.model_func', package='.'.join(__name__.split('.')[:-1])) if reload: importlib.reload(MODEL_PACKAGE) model_info = getattr(MODEL_PACKAGE,'model') # 通过模型的id获取到模型文件 elif source == 'id': if model_id is None: raise Exception('必须输入模型的id') from workflow_utils import get_model_version_file try: model_info = get_model_version_file(model_id=model_id,filename='model.pkl') model_info = pickle.loads(model_info) except Exception as e: print(e) raise Exception('模型文件获取失败') load_info = model_info.get('LOAD_INFO',{}) model = cls(**load_info) model.model_info = model_info return model def metric(self,y_true,y_pred,show=True): mask = ~(np.isnan(y_true) | np.isnan(y_pred)) y_true = y_true[mask] y_pred = y_pred[mask] r2 = r2_score(y_true,y_pred) mae = mean_absolute_error(y_true,y_pred) mape = mean_absolute_percentage_error(y_true,y_pred) if show: print(f'R2\t: {r2}\nMAE\t: {mae} \nMAPE\t: {mape}') return {'R2':r2,'MAE':mae,'MAPE':mape} def last_metric(self): y_true = self.last_fit_y_true y_pred = self.last_fit_y_pred return self.metric(y_true=y_true,y_pred=y_pred) def summary(self): ... def plot_TVP(self): plot = ( pd.DataFrame( { 'Real' : self.last_fit_y_true.flatten(), 'Predict': self.last_fit_y_pred.flatten() } ) .pipe(gg.ggplot) + gg.aes(x='Real',y='Predict') + gg.geom_point() + gg.geom_abline(slope=1,intercept=0,color='red') ) return plot def plot_contour(data,x,y,z,labs=None): data_pivot = ( data.pivot(index=x,columns=y,values=z) .sort_index(axis=1,ascending=False) .sort_index(axis=0,ascending=False) ) X = np.repeat(data_pivot.index.values.reshape(-1,1),len(data_pivot.columns),axis=1) Y = np.repeat(data_pivot.columns.values.reshape(1,-1),len(data_pivot.index),axis=0) Z = data_pivot.values labs = {} if labs is None else labs plot = ( data.pipe(gg.ggplot) + gg.aes(x=x,y=y,fill=z) + gg.geom_tile() + gg.coord_cartesian(expand=False) + gg.theme(legend_position='none') + gg.labs(**labs) ) fig = plot.draw() axs = fig.get_axes() contour = axs[0].contour(X, Y, Z, levels=10,colors='black',linewidths=1) axs[0].clabel(contour, inline=True, fontsize=8) return fig