_base_device.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import numpy as np
  2. import pandas as pd
  3. from sklearn.metrics import (
  4. r2_score,
  5. mean_absolute_error,
  6. mean_absolute_percentage_error
  7. )
  8. try:
  9. import plotnine as gg
  10. except:
  11. pass
  12. from ._base import BaseModel
  13. class BaseDevice(BaseModel):
  14. def __init__(self) -> None:
  15. super().__init__()
  16. def predict(self,input_data:pd.DataFrame) -> dict:
  17. param_posterior = self.model_info['model_ATD']
  18. res = self.model(
  19. **{k:input_data.loc[:,v].values for k,v in self.model_input_data_columns.items()},
  20. engine = 'numpy',
  21. components = self.components,
  22. param = param_posterior
  23. )
  24. return res
  25. def predict_system(self,input_data : pd.DataFrame) -> pd.DataFrame:
  26. pred_res = self.predict(input_data)
  27. system_output = {}
  28. for equp_name,output_info in pred_res.items():
  29. for output_name,output_value in output_info.items():
  30. system_output[f'{equp_name}_{output_name}'] = output_value
  31. system_output = pd.DataFrame(system_output)
  32. return system_output
  33. def get_TVP(self,posterior:dict,observed_data:pd.DataFrame):
  34. TVP_data = []
  35. for param_name in posterior.keys():
  36. if param_name.replace('_mu','') not in observed_data.columns:
  37. continue
  38. TVP_data.append(
  39. pd.DataFrame(
  40. {
  41. 'idx' : observed_data.index,
  42. 'param_name': param_name.replace('_mu',''),
  43. 'real' : observed_data.loc[:,param_name.replace('_mu','')].values,
  44. 'pred' : posterior[param_name]
  45. }
  46. )
  47. )
  48. TVP_data = pd.concat(TVP_data,axis=0)
  49. return TVP_data
  50. def get_metric(self,TVP:pd.DataFrame):
  51. group_by_data = TVP.groupby(['param_name'])[['pred','real']]
  52. TVP_metric = (
  53. pd.concat(
  54. [
  55. group_by_data.apply(lambda dt:r2_score(dt.real,dt.pred)),
  56. group_by_data.apply(lambda dt:mean_absolute_error(dt.real,dt.pred)),
  57. group_by_data.apply(lambda dt:mean_absolute_percentage_error(dt.real,dt.pred)),
  58. ],
  59. axis=1
  60. )
  61. .set_axis(['R2','MAE','MAPE'],axis=1)
  62. .sort_values(by='R2',ascending=True)
  63. )
  64. return TVP_metric
  65. def plot_TVP(self,TVP):
  66. plot = (
  67. TVP
  68. .pipe(gg.ggplot)
  69. + gg.aes(x='real',y='pred')
  70. + gg.geom_point()
  71. + gg.facet_wrap(facets='param_name',scales='free')
  72. + gg.geom_abline(intercept=0,slope=1,color='red')
  73. + gg.theme(figure_size=[10,10])
  74. )
  75. return plot
  76. def curve(
  77. self,
  78. input_data: pd.DataFrame,
  79. x : str,
  80. y : str,
  81. color : str = None,
  82. grid_x : str = None,
  83. grid_y : str = None,
  84. ):
  85. if x not in input_data.columns:
  86. raise Exception(f'{x} is not in input_data')
  87. product = [np.linspace(input_data.loc[:,x].min(),input_data.loc[:,x].max(),100)]
  88. names = [x]
  89. if color is not None:
  90. if color not in input_data.columns:
  91. raise Exception(f'{color} is not in input_data')
  92. product.append(np.quantile(input_data.loc[:,color],q=[0.25,0.5,0.75]))
  93. names.append(color)
  94. if grid_x is not None:
  95. if grid_x not in input_data.columns:
  96. raise Exception(f'{grid_x} is not in input_data')
  97. product.append(np.quantile(input_data.loc[:,grid_x],q=[0.25,0.5,0.75]))
  98. names.append(grid_x)
  99. if grid_y is not None:
  100. if grid_y not in input_data.columns:
  101. raise Exception(f'{grid_y} is not in input_data')
  102. product.append(np.quantile(input_data.loc[:,grid_y],q=[0.25,0.5,0.75]))
  103. names.append(grid_y)
  104. curve_input = (
  105. pd.MultiIndex.from_product(
  106. product,
  107. names=names
  108. )
  109. .to_frame(index=False)
  110. )
  111. curve_input_all = curve_input.copy(deep=True)
  112. for col in input_data.columns:
  113. if col not in curve_input_all:
  114. curve_input_all.loc[:,col] = input_data.loc[:,col].median()
  115. pred_data = self.predict_system(curve_input_all)
  116. if y not in pred_data.columns:
  117. raise Exception(f'{y} is not in Prediction')
  118. curve_data = pd.concat([curve_input,pred_data],axis=1)
  119. plot = (
  120. curve_data
  121. .round(2)
  122. .pipe(gg.ggplot)
  123. + gg.aes(x=x,y=y)
  124. + gg.geom_line()
  125. )
  126. if color is not None:
  127. plot += gg.aes(color=f'factor({color})',group=color)
  128. plot += gg.labs(color=color)
  129. if grid_x is not None or grid_y is not None:
  130. plot += gg.facet_grid(rows=grid_x,cols=grid_y,labeller='label_both')
  131. return plot