data_cleaner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import warnings
  2. from typing import Union
  3. from datetime import datetime
  4. import sys
  5. from io import StringIO
  6. import numpy as np
  7. import pandas as pd
  8. from statsmodels.formula.api import rlm
  9. from scipy.stats import iqr
  10. from .data_summary import summary_dataframe
  11. class DataCleaner:
  12. def __init__(self,data:pd.DataFrame,print_process=True) -> None:
  13. self.raw_data = data
  14. self.data = data.copy()
  15. self.drop_index = np.array([False]*len(self.raw_data))
  16. self.print_process = print_process
  17. self.log_info = []
  18. raw_info = capture_print_output(
  19. summary_dataframe,
  20. interfere = not print_process,
  21. df = self.raw_data,
  22. df_name = '原始数据'
  23. )
  24. self.log_info.append(raw_info)
  25. def rm_na_and_inf(self):
  26. # 删除缺失数据
  27. is_na_data = self.data.isna().any(axis=1).values
  28. is_inf_data = np.any(np.isinf(self.data.values),axis=1)
  29. drop_index = is_na_data | is_inf_data
  30. self.drop_index = self.drop_index | drop_index
  31. self._count_removed_data(index=drop_index,method='rm_na_and_inf')
  32. return self
  33. def rm_constant(
  34. self,
  35. window :int = 10,
  36. exclude_value :list = None,
  37. include_cols :list = '__ALL__',
  38. include_by_re :bool = False,
  39. exclude_cols :list = None
  40. ):
  41. # 删除常数
  42. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  43. drop_index_matrix = (data.rolling(window=window).std()==0)
  44. if exclude_value is not None:
  45. for each_value in exclude_value:
  46. keep_index_matrix = data.values == each_value
  47. drop_index_matrix[keep_index_matrix] = False
  48. drop_index = drop_index_matrix.any(axis=1)
  49. self.drop_index = self.drop_index | drop_index
  50. self._count_removed_data(index=drop_index,method='rm_constant',index_matrix=drop_index_matrix,var_name=data.columns)
  51. return self
  52. def rm_rolling_fluct(
  53. self,
  54. window :int = 10,
  55. unit :Union[str,None] = 'min',
  56. fun :str = 'ptp',
  57. thre :float = 0,
  58. include_cols :list = '__ALL__',
  59. include_by_re :bool = False,
  60. exclude_cols :list = None
  61. ):
  62. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  63. if unit is None:
  64. roll_window = window
  65. else:
  66. roll_window = str(window) + unit
  67. roll_data = data.rolling(window=roll_window,min_periods=1,center=True)
  68. if fun == 'ptp':
  69. res = roll_data.max() - roll_data.min()
  70. elif fun == 'pct':
  71. res = (roll_data.max() - roll_data.min())/roll_data.min()
  72. drop_index_matrix = res>thre
  73. drop_index = drop_index_matrix.any(axis=1)
  74. self.drop_index = self.drop_index | drop_index
  75. self._count_removed_data(index=drop_index,method='rm_rolling_fluct',index_matrix=drop_index_matrix,var_name=data.columns)
  76. return self
  77. def rm_outlier_rolling_mean(
  78. self,
  79. window :int = 10,
  80. thre :float = 0.02,
  81. include_cols :list = '__ALL__',
  82. include_by_re:bool = False,
  83. exclude_cols :list = None
  84. ):
  85. # 删除时序异常
  86. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  87. data = data.reset_index(drop=True)
  88. windows_mean = data.rolling(window=window,min_periods=1).mean()
  89. drop_index = (((data - windows_mean)/data).abs()>thre).any(axis=1).values
  90. self.drop_index = drop_index | self.drop_index
  91. self._count_removed_data(index=drop_index,method='rm_outlier_mean')
  92. return self
  93. def rm_diff(
  94. self,
  95. thre : float,
  96. shift : int = 1,
  97. include_cols : list = '__ALL__',
  98. include_by_re: bool = False,
  99. exclude_cols : list = None
  100. ):
  101. # shift 等于1时为后一项减前一项
  102. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  103. data_diff = data.diff(periods=shift,axis=0)
  104. drop_index_matrix = data_diff.abs() > thre
  105. drop_index = drop_index_matrix.any(axis=1).values
  106. self.drop_index = drop_index | self.drop_index
  107. self._count_removed_data(index=drop_index,method='rm_diff',index_matrix=drop_index_matrix,var_name=data.columns)
  108. return self
  109. def rm_zero(
  110. self,
  111. include_cols :list = '__ALL__',
  112. include_by_re:bool = False,
  113. exclude_cols :list = None
  114. ):
  115. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  116. drop_index = (data==0).any(axis=1).values
  117. self.drop_index = drop_index | self.drop_index
  118. self._count_removed_data(index=drop_index,method='rm_zero')
  119. return self
  120. def rm_negative(
  121. self,
  122. keep_zero :bool = False,
  123. include_cols :list = '__ALL__',
  124. include_by_re:bool = False,
  125. exclude_cols :list = None
  126. ):
  127. # 删除负数
  128. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  129. if keep_zero is True:
  130. drop_index = (data<0).any(axis=1).values
  131. else:
  132. drop_index = (data<=0).any(axis=1).values
  133. self.drop_index = drop_index | self.drop_index
  134. self._count_removed_data(index=drop_index,method='rm_negative')
  135. return self
  136. def rm_rule(self,remove_rule:str):
  137. # 基于规则删除数据
  138. data = self.data.copy()
  139. drop_index = np.array(data.eval(remove_rule))
  140. self.drop_index = drop_index | self.drop_index
  141. self._count_removed_data(index=drop_index,method=f'rm_rule({remove_rule})')
  142. return self
  143. def rm_regression_outlier(
  144. self,
  145. formula : str,
  146. rm_resid_IQR: float = 1.5,
  147. rm_dir : str = 'both',
  148. exclude_rule: Union[str,list,None] = None,
  149. min_sample : int = 30,
  150. ):
  151. #! 顺序敏感
  152. RAW_INDEX = np.arange(len(self.data))
  153. # 排除以外的数据,不参与计算
  154. if exclude_rule is None:
  155. exclude_rule = []
  156. if isinstance(exclude_rule,str):
  157. exclude_rule = [exclude_rule]
  158. exclued_index = np.array([False]*len(self.raw_data))
  159. for rule in exclude_rule:
  160. exclued_index = exclued_index | np.array(self.data.eval(rule))
  161. exclued_index = pd.Series(data=exclued_index,index=RAW_INDEX)
  162. exclude_index_drop = pd.Series(self.drop_index,index=RAW_INDEX).loc[exclued_index.values]
  163. # 待清洗的数据
  164. data_clean = self.data.assign(RAW_INDEX_=RAW_INDEX).loc[~(self.drop_index|exclued_index.values)]
  165. filter_index = data_clean.RAW_INDEX_.values
  166. if len(data_clean) < min_sample:
  167. return self
  168. with warnings.catch_warnings():
  169. warnings.simplefilter('ignore')
  170. mod = rlm(formula,data=data_clean).fit(maxiter=500)
  171. resid = np.array(mod.resid)
  172. IQR = iqr(resid)
  173. if rm_dir == 'both':
  174. drop_index = (resid < (np.quantile(resid,q=0.25)-rm_resid_IQR*IQR)) | (resid > (np.quantile(resid,q=0.75)+rm_resid_IQR*IQR))
  175. elif rm_dir == 'lower':
  176. drop_index = resid < (np.quantile(resid,q=0.25)-rm_resid_IQR*IQR)
  177. elif rm_dir == 'upper':
  178. drop_index = resid > (np.quantile(resid,q=0.75)+rm_resid_IQR*IQR)
  179. else:
  180. raise ValueError('rm_dir must be one of "both","lower","upper"')
  181. drop_index_incomplete = pd.Series(data=drop_index,index=filter_index).combine_first(exclude_index_drop)
  182. drop_index_complete = drop_index_incomplete.reindex(RAW_INDEX).fillna(False).values
  183. self.drop_index = drop_index_complete | self.drop_index
  184. self._count_removed_data(index=drop_index,method=f'rm_reg({formula})')
  185. return self
  186. def rm_date_range(self,start:datetime,end:datetime,col=None):
  187. start = pd.Timestamp(start)
  188. end = pd.Timestamp(end)
  189. if col is None:
  190. ts = pd.to_datetime(self.raw_data.index)
  191. else:
  192. ts = pd.to_datetime(self.raw_data.loc[:,col])
  193. drop_index = (ts>=start) & (ts<=end)
  194. self.drop_index = drop_index | self.drop_index
  195. self._count_removed_data(index=drop_index,method=f'rm_date_range({start}~{end})')
  196. return self
  197. def rm_outrange(
  198. self,
  199. method :str = 'quantile',
  200. upper :float = 0.99,
  201. lower :float = 0.01,
  202. include_cols :list = '__ALL__',
  203. include_by_re :bool = False,
  204. exclude_cols :list = None
  205. ):
  206. data = self._get_data_by_cols(include_cols,include_by_re,exclude_cols)
  207. if method == 'quantile':
  208. q_upper = np.quantile(data.values,q=upper,axis=0)
  209. q_lower = np.quantile(data.values,q=lower,axis=0)
  210. elif method == 'raw':
  211. q_upper = upper
  212. q_lower = lower
  213. else:
  214. raise Exception('WRONG method')
  215. drop_index_matrix = (data > q_upper) | (data < q_lower)
  216. drop_index = drop_index_matrix.any(axis=1)
  217. self.drop_index = self.drop_index | drop_index
  218. self._count_removed_data(index=drop_index,method='rm_outrange',index_matrix=drop_index_matrix,var_name=data.columns)
  219. return self
  220. def save_log(self,path:str='./log.txt'):
  221. with open(path, "w", encoding="utf-8") as f:
  222. for line in self.log_info:
  223. f.write(line + "\n")
  224. return self
  225. def get_data(self,fill=None,get_drop=False,save_log=None) -> pd.DataFrame:
  226. index = self.drop_index if not get_drop else ~self.drop_index
  227. if fill is None:
  228. # 保留非删除数据
  229. result_data = self.raw_data.loc[~index,:]
  230. else:
  231. # 填充非删除数据
  232. result_data = self.raw_data.copy()
  233. result_data.loc[index,:] = fill
  234. res_info = capture_print_output(
  235. summary_dataframe,
  236. interfere = not self.print_process,
  237. df = result_data,
  238. df_name = '结果数据'
  239. )
  240. self.log_info.append(res_info)
  241. if save_log is not None:
  242. self.save_log(save_log)
  243. return result_data
  244. def _get_data_by_cols(
  245. self,
  246. include_cols :list = '__ALL__',
  247. include_by_re:bool = False,
  248. exclude_cols :list = None,
  249. ) -> pd.DataFrame:
  250. data = self.data.copy()
  251. if include_by_re is True:
  252. if isinstance(include_cols,str):
  253. cols = data.loc[:,data.columns.str.contains(include_cols,regex=True)].columns
  254. else:
  255. raise Exception('WRONG')
  256. elif include_by_re is False:
  257. if include_cols == '__ALL__':
  258. cols = data.columns
  259. elif isinstance(include_cols,str):
  260. cols = [include_cols]
  261. elif isinstance(include_cols,list):
  262. cols = data.loc[:,include_cols].columns
  263. else:
  264. raise Exception('WRONG')
  265. if exclude_cols is not None:
  266. cols = cols.difference(other=exclude_cols)
  267. return data.loc[:,cols]
  268. def _count_removed_data(self,index,method,index_matrix=None,var_name=None):
  269. count = index.sum()
  270. pct = round(count / len(index) * 100,2)
  271. info = f'remove {count}({pct}%) by {method}'
  272. self.log_info.append(info)
  273. if self.print_process:
  274. print(info)
  275. if index_matrix is not None and var_name is not None:
  276. var_drop_count = np.sum(index_matrix,axis=0)
  277. for var,drop_count in zip(var_name,var_drop_count):
  278. if drop_count == 0:
  279. continue
  280. info = f'{var}:{drop_count}'
  281. self.log_info.append(info)
  282. if self.print_process:
  283. print(info)
  284. def capture_print_output(func, *args, interfere=False, **kwargs):
  285. """
  286. 捕获函数的打印输出并返回为字符串,可选择是否影响原函数的正常打印
  287. 参数:
  288. func: 要执行的函数
  289. *args: 传递给函数的位置参数
  290. interfere: 是否干扰原函数打印 (默认False)
  291. **kwargs: 传递给函数的关键字参数
  292. 返回:
  293. 函数执行过程中所有打印输出的字符串
  294. """
  295. # 创建一个字符串缓冲区来捕获输出
  296. new_stdout = StringIO()
  297. # 保存原来的标准输出
  298. old_stdout = sys.stdout
  299. if interfere:
  300. # 简单重定向模式 - 会干扰原打印
  301. sys.stdout = new_stdout
  302. else:
  303. # 不干扰模式 - 使用Tee类同时输出
  304. class Tee:
  305. def __init__(self, old, new):
  306. self.old = old
  307. self.new = new
  308. def write(self, text):
  309. self.old.write(text) # 保持原打印
  310. self.new.write(text) # 捕获到字符串
  311. def flush(self):
  312. self.old.flush()
  313. self.new.flush()
  314. sys.stdout = Tee(old_stdout, new_stdout)
  315. try:
  316. # 执行函数
  317. func(*args, **kwargs)
  318. # 获取捕获的输出
  319. output = new_stdout.getvalue()
  320. finally:
  321. # 恢复原来的标准输出
  322. sys.stdout = old_stdout
  323. return output