_base_device.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import numpy as np
  2. import pandas as pd
  3. import plotnine as gg
  4. from ._base import BaseModel
  5. class BaseDevice(BaseModel):
  6. def __init__(self) -> None:
  7. super().__init__()
  8. def curve(
  9. self,
  10. input_data: pd.DataFrame,
  11. x : str,
  12. y : str,
  13. color : str = None,
  14. grid_x : str = None,
  15. grid_y : str = None,
  16. ):
  17. if x not in input_data.columns:
  18. raise Exception(f'{x} is not in input_data')
  19. product = [np.linspace(input_data.loc[:,x].min(),input_data.loc[:,x].max(),100)]
  20. names = [x]
  21. if color is not None:
  22. if color not in input_data.columns:
  23. raise Exception(f'{color} is not in input_data')
  24. product.append(np.quantile(input_data.loc[:,color],q=[0.25,0.5,0.75]))
  25. names.append(color)
  26. if grid_x is not None:
  27. if grid_x not in input_data.columns:
  28. raise Exception(f'{grid_x} is not in input_data')
  29. product.append(np.quantile(input_data.loc[:,grid_x],q=[0.25,0.5,0.75]))
  30. names.append(grid_x)
  31. if grid_y is not None:
  32. if grid_y not in input_data.columns:
  33. raise Exception(f'{grid_y} is not in input_data')
  34. product.append(np.quantile(input_data.loc[:,grid_y],q=[0.25,0.5,0.75]))
  35. names.append(grid_y)
  36. curve_input = (
  37. pd.MultiIndex.from_product(
  38. product,
  39. names=names
  40. )
  41. .to_frame(index=False)
  42. )
  43. curve_input_all = curve_input.copy(deep=True)
  44. for col in input_data.columns:
  45. if col not in curve_input_all:
  46. curve_input_all.loc[:,col] = input_data.loc[:,col].median()
  47. pred_data = self.predict_system(curve_input_all)
  48. if y not in pred_data.columns:
  49. raise Exception(f'{y} is not in Prediction')
  50. curve_data = pd.concat([curve_input,pred_data],axis=1)
  51. plot = (
  52. curve_data
  53. .round(2)
  54. .pipe(gg.ggplot)
  55. + gg.aes(x=x,y=y)
  56. + gg.geom_line()
  57. )
  58. if color is not None:
  59. plot += gg.aes(color=f'factor({color})',group=color)
  60. plot += gg.labs(color=color)
  61. if grid_x is not None or grid_y is not None:
  62. plot += gg.facet_grid(rows=grid_x,cols=grid_y,labeller='label_both')
  63. return plot