fit_utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import numpy as np
  2. import pandas as pd
  3. import pymc as pm
  4. from sklearn.metrics import r2_score,mean_absolute_error,mean_absolute_percentage_error
  5. def record(name,var):
  6. pm.Deterministic(f'{name}_mu',var)
  7. def observe(name,var,observed,sigma=1):
  8. if isinstance(observed,pd.DataFrame):
  9. observed = observed.loc[:,name].values
  10. mu = pm.Deterministic(f'{name}_mu',var)
  11. sigma = pm.HalfNormal(f'{name}_sigma',sigma=sigma)
  12. pm.Normal(name,mu=mu,sigma=sigma,observed=observed)
  13. def reorder_posterior(prior:dict,posterior:dict):
  14. param_posterior_reorder = {'F_air':{}}
  15. for equp_name in prior.keys():
  16. param_posterior_reorder.setdefault(equp_name,{})
  17. for param_name,param_value in posterior.items():
  18. if '__' in param_name:
  19. continue
  20. if param_name == 'F_air_val_rw':
  21. param_value = np.median(param_value[-5:])
  22. if param_name.startswith(equp_name):
  23. param_name_adj = param_name.replace(f'{equp_name}_','')
  24. param_posterior_reorder[equp_name][param_name_adj] = param_value
  25. return param_posterior_reorder
  26. def get_fitted_result(
  27. posterior : dict,
  28. observed_data: pd.DataFrame,
  29. plot_TVP : bool
  30. ) -> tuple:
  31. # 样本内预测数据
  32. TVP_data = []
  33. for param_name in posterior.keys():
  34. if param_name.replace('_mu','') not in observed_data.columns:
  35. continue
  36. TVP_data.append(
  37. pd.DataFrame(
  38. {
  39. 'param_name': param_name.replace('_mu',''),
  40. 'real' : observed_data.loc[:,param_name.replace('_mu','')].values,
  41. 'pred' : posterior[param_name]
  42. }
  43. )
  44. )
  45. TVP_data = pd.concat(TVP_data,axis=0)
  46. group_by_data = TVP_data.groupby(['param_name'])[['pred','real']]
  47. TVP_metric = (
  48. pd.concat(
  49. [
  50. group_by_data.apply(lambda dt:r2_score(dt.real,dt.pred)),
  51. group_by_data.apply(lambda dt:mean_absolute_error(dt.real,dt.pred)),
  52. group_by_data.apply(lambda dt:mean_absolute_percentage_error(dt.real,dt.pred)),
  53. ],
  54. axis=1
  55. )
  56. .set_axis(['R2','MAE','MAPE'],axis=1)
  57. .sort_values(by='R2',ascending=True)
  58. )
  59. if plot_TVP:
  60. import plotnine as gg
  61. plot = (
  62. TVP_data
  63. .pipe(gg.ggplot)
  64. + gg.aes(x='real',y='pred')
  65. + gg.geom_point()
  66. + gg.facet_wrap(facets='param_name',scales='free')
  67. + gg.geom_abline(intercept=0,slope=1,color='red')
  68. + gg.theme(figure_size=[10,10])
  69. )
  70. plot.show()
  71. return TVP_data,TVP_metric