import numpy as np import pandas as pd import plotnine as gg from ._base import BaseModel class BaseDevice(BaseModel): def __init__(self) -> None: super().__init__() def curve( self, input_data: pd.DataFrame, x : str, y : str, color : str = None, grid_x : str = None, grid_y : str = None, ): if x not in input_data.columns: raise Exception(f'{x} is not in input_data') product = [np.linspace(input_data.loc[:,x].min(),input_data.loc[:,x].max(),100)] names = [x] if color is not None: if color not in input_data.columns: raise Exception(f'{color} is not in input_data') product.append(np.quantile(input_data.loc[:,color],q=[0.25,0.5,0.75])) names.append(color) if grid_x is not None: if grid_x not in input_data.columns: raise Exception(f'{grid_x} is not in input_data') product.append(np.quantile(input_data.loc[:,grid_x],q=[0.25,0.5,0.75])) names.append(grid_x) if grid_y is not None: if grid_y not in input_data.columns: raise Exception(f'{grid_y} is not in input_data') product.append(np.quantile(input_data.loc[:,grid_y],q=[0.25,0.5,0.75])) names.append(grid_y) curve_input = ( pd.MultiIndex.from_product( product, names=names ) .to_frame(index=False) ) curve_input_all = curve_input.copy(deep=True) for col in input_data.columns: if col not in curve_input_all: curve_input_all.loc[:,col] = input_data.loc[:,col].median() pred_data = self.predict_system(curve_input_all) if y not in pred_data.columns: raise Exception(f'{y} is not in Prediction') curve_data = pd.concat([curve_input,pred_data],axis=1) plot = ( curve_data .round(2) .pipe(gg.ggplot) + gg.aes(x=x,y=y) + gg.geom_line() ) if color is not None: plot += gg.aes(color=f'factor({color})',group=color) plot += gg.labs(color=color) if grid_x is not None or grid_y is not None: plot += gg.facet_grid(rows=grid_x,cols=grid_y,labeller='label_both') return plot