| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- import os
- from typing import Union
- import numpy as np
- import pandas as pd
- from sklearn.metrics import (
- r2_score,
- mean_absolute_error,
- mean_absolute_percentage_error
- )
- try:
- import plotnine as gg
- except:
- pass
- from ._base import BaseModel
- class BaseDevice(BaseModel):
-
- val_rw_adj_target = None
-
- def __init__(self) -> None:
- super().__init__()
-
- def predict(self,input_data:pd.DataFrame) -> dict:
- param_posterior = self.model_info['model_ATD']
- res = self.model(
- **{k:input_data.loc[:,v].values for k,v in self.model_input_data_columns.items()},
- engine = 'numpy',
- components = self.components,
- param = param_posterior
- )
- return res
-
- def predict_system(self,input_data : pd.DataFrame) -> pd.DataFrame:
- pred_res = self.predict(input_data)
- system_output = {}
- for equp_name,output_info in pred_res.items():
- for output_name,output_value in output_info.items():
- system_output[f'{equp_name}_{output_name}'] = output_value
- system_output = pd.DataFrame(system_output)
- return system_output
-
- def get_TVP(self,posterior:dict,observed_data:pd.DataFrame):
- TVP_data = []
- for param_name in posterior.keys():
- if param_name.replace('_mu','') not in observed_data.columns:
- continue
- TVP_data.append(
- pd.DataFrame(
- {
- 'idx' : observed_data.index,
- 'param_name': param_name.replace('_mu',''),
- 'real' : observed_data.loc[:,param_name.replace('_mu','')].values,
- 'pred' : posterior[param_name]
- }
- )
- )
- TVP_data = pd.concat(TVP_data,axis=0)
- return TVP_data
-
- def get_metric(self,TVP:pd.DataFrame):
- group_by_data = TVP.groupby(['param_name'])[['pred','real']]
- TVP_metric = (
- pd.concat(
- [
- group_by_data.apply(lambda dt:r2_score(dt.real,dt.pred)),
- group_by_data.apply(lambda dt:mean_absolute_error(dt.real,dt.pred)),
- group_by_data.apply(lambda dt:mean_absolute_percentage_error(dt.real,dt.pred)),
- ],
- axis=1
- )
- .set_axis(['R2','MAE','MAPE'],axis=1)
- .sort_values(by='R2',ascending=True)
- )
- return TVP_metric
-
- def plot_TVP(self,TVP,save_path=None):
- plot = (
- TVP
- .pipe(gg.ggplot)
- + gg.aes(x='real',y='pred')
- + gg.geom_point()
- + gg.facet_wrap(facets='param_name',scales='free')
- + gg.geom_abline(intercept=0,slope=1,color='red')
- + gg.theme(figure_size=[10,10])
- )
- if save_path is not None:
- plot.save(filename=save_path)
-
- return plot
-
-
- def curve(
- self,
- input_data: pd.DataFrame,
- x : str,
- y : str,
- color : str = None,
- grid_x : str = None,
- grid_y : str = None,
- ):
-
- if x not in input_data.columns:
- raise Exception(f'{x} is not in input_data')
-
- product = [np.linspace(input_data.loc[:,x].min(),input_data.loc[:,x].max(),100)]
- names = [x]
-
- if color is not None:
- if color not in input_data.columns:
- raise Exception(f'{color} is not in input_data')
- product.append(np.quantile(input_data.loc[:,color],q=[0.25,0.5,0.75]))
- names.append(color)
- if grid_x is not None:
- if grid_x not in input_data.columns:
- raise Exception(f'{grid_x} is not in input_data')
- product.append(np.quantile(input_data.loc[:,grid_x],q=[0.25,0.5,0.75]))
- names.append(grid_x)
- if grid_y is not None:
- if grid_y not in input_data.columns:
- raise Exception(f'{grid_y} is not in input_data')
- product.append(np.quantile(input_data.loc[:,grid_y],q=[0.25,0.5,0.75]))
- names.append(grid_y)
-
- curve_input = (
- pd.MultiIndex.from_product(
- product,
- names=names
- )
- .to_frame(index=False)
- )
- curve_input_all = curve_input.copy(deep=True)
-
- for col in input_data.columns:
- if col not in curve_input_all:
- curve_input_all.loc[:,col] = input_data.loc[:,col].median()
- pred_data = self.predict_system(curve_input_all)
- if y not in pred_data.columns:
- raise Exception(f'{y} is not in Prediction')
-
- curve_data = pd.concat([curve_input,pred_data],axis=1)
- plot = (
- curve_data
- .round(2)
- .pipe(gg.ggplot)
- + gg.aes(x=x,y=y)
- + gg.geom_line()
- )
- if color is not None:
- plot += gg.aes(color=f'factor({color})',group=color)
- plot += gg.labs(color=color)
-
- if grid_x is not None or grid_y is not None:
- plot += gg.facet_grid(rows=grid_x,cols=grid_y,labeller='label_both')
-
- return plot
-
-
- @property
- def F_air_val_rw(self):
- raise NotImplementedError
-
- def set_F_air_val_rw(self,value:float):
- raise NotImplementedError
-
- def find_F_air_val_rw(
- self,
- input_data : pd.DataFrame,
- observed_data : pd.DataFrame,
- plot : bool = False,
- rw_value_range: Union[None,tuple] = None
- ):
- if self.val_rw_adj_target is None:
- raise NotImplementedError('请先设置val_rw_adj_target')
-
- raw_F_air_val_rw = self.F_air_val_rw
- if rw_value_range is None:
- rw_value_range = np.linspace(1e-6,raw_F_air_val_rw*2,500)
- else:
- rw_value_range = np.linspace(rw_value_range[0],rw_value_range[1],500)
- mae = []
- for rw_value in rw_value_range:
- self.set_F_air_val_rw(rw_value)
- pred = self.predict_system(input_data).loc[:,self.val_rw_adj_target].values.flatten()
- real = observed_data.loc[:,self.val_rw_adj_target].values.flatten()
- mae.append(mean_absolute_error(pred,real))
- best_rw_value = rw_value_range[np.argmin(mae)]
- best_rw_mae = mae[np.argmin(mae)]
- raw_rw_mae = mae[np.argmin(np.abs(raw_F_air_val_rw-rw_value_range))]
- self.set_F_air_val_rw(raw_F_air_val_rw)
- print(f'val_rw:{raw_F_air_val_rw:.2f} -> {best_rw_value:.2f}, MAE:{raw_rw_mae:.2f} -> {best_rw_mae:.2f}')
-
- if plot:
- import plotnine as gg
- plot = (
- pd.DataFrame(
- {
- 'val_rw': rw_value_range,
- 'MAE' : mae
- }
- )
- .pipe(gg.ggplot)
- + gg.aes(x='val_rw',y='MAE')
- + gg.geom_line()
- + gg.geom_point(gg.aes(x=raw_F_air_val_rw,y=raw_rw_mae,color='"Current"'),size=5)
- + gg.geom_point(gg.aes(x=best_rw_value,y=best_rw_mae,color='"Best"'),size=5)
- + gg.labs(color='')
- )
- plot.show()
- return best_rw_value
|