| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- import importlib
- import pickle
- from typing import Union
- from datetime import datetime
- import numpy as np
- import pandas as pd
- from sklearn.metrics import (
- r2_score,
- mean_absolute_error,
- mean_absolute_percentage_error
- )
- try:
- import plotnine as gg
- except:
- pass
- class BaseModel:
-
- def __init__(self) -> None:
- self.last_fit_y_true :Union[np.ndarray,None] = None
- self.last_fit_y_pred :Union[np.ndarray,None] = None
- self.model_info = {'LOAD_INFO':{}}
-
- def record_load_info(self,**info):
- for attr_name,attr_value in info.items():
- setattr(self,attr_name,attr_value)
- self.model_info['LOAD_INFO'].update(info)
-
- def record_model(
- self,
- model_name : str,
- model : dict,
- train_data : dict,
- train_metric: dict,
- keep_raw : bool = False
- ) -> None:
- self.model_info[f'model_{model_name}'] = model
- self.model_info[f'model_train_info_{model_name}'] = {}
- self.model_info[f'model_train_info_{model_name}']['datetime'] = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
- self.model_info[f'model_train_info_{model_name}']['metric'] = train_metric
- for name,value in train_data.items():
- if value is None:
- continue
- if keep_raw:
- self.model_info[f'model_train_info_{model_name}'][f'{name}_raw'] = value
- self.model_info[f'model_train_info_{model_name}'][f'{name}_max'] = round(np.max(value),2)
- self.model_info[f'model_train_info_{model_name}'][f'{name}_min'] = round(np.min(value),2)
-
- def _get_train_data(self,model_name) -> dict:
- if not self.is_model_exist(model_name):
- return {}
- raw_train_data = {}
- train_data_info = self.model_info[f'model_train_info_{model_name}']
- for name,value in train_data_info.items():
- if value is None:
- continue
- if name.endswith('_raw'):
- name_adj = name.strip('_raw')
- raw_train_data[name_adj] = value
- return raw_train_data
-
- @property
- def all_model_names(self):
- names = []
- for key in self.model_info:
- if ('_train_info_' not in key) and ('model_' in key):
- name = key.replace('model_','')
- names.append(name)
- return names
-
- def is_model_exist(self,model_name) -> bool:
- if model_name in self.all_model_names:
- return True
- else:
- return False
-
- def is_model_train_data_exist(self,model_name,data_name) -> bool:
- if not self.is_model_exist(model_name):
- return False
- if f'{data_name}_max' in self.model_info[f'model_train_info_{model_name}']:
- return True
- else:
- return False
-
- def save(self,path):
- pd.to_pickle(self.model_info,path)
-
- @classmethod
- def load(cls,path):
- model_info = pd.read_pickle(path)
- load_info = model_info.get('LOAD_INFO',{})
- model = cls(**load_info)
- model.model_info = model_info
- return model
-
- def save_to_platform(
- self,
- version_id : int,
- model_id : str,
- update_method : str,
- model_info : dict,
- MODEL_FILE_PATH: str,
- MODEL_FUNC_PATH: str,
- ) -> None:
- """
- model_info = {
- NAME:{
- point_id : ...,
- point_name : ...,
- point_class: ...,
- thre_mae : ...,
- thre_mape : ...,
- thre_days : ...,
- }
- }
- """
- self.save(MODEL_FILE_PATH)
-
- model_update_info = {}
- for model_name in self.all_model_names:
- if model_name not in model_info:
- continue
- model_update_info[model_name] = {
- 'metric' : {
- 'MAE' : self.model_info[f'model_train_info_{model_name}']['metric']['MAE'],
- 'MAPE': self.model_info[f'model_train_info_{model_name}']['metric']['MAPE'],
- },
- 'point_id' : model_info[model_name]['point_id'],
- 'point_name' : model_info[model_name]['point_name'],
- 'point_class': model_info[model_name]['point_class'],
- 'thre_mae' : model_info[model_name]['thre_mae'],
- 'thre_mape' : model_info[model_name]['thre_mape'],
- 'thre_days' : model_info[model_name]['thre_days'],
- }
- from ._update_utils import update
- update(
- version_id = version_id,
- model_id = model_id,
- model_info = model_update_info,
- update_method = update_method,
- MODEL_FUNC_PATH = MODEL_FUNC_PATH,
- MODEL_FILE_PATH = MODEL_FILE_PATH
- )
-
- @classmethod
- def load_from_platform(
- cls,
- reload = False,
- source = 'file', # file / id
- model_id = None,
- ):
- # 设备模型组件对应模型管理中的文件
- if source == 'file':
- MODEL_PACKAGE = importlib.import_module(
- '...models.model_func', package='.'.join(__name__.split('.')[:-1]))
- if reload:
- importlib.reload(MODEL_PACKAGE)
- model_info = getattr(MODEL_PACKAGE,'model')
-
- # 通过模型的id获取到模型文件
- elif source == 'id':
- if model_id is None:
- raise Exception('必须输入模型的id')
- from workflow_utils import get_model_version_file
- try:
- model_info = get_model_version_file(model_id=model_id,filename='model.pkl')
- model_info = pickle.loads(model_info)
- except Exception as e:
- print(e)
- raise Exception('模型文件获取失败')
-
- load_info = model_info.get('LOAD_INFO',{})
- model = cls(**load_info)
- model.model_info = model_info
-
- return model
-
- def metric(self,y_true,y_pred,show=True):
- mask = ~(np.isnan(y_true) | np.isnan(y_pred))
- y_true = y_true[mask]
- y_pred = y_pred[mask]
- r2 = r2_score(y_true,y_pred)
- mae = mean_absolute_error(y_true,y_pred)
- mape = mean_absolute_percentage_error(y_true,y_pred)
- if show:
- print(f'R2\t: {r2}\nMAE\t: {mae} \nMAPE\t: {mape}')
- return {'R2':r2,'MAE':mae,'MAPE':mape}
-
- def last_metric(self):
- y_true = self.last_fit_y_true
- y_pred = self.last_fit_y_pred
- return self.metric(y_true=y_true,y_pred=y_pred)
-
- def summary(self):
- ...
-
- def plot_TVP(self):
- plot = (
- pd.DataFrame(
- {
- 'Real' : self.last_fit_y_true.flatten(),
- 'Predict': self.last_fit_y_pred.flatten()
- }
- )
- .pipe(gg.ggplot)
- + gg.aes(x='Real',y='Predict')
- + gg.geom_point()
- + gg.geom_abline(slope=1,intercept=0,color='red')
- )
- return plot
- def plot_contour(data,x,y,z,labs=None):
- data_pivot = (
- data.pivot(index=x,columns=y,values=z)
- .sort_index(axis=1,ascending=False)
- .sort_index(axis=0,ascending=False)
- )
- X = np.repeat(data_pivot.index.values.reshape(-1,1),len(data_pivot.columns),axis=1)
- Y = np.repeat(data_pivot.columns.values.reshape(1,-1),len(data_pivot.index),axis=0)
- Z = data_pivot.values
-
- labs = {} if labs is None else labs
- plot = (
- data.pipe(gg.ggplot)
- + gg.aes(x=x,y=y,fill=z)
- + gg.geom_tile()
- + gg.coord_cartesian(expand=False)
- + gg.theme(legend_position='none')
- + gg.labs(**labs)
- )
- fig = plot.draw()
- axs = fig.get_axes()
- contour = axs[0].contour(X, Y, Z, levels=10,colors='black',linewidths=1)
- axs[0].clabel(contour, inline=True, fontsize=8)
- return fig
|