fit_utils.py 1.2 KB

1234567891011121314151617181920212223242526272829303132
  1. import numpy as np
  2. import pandas as pd
  3. import pymc as pm
  4. def record(name,var):
  5. pm.Deterministic(f'{name}_mu',var)
  6. def observe(name,var,observed,sigma=1):
  7. if isinstance(observed,pd.DataFrame):
  8. if name not in observed.columns:
  9. raise Exception(f'observed data中未找到{name}')
  10. observed = observed.loc[:,name].values
  11. mu = pm.Deterministic(f'{name}_mu',var)
  12. if isinstance(sigma,(int,float)):
  13. sigma = pm.HalfNormal(f'{name}_sigma',sigma=sigma)
  14. else:
  15. sigma = sigma
  16. pm.Normal(name,mu=mu,sigma=sigma,observed=observed)
  17. def reorder_posterior(prior:dict,posterior:dict):
  18. param_posterior_reorder = {'F_air':{}}
  19. for equp_name in prior.keys():
  20. param_posterior_reorder.setdefault(equp_name,{})
  21. for param_name,param_value in posterior.items():
  22. if '__' in param_name:
  23. continue
  24. if param_name == 'F_air_val_rw':
  25. param_value = param_value[-1]
  26. if param_name.startswith(equp_name):
  27. param_name_adj = param_name.replace(f'{equp_name}_','')
  28. param_posterior_reorder[equp_name][param_name_adj] = param_value
  29. return param_posterior_reorder