_base_device.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import os
  2. from typing import Union
  3. import numpy as np
  4. import pandas as pd
  5. from sklearn.metrics import (
  6. r2_score,
  7. mean_absolute_error,
  8. mean_absolute_percentage_error
  9. )
  10. try:
  11. import plotnine as gg
  12. except:
  13. pass
  14. from ._base import BaseModel
  15. class BaseDevice(BaseModel):
  16. val_rw_adj_target = None
  17. def __init__(self) -> None:
  18. super().__init__()
  19. def predict(self,input_data:pd.DataFrame) -> dict:
  20. param_posterior = self.model_info['model_ATD']
  21. res = self.model(
  22. **{k:input_data.loc[:,v].values for k,v in self.model_input_data_columns.items()},
  23. engine = 'numpy',
  24. components = self.components,
  25. param = param_posterior
  26. )
  27. return res
  28. def predict_system(self,input_data : pd.DataFrame) -> pd.DataFrame:
  29. pred_res = self.predict(input_data)
  30. system_output = {}
  31. for equp_name,output_info in pred_res.items():
  32. for output_name,output_value in output_info.items():
  33. system_output[f'{equp_name}_{output_name}'] = output_value
  34. system_output = pd.DataFrame(system_output)
  35. return system_output
  36. def get_TVP(self,posterior:dict,observed_data:pd.DataFrame):
  37. TVP_data = []
  38. for param_name in posterior.keys():
  39. if param_name.replace('_mu','') not in observed_data.columns:
  40. continue
  41. TVP_data.append(
  42. pd.DataFrame(
  43. {
  44. 'idx' : observed_data.index,
  45. 'param_name': param_name.replace('_mu',''),
  46. 'real' : observed_data.loc[:,param_name.replace('_mu','')].values,
  47. 'pred' : posterior[param_name]
  48. }
  49. )
  50. )
  51. TVP_data = pd.concat(TVP_data,axis=0)
  52. return TVP_data
  53. def get_metric(self,TVP:pd.DataFrame):
  54. group_by_data = TVP.groupby(['param_name'])[['pred','real']]
  55. TVP_metric = (
  56. pd.concat(
  57. [
  58. group_by_data.apply(lambda dt:r2_score(dt.real,dt.pred)),
  59. group_by_data.apply(lambda dt:mean_absolute_error(dt.real,dt.pred)),
  60. group_by_data.apply(lambda dt:mean_absolute_percentage_error(dt.real,dt.pred)),
  61. ],
  62. axis=1
  63. )
  64. .set_axis(['R2','MAE','MAPE'],axis=1)
  65. .sort_values(by='R2',ascending=True)
  66. )
  67. return TVP_metric
  68. def plot_TVP(self,TVP,save_path=None):
  69. plot = (
  70. TVP
  71. .pipe(gg.ggplot)
  72. + gg.aes(x='real',y='pred')
  73. + gg.geom_point()
  74. + gg.facet_wrap(facets='param_name',scales='free')
  75. + gg.geom_abline(intercept=0,slope=1,color='red')
  76. + gg.theme(figure_size=[10,10])
  77. )
  78. if save_path is not None:
  79. plot.save(filename=save_path)
  80. return plot
  81. def plot_check(self,cur_input_data:pd.DataFrame) -> dict:
  82. return {}
  83. def curve(
  84. self,
  85. input_data : pd.DataFrame,
  86. x : str,
  87. y : str,
  88. color : str = None,
  89. facte_x : str = None,
  90. facte_y : str = None,
  91. space_x : np.ndarray = None,
  92. space_color: np.ndarray = None,
  93. diff_y : bool = False
  94. ):
  95. if x not in input_data.columns:
  96. raise Exception(f'{x} is not in input_data')
  97. if space_x is None:
  98. product = [np.linspace(input_data.loc[:,x].min(),input_data.loc[:,x].max(),100)]
  99. else:
  100. product = [space_x]
  101. names = [x]
  102. groupby_key = []
  103. if color is not None:
  104. if color not in input_data.columns:
  105. raise Exception(f'{color} is not in input_data')
  106. if space_color is None:
  107. product.append(np.quantile(input_data.loc[:,color],q=[0.25,0.5,0.75]))
  108. else:
  109. product.append(space_color)
  110. names.append(color)
  111. groupby_key.append(color)
  112. if facte_x is not None:
  113. if facte_x not in input_data.columns:
  114. raise Exception(f'{facte_x} is not in input_data')
  115. product.append(np.quantile(input_data.loc[:,facte_x],q=[0.25,0.5,0.75]))
  116. names.append(facte_x)
  117. groupby_key.append(facte_x)
  118. if facte_y is not None:
  119. if facte_y not in input_data.columns:
  120. raise Exception(f'{facte_y} is not in input_data')
  121. product.append(np.quantile(input_data.loc[:,facte_y],q=[0.25,0.5,0.75]))
  122. names.append(facte_y)
  123. groupby_key.append(facte_y)
  124. curve_input = (
  125. pd.MultiIndex.from_product(
  126. product,
  127. names=names
  128. )
  129. .to_frame(index=False)
  130. )
  131. curve_input_all = curve_input.copy(deep=True)
  132. for col in input_data.columns:
  133. if col not in curve_input_all:
  134. curve_input_all.loc[:,col] = input_data.loc[:,col].median()
  135. pred_data = self.predict_system(curve_input_all)
  136. if y not in pred_data.columns:
  137. raise Exception(f'{y} is not in Prediction')
  138. curve_data = pd.concat([curve_input,pred_data],axis=1)
  139. if diff_y:
  140. if len(groupby_key) > 0:
  141. curve_data = (
  142. curve_data
  143. .groupby(groupby_key,as_index=True)
  144. .apply(lambda dt:dt.loc[:,y].diff(),include_groups=False)
  145. .reset_index()
  146. .dropna()
  147. )
  148. else:
  149. curve_data.loc[:,y] = curve_data.loc[:,y].diff()
  150. curve_data = curve_data.dropna()
  151. plot = (
  152. curve_data
  153. .pipe(gg.ggplot)
  154. + gg.aes(x=x,y=y)
  155. + gg.geom_line()
  156. )
  157. if color is not None:
  158. plot += gg.aes(color=f'factor({color})',group=color)
  159. plot += gg.labs(color=color)
  160. if facte_x is not None or facte_y is not None:
  161. plot += gg.facet_grid(rows=facte_x,cols=facte_y,labeller='label_both')
  162. if diff_y:
  163. plot += gg.geom_hline(yintercept=0,linetype='--')
  164. return plot
  165. @property
  166. def F_air_val_rw(self):
  167. raise NotImplementedError
  168. def set_F_air_val_rw(self,value:float):
  169. raise NotImplementedError
  170. def find_F_air_val_rw(
  171. self,
  172. input_data : pd.DataFrame,
  173. observed_data : pd.DataFrame,
  174. plot : bool = False,
  175. rw_value_range: Union[None,tuple] = None
  176. ):
  177. if self.val_rw_adj_target is None:
  178. raise NotImplementedError('请先设置val_rw_adj_target')
  179. raw_F_air_val_rw = self.F_air_val_rw
  180. if rw_value_range is None:
  181. rw_value_range = np.linspace(1e-6,raw_F_air_val_rw*2,500)
  182. else:
  183. rw_value_range = np.linspace(rw_value_range[0],rw_value_range[1],500)
  184. mae = []
  185. for rw_value in rw_value_range:
  186. self.set_F_air_val_rw(rw_value)
  187. pred = self.predict_system(input_data).loc[:,self.val_rw_adj_target].values.flatten()
  188. real = observed_data.loc[:,self.val_rw_adj_target].values.flatten()
  189. mae.append(mean_absolute_error(pred,real))
  190. best_rw_value = rw_value_range[np.argmin(mae)]
  191. best_rw_mae = mae[np.argmin(mae)]
  192. raw_rw_mae = mae[np.argmin(np.abs(raw_F_air_val_rw-rw_value_range))]
  193. self.set_F_air_val_rw(raw_F_air_val_rw)
  194. print(f'val_rw:{raw_F_air_val_rw:.2f} -> {best_rw_value:.2f}, MAE:{raw_rw_mae:.2f} -> {best_rw_mae:.2f}')
  195. if plot:
  196. import plotnine as gg
  197. plot = (
  198. pd.DataFrame(
  199. {
  200. 'val_rw': rw_value_range,
  201. 'MAE' : mae
  202. }
  203. )
  204. .pipe(gg.ggplot)
  205. + gg.aes(x='val_rw',y='MAE')
  206. + gg.geom_line()
  207. + gg.geom_point(gg.aes(x=raw_F_air_val_rw,y=raw_rw_mae,color='"Current"'),size=5)
  208. + gg.geom_point(gg.aes(x=best_rw_value,y=best_rw_mae,color='"Best"'),size=5)
  209. + gg.labs(color='')
  210. )
  211. plot.show()
  212. return best_rw_value