| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import numpy as np
- import pandas as pd
- import plotnine as gg
- from ._base import BaseModel
- class BaseDevice(BaseModel):
-
- def __init__(self) -> None:
- super().__init__()
-
- def curve(
- self,
- input_data: pd.DataFrame,
- x : str,
- y : str,
- color : str = None,
- grid_x : str = None,
- grid_y : str = None,
- ):
-
- if x not in input_data.columns:
- raise Exception(f'{x} is not in input_data')
-
- product = [np.linspace(input_data.loc[:,x].min(),input_data.loc[:,x].max(),100)]
- names = [x]
-
- if color is not None:
- if color not in input_data.columns:
- raise Exception(f'{color} is not in input_data')
- product.append(np.quantile(input_data.loc[:,color],q=[0.25,0.5,0.75]))
- names.append(color)
- if grid_x is not None:
- if grid_x not in input_data.columns:
- raise Exception(f'{grid_x} is not in input_data')
- product.append(np.quantile(input_data.loc[:,grid_x],q=[0.25,0.5,0.75]))
- names.append(grid_x)
- if grid_y is not None:
- if grid_y not in input_data.columns:
- raise Exception(f'{grid_y} is not in input_data')
- product.append(np.quantile(input_data.loc[:,grid_y],q=[0.25,0.5,0.75]))
- names.append(grid_y)
-
- curve_input = (
- pd.MultiIndex.from_product(
- product,
- names=names
- )
- .to_frame(index=False)
- )
- curve_input_all = curve_input.copy(deep=True)
-
- for col in input_data.columns:
- if col not in curve_input_all:
- curve_input_all.loc[:,col] = input_data.loc[:,col].median()
- pred_data = self.predict_system(curve_input_all)
- if y not in pred_data.columns:
- raise Exception(f'{y} is not in Prediction')
-
- curve_data = pd.concat([curve_input,pred_data],axis=1)
- plot = (
- curve_data
- .round(2)
- .pipe(gg.ggplot)
- + gg.aes(x=x,y=y)
- + gg.geom_line()
- )
- if color is not None:
- plot += gg.aes(color=f'factor({color})',group=color)
- plot += gg.labs(color=color)
-
- if grid_x is not None or grid_y is not None:
- plot += gg.facet_grid(rows=grid_x,cols=grid_y,labeller='label_both')
-
- return plot
|