_update_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. import json
  2. from datetime import datetime, timedelta
  3. from workflowlib import requests
  4. from workflow_utils import update_model_version_v3
  5. def update(
  6. version_id : int,
  7. model_id : str,
  8. model_info : dict,
  9. update_method : str,
  10. MODEL_FUNC_PATH: str,
  11. MODEL_FILE_PATH: str
  12. ):
  13. """
  14. :param model_info:
  15. {
  16. 'E':{
  17. 'metric' : {
  18. 'MAE' : ...,
  19. 'MAPE': ...,
  20. },
  21. 'point_id' : 'abc',
  22. 'point_name' : 'abc',
  23. 'point_class' : 'abc'
  24. 'thre_mae' : '123',
  25. 'thre_mape' : '123',
  26. 'thre_days' : '123',
  27. },
  28. 'AP':{
  29. 'metric' : {
  30. 'MAE' : ...,
  31. 'MAPE': ...,
  32. },
  33. 'point_id' : 'abc',
  34. 'point_name' : 'abc',
  35. 'point_class' : 'abc'
  36. 'thre_mae' : '123',
  37. 'thre_mape' : '123',
  38. 'thre_days' : '123',
  39. },
  40. },
  41. :param MODEL_FUNC_PATH:
  42. :param MODEL_FILE_PATH:
  43. :return:
  44. """
  45. factors = ['MAE', 'MAPE'] # 定义指标名称列表
  46. # 模型管理设置点位
  47. points_set_dict = {}
  48. for key, value in model_info.items():
  49. points_set_dict[key] = {
  50. 'point_id' : value['point_id'],
  51. 'point_name' : value['point_name'],
  52. 'point_class': value['point_class']
  53. }
  54. update_model_points(points_set_dict, model_id)
  55. # 利用模型ID获取模型信息
  56. model_info_res = get_model_info(model_id=model_id)
  57. device_name = model_info_res['device_name'] # 设备名称用于更新日志
  58. # 模型训练指标写于日志
  59. metrics_log = {}
  60. for key, value in model_info.items():
  61. metrics_log[key] = value['metric'] # 模型训练后的指标
  62. # 点位列表
  63. point_ids = []
  64. for key, value in model_info.items():
  65. point_ids.append(value['point_id'])
  66. # 键值E/AP:点位,指标阈值
  67. points_metrics = {}
  68. for key, value in model_info.items():
  69. points_metrics[key] = {
  70. 'point_id': value['point_id'],
  71. 'MAE': value['thre_mae'],
  72. 'MAPE': value['thre_mape']
  73. }
  74. # print("points_metrics", points_metrics)
  75. metric_json = get_metrics_json(points_metrics) # 正确的函数调用
  76. # json_data = json.dumps(metric_json, indent=4)
  77. # print("json_data:\n", json_data)
  78. # 配置项中每个目标变量超出阈值的天数,{'E': '123', 'AP': '123'}
  79. thre_days_dict = {}
  80. for key, value in model_info.items():
  81. thre_days_dict[key] = value['thre_days']
  82. if update_method == 'any_metric':
  83. print("【更新模式:基于监控指标进行更新】")
  84. # 从获取的模型信息中读取版本列表,找到旧版本ID,用于读取监控指标
  85. all_old_version_id = model_info_res['version_list']
  86. old_version_id = next((_['id'] for _ in all_old_version_id), version_id)
  87. all_keys_need_update = True
  88. for key, days in thre_days_dict.items():
  89. print("{:=^50s}".format(f"目标变量 {key} "))
  90. print(f"处理目标变量 {key} 的过去时间:")
  91. past_times_list = get_past_times(days)
  92. # print(f"过去时间表: {past_times_list}")
  93. # 获取监控指标
  94. monitor_results = get_monitor_metric(past_times_list, old_version_id, factors, point_ids)
  95. # print(f"目标点位的监控指标结果:{monitor_results}")
  96. # 检查阈值是否超出
  97. is_update_needed = check_threshold_update_model(monitor_results, points_metrics[key])
  98. print(f"目标 {key} 是否需要更新模型:{is_update_needed}")
  99. # 如果任何一个key的指标超出阈值,则将all_keys_need_update置为False
  100. if is_update_needed:
  101. all_keys_need_update = False
  102. if all_keys_need_update:
  103. print("【所有目标变量的所有指标均未超出阈值,不需要更新模型】")
  104. return None
  105. else:
  106. print("【存在目标变量指标超出阈值,需要更新模型】")
  107. version_update(model_id, MODEL_FUNC_PATH, MODEL_FILE_PATH, metric_json, version_id, device_name,
  108. metrics_log)
  109. elif update_method == 'update':
  110. print("【更新模式:强制更新】")
  111. version_update(model_id, MODEL_FUNC_PATH, MODEL_FILE_PATH, metric_json, version_id, device_name, metrics_log)
  112. else:
  113. raise ValueError
  114. # 通过model_id读取模型版本界面信息
  115. def get_model_info(model_id):
  116. res = requests.get(url=f"http://m2-backend-svc:8000/api/ai/model/get_details/{model_id}")
  117. res = res.json()['result']
  118. # print(f"模型版本信息: ", res)
  119. return res
  120. # 计算过去时间
  121. def get_past_times(thre_days: int) -> list:
  122. """
  123. 根据传入的天数计算过去几天的起始时间(从凌晨开始)。
  124. Args:
  125. thre_days (int): 要计算的天数范围。
  126. Returns:
  127. list: 包含过去几天的 datetime 对象列表。
  128. """
  129. thre_days = int(thre_days)
  130. nowtime = datetime.now()
  131. past_times = []
  132. for day in range(1, thre_days + 1):
  133. past_time = (nowtime - timedelta(days=day))
  134. past_time = past_time.replace(hour=0, minute=0, second=0, microsecond=0)
  135. past_times.append(past_time)
  136. print(f"根据配置 {thre_days} 天计算的时间:", past_times)
  137. return past_times
  138. # 获取监控指标
  139. def get_monitor_metric(past_times: list, version_id: str, factors: list, point_ids: list) -> dict:
  140. """
  141. 获取时间内单体模型监控界面的指定模型指标。
  142. Args:
  143. past_times (list): 时间列表,包含时间段的 datetime 对象。
  144. version_id (str): 模型版本ID。
  145. factors (list): 指标列表,例如 ['MAE', 'MAPE']。
  146. point_ids (list): 目标变量点位编号列表,例如 ['_E', '_AP']。
  147. Returns:
  148. dict: 包含每个 factor 的监控结果,格式如下:
  149. {
  150. 'MAE': [值1, 值2, ...],
  151. 'MAPE': [值1, 值2, ...],
  152. }
  153. """
  154. url = "http://m2-backend-svc:8000/api/ai/monitor/get_single_factor_sequence"
  155. # monitor_metric = defaultdict(list)
  156. monitor_metric = {}
  157. for time_point in past_times:
  158. formatted_time = time_point.strftime("%Y-%m-%d %H:%M:%S")
  159. for point_id in point_ids:
  160. if point_id not in monitor_metric:
  161. monitor_metric[point_id] = {}
  162. for factor in factors:
  163. if factor not in monitor_metric[point_id]:
  164. monitor_metric[point_id][factor] = []
  165. data = {
  166. "factor": factor,
  167. "point_id": point_id,
  168. "time_begin": formatted_time,
  169. "time_end": formatted_time,
  170. "version_id": version_id,
  171. "type": "DAILY"
  172. }
  173. try:
  174. response = requests.post(url=url, data=json.dumps(data))
  175. response.raise_for_status() # 捕获 HTTP 请求异常
  176. result = response.json().get('results', [])
  177. if result:
  178. # monitor_metric[factor].append(result[0][1]) # 假定结果中目标值在 result[0][1]
  179. monitor_metric[point_id][factor].append(result[0][1])
  180. else:
  181. # monitor_metric[factor].append(0) # 填充0
  182. monitor_metric[point_id][factor].append(0)
  183. except (KeyError, IndexError, Exception) as e:
  184. print(f"Error fetching data for {factor}, {point_id}, {formatted_time}: {e}")
  185. # monitor_metric[factor].append(0)
  186. monitor_metric[point_id][factor].append(0)
  187. return dict(monitor_metric)
  188. # 检查指标是否超出阈值
  189. def check_threshold_update_model(monitor_metric, points_metrics):
  190. """
  191. 判断任意一种指标在配置时间内是否超出阈值
  192. Args:
  193. monitor_metric (dict): 模型的监控指标数据
  194. points_metrics (dict): 指标阈值,如 {'point_id': 'abc', 'MAE': '123', 'MAPE': '123'}
  195. Returns:
  196. bool: 若任意一个指标的值超过阈值,则返回 True,否则返回 False
  197. """
  198. # print("监控指标结果:", monitor_metric)
  199. # print("点位指标阈值:", points_metrics)
  200. point_id = points_metrics['point_id']
  201. any_greater = {}
  202. if point_id in monitor_metric:
  203. print(f"目标点位 {point_id} 的监控指标结果:{monitor_metric[point_id]}")
  204. for factor, values in monitor_metric[point_id].items():
  205. threshold = float(points_metrics.get(factor, float('inf')))
  206. # 判断是否有指标超出阈值
  207. any_greater[factor] = any(
  208. value > threshold for value in values if value is not None
  209. )
  210. any_true_greater = any(value for value in any_greater.values())
  211. print("监控模型指标是否超出阈值: ", any_greater)
  212. return any_true_greater
  213. def get_metrics_json(points_metrics):
  214. metric_json = []
  215. for key, metrics in points_metrics.items():
  216. # 提取阈值 MAE 和 MAPE
  217. thre_mae = metrics.get('MAE', None)
  218. thre_mape = metrics.get('MAPE', None)
  219. # 更新 factors 列表中的 upr_limit
  220. factors = [
  221. {"factor": "MAE", "lwr_limit": None, "upr_limit": float(thre_mae) if thre_mae is not None else None,
  222. "trained_value": None},
  223. {"factor": "MBE", "lwr_limit": None, "upr_limit": None, "trained_value": None},
  224. {"factor": "MSE", "lwr_limit": None, "upr_limit": None, "trained_value": None},
  225. {"factor": "MdAE", "lwr_limit": None, "upr_limit": None, "trained_value": None},
  226. {"factor": "std_MAE", "lwr_limit": None, "upr_limit": None, "trained_value": None},
  227. {"factor": "MAPE", "lwr_limit": None, "upr_limit": float(thre_mape) if thre_mape is not None else None,
  228. "trained_value": None},
  229. {"factor": "std_MAPE", "lwr_limit": None, "upr_limit": None, "trained_value": None},
  230. ]
  231. metric_json.append({"factors": factors, "point_id": metrics.get("point_id")})
  232. return metric_json
  233. def version_update(model_id, mod_func_path, mod_file_path, metric_json, new_version_id, device_name, metrics_log):
  234. """
  235. 上传新模型版本,并生成上传日志
  236. :param model_id: 模型ID
  237. :param mod_func_path: 模型文件地址
  238. :param mod_file_path: 模型文件地址
  239. :param metric_json: 模型指标metric json
  240. :param new_version_id: 新版本ID
  241. :param device_name: 设备名称
  242. :param metrics_log: 上传日志用到的指标log
  243. :return:
  244. """
  245. filename_list = [
  246. {"filename": mod_file_path},
  247. {"filename": mod_func_path}
  248. ]
  249. # 上传模型新版本
  250. update_model_version_v3(model_id, new_version_id, filename_list, workflow_id=None, factors=metric_json)
  251. # 上传日志
  252. device_name_value = f"设备名称:{device_name}, 模型文件:{model_id}, 更新后的指标:{metrics_log}"
  253. # print(device_name_value)
  254. r = requests.post(
  255. "http://m2-backend-svc:8000/api/ai/sys_opt_log/create_one",
  256. json={
  257. "上传日志" : "上传日志",
  258. "user_id": 10,
  259. "type" : "模型自动迭代操作",
  260. "log" : device_name_value
  261. }
  262. )
  263. return
  264. def update_model_points(points_set_dict, model_id):
  265. """
  266. points_set_dict: 包含点位信息的字典
  267. model_id: 模型id
  268. """
  269. url = f"http://m2-backend-svc:8000/api/ai/model/get_details/{model_id}"
  270. update_url = f"http://m2-backend-svc:8000/api/ai/model/update_info/{model_id}"
  271. try:
  272. print("{:=^50s}".format("设置模型文件点位"))
  273. r = requests.get(url=url)
  274. print(f"上传模型点位请求响应:{r}。")
  275. r.raise_for_status()
  276. result = r.json().get('result', {})
  277. device_data = result.get('device_data', {})
  278. device_id = result.get('device_id', {})
  279. # 定义 points 列表
  280. points = []
  281. for key, value in points_set_dict.items():
  282. points.append({
  283. "point_id" : value['point_id'],
  284. "point_class": value['point_class'],
  285. "name" : value['point_name'],
  286. "device_id" : device_id
  287. })
  288. if device_data['point_list'] is None:
  289. device_data['point_list'] = points
  290. else:
  291. for point in points:
  292. existing_point = next(
  293. (p for p in device_data['point_list'] if p['point_class'] == point['point_class']), None)
  294. if existing_point:
  295. existing_point['point_id'] = point['point_id']
  296. print(f"点位 {point['point_class']} 已存在,更新 point_id 为 {point['point_id']}。")
  297. else:
  298. device_data['point_list'].append(point)
  299. print(f"添加新点位 {point}。")
  300. update_r = requests.post(update_url, json=result) # 上传result
  301. if update_r.status_code == 200:
  302. print(f'模型点位保存成功!')
  303. else:
  304. print(f'保存模型点位时出错:', update_r.status_code)
  305. except Exception as e:
  306. print(f"请求错误:{e}")
  307. result = {}
  308. return