diff --git a/src/app.py b/src/app.py index 2b6a3ad..805d53d 100644 --- a/src/app.py +++ b/src/app.py @@ -505,15 +505,9 @@ def run_chip_distribution_collection(): # tushare_token=TUSHARE_TOKEN, # mode='full' # ) - collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full', - start_date='2020-01-02', end_date='2020-12-31') - # collect_chip_distribution( - # db_url=db_url, - # tushare_token=TUSHARE_TOKEN, - # mode='daily', # 每日增量采集 - # date=today, - # batch_size=100 # 每100只股票批量入库一次 - # ) + # collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full', + # start_date='2025-12-10', end_date='2025-12-30') + collect_chip_distribution(db_url, TUSHARE_TOKEN, mode='daily') logger.info("每日筹码分布数据采集完成(Tushare)") return jsonify({ @@ -3834,8 +3828,12 @@ def run_tech_fundamental_strategy_batch(): start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date') end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date') - if not start_date or not end_date: - return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400 + # 如果未传入日期,默认使用当天日期 + today = datetime.now().strftime('%Y-%m-%d') + if not start_date: + start_date = today + if not end_date: + end_date = today # 验证日期 start_dt = datetime.strptime(start_date, '%Y-%m-%d') @@ -4147,6 +4145,75 @@ def delete_tag_relation(): }), 500 +@app.route('/api/tag/clear_tags', methods=['POST']) +def clear_tags(): + """清空 MongoDB 中的行业/概念标签及其标签-股票中间关系(不动Redis队列、不删分析结果、保留其他标签类型)""" + try: + if tag_relation_api is None: + return jsonify({ + "status": "error", + "message": "标签关联分析API未初始化" + }), 500 + + result = tag_relation_api.clear_all_tags_and_relations() + + if result.get("success"): + return jsonify({ + "status": "success", + "data": result + }) + else: + return jsonify({ + "status": "error", + "message": result.get("error", "清空失败"), + "data": result + }), 500 + except Exception as e: + logger.error(f"清空标签失败: {str(e)}", exc_info=True) + return jsonify({ + "status": "error", + "message": f"清空标签失败: {str(e)}" + }), 500 + + +@app.route('/api/tag/init_from_mysql', methods=['POST']) +def init_tags_from_mysql(): + """从MySQL的行业/概念表初始化标签及其与个股的关系到MongoDB(无参数版)""" + try: + # 检查API是否初始化成功 + if tag_relation_api is None: + return jsonify({ + "status": "error", + "message": "标签关联分析API未初始化" + }), 500 + + # 直接全量初始化:行业+概念,覆盖已有关联 + result = tag_relation_api.init_tags_from_mysql( + include_hybk=True, + include_gnbk=True, + overwrite_relations=True + ) + + if result.get('success'): + return jsonify({ + "status": "success", + "data": result + }) + else: + return jsonify({ + "status": "error", + "message": "初始化存在错误,请查看errors字段", + "data": result + }), 500 + + except Exception as e: + logger.error(f"初始化标签失败: {str(e)}", exc_info=True) + return jsonify({ + "status": "error", + "message": f"初始化标签失败: {str(e)}" + }), 500 + + if __name__ == '__main__': # 启动Web服务器 diff --git a/src/stock_tag_analysis/tag_relation_api.py b/src/stock_tag_analysis/tag_relation_api.py index d941cd7..efa738a 100644 --- a/src/stock_tag_analysis/tag_relation_api.py +++ b/src/stock_tag_analysis/tag_relation_api.py @@ -373,6 +373,165 @@ class TagRelationAPI: logger.error(f"查询gp_gnbk失败: {str(e)}") return [] + def init_tags_from_mysql( + self, + include_hybk: bool = True, + include_gnbk: bool = True, + overwrite_relations: bool = True + ) -> Dict[str, Any]: + """从MySQL的行业/概念表初始化标签及其与个股的关联关系到MongoDB + + Args: + include_hybk: 是否初始化行业标签(gp_hybk) + include_gnbk: 是否初始化概念标签(gp_gnbk) + overwrite_relations: 是否覆盖已有的标签-个股关联关系 + + Returns: + Dict[str, Any]: 初始化统计结果 + """ + stats = { + "success": True, + "hybk": { + "tag_count": 0, + "relation_count": 0 + }, + "gnbk": { + "tag_count": 0, + "relation_count": 0 + }, + "errors": [] + } + + try: + if include_hybk: + try: + logger.info("开始从 gp_hybk 初始化行业标签及关联关系") + sql = """ + SELECT bk_name, gp_code + FROM gp_hybk + WHERE bk_name IS NOT NULL + AND gp_code IS NOT NULL + """ + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql)) + rows = result.fetchall() + + # 聚合为 {tag_name: set(gp_code)} + tag_stocks_map: Dict[str, set] = {} + for bk_name, gp_code in rows: + if not bk_name or not gp_code: + continue + tag_stocks_map.setdefault(bk_name, set()).add(gp_code) + + for tag_name, stock_codes in tag_stocks_map.items(): + try: + # 过滤有效股票(转为 gp_code_two) + valid_stocks = self._filter_valid_stocks(list(stock_codes)) + if not valid_stocks: + logger.info(f"行业标签'{tag_name}'过滤后无有效股票,跳过") + continue + + tag_data = { + 'tag_name': tag_name, + 'tag_type': '行业标签', + 'status': 'completed', + 'progress': 100.0, + 'source': 'hybk' + } + tag_code = self.service.database.save_tag(tag_data) + if not tag_code: + logger.warning(f"行业标签'{tag_name}'保存失败,跳过") + continue + + # 保存关联关系 + self.service.database.save_tag_stock_relations( + tag_code, + valid_stocks, + replace_all=overwrite_relations + ) + + stats["hybk"]["tag_count"] += 1 + stats["hybk"]["relation_count"] += len(valid_stocks) + except Exception as e: + err_msg = f"初始化行业标签'{tag_name}'失败: {str(e)}" + logger.error(err_msg, exc_info=True) + stats["errors"].append(err_msg) + + except Exception as e: + err_msg = f"从 gp_hybk 初始化行业标签失败: {str(e)}" + logger.error(err_msg, exc_info=True) + stats["errors"].append(err_msg) + stats["success"] = False + + if include_gnbk: + try: + logger.info("开始从 gp_gnbk 初始化概念标签及关联关系") + sql = """ + SELECT bk_name, gp_code + FROM gp_gnbk + WHERE bk_name IS NOT NULL + AND gp_code IS NOT NULL + """ + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql)) + rows = result.fetchall() + + # 聚合为 {tag_name: set(gp_code)} + tag_stocks_map: Dict[str, set] = {} + for bk_name, gp_code in rows: + if not bk_name or not gp_code: + continue + tag_stocks_map.setdefault(bk_name, set()).add(gp_code) + + for tag_name, stock_codes in tag_stocks_map.items(): + try: + # 过滤有效股票(转为 gp_code_two) + valid_stocks = self._filter_valid_stocks(list(stock_codes)) + if not valid_stocks: + logger.info(f"概念标签'{tag_name}'过滤后无有效股票,跳过") + continue + + tag_data = { + 'tag_name': tag_name, + 'tag_type': '概念标签', + 'status': 'completed', + 'progress': 100.0, + 'source': 'gnbk' + } + tag_code = self.service.database.save_tag(tag_data) + if not tag_code: + logger.warning(f"概念标签'{tag_name}'保存失败,跳过") + continue + + # 保存关联关系 + self.service.database.save_tag_stock_relations( + tag_code, + valid_stocks, + replace_all=overwrite_relations + ) + + stats["gnbk"]["tag_count"] += 1 + stats["gnbk"]["relation_count"] += len(valid_stocks) + except Exception as e: + err_msg = f"初始化概念标签'{tag_name}'失败: {str(e)}" + logger.error(err_msg, exc_info=True) + stats["errors"].append(err_msg) + + except Exception as e: + err_msg = f"从 gp_gnbk 初始化概念标签失败: {str(e)}" + logger.error(err_msg, exc_info=True) + stats["errors"].append(err_msg) + stats["success"] = False + + return stats + + except Exception as e: + err_msg = f"初始化标签失败: {str(e)}" + logger.error(err_msg, exc_info=True) + stats["success"] = False + stats["errors"].append(err_msg) + return stats + def process_tag( self, tag_name: str, @@ -952,6 +1111,17 @@ class TagRelationAPI: "error": str(e) } + def clear_all_tags_and_relations(self) -> Dict[str, Any]: + """清空 MongoDB 中行业/概念标签及其中间关系(不操作 Redis 队列)""" + try: + return self.service.database.clear_all_tags_and_relations() + except Exception as e: + logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True) + return { + "success": False, + "error": str(e) + } + def close(self): """关闭所有连接""" try: diff --git a/src/stock_tag_analysis/tag_relation_database.py b/src/stock_tag_analysis/tag_relation_database.py index 64fe06a..e460bae 100644 --- a/src/stock_tag_analysis/tag_relation_database.py +++ b/src/stock_tag_analysis/tag_relation_database.py @@ -5,6 +5,7 @@ import logging import pymongo +from pymongo import ReturnDocument from typing import Dict, Any, Optional, List from datetime import datetime @@ -47,6 +48,7 @@ class TagRelationDatabase: # 标签集合和关联集合 self.tags_collection = None self.tag_stock_relations_collection = None + self.counters_collection = None self.connect_mongodb() @@ -66,6 +68,10 @@ class TagRelationDatabase: # 初始化标签集合和关联集合 self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合 self.tag_stock_relations_collection = self.db['tag_stock_relations'] + self.counters_collection = self.db['counters'] + + # 确保索引(唯一约束) + self._ensure_indexes() # 测试连接 self.mongo_client.admin.command('ping') @@ -75,6 +81,40 @@ class TagRelationDatabase: logger.error(f"MongoDB连接失败: {str(e)}") raise + def _ensure_indexes(self): + """为标签集合创建唯一索引,防止 tag_name / tag_code 冲突""" + try: + # tag_name 唯一 + self.tags_collection.create_index( + [('tag_name', pymongo.ASCENDING)], + unique=True, + name='idx_tag_name_unique' + ) + # tag_code 唯一 + self.tags_collection.create_index( + [('tag_code', pymongo.ASCENDING)], + unique=True, + name='idx_tag_code_unique' + ) + except Exception as e: + logger.warning(f"创建索引时发生异常: {str(e)}") + + def _get_next_tag_code(self) -> str: + """使用 counters 集合原子递增生成 tag_code""" + try: + result = self.counters_collection.find_one_and_update( + {'_id': 'tag_code'}, + {'$inc': {'seq': 1}}, + upsert=True, + return_document=ReturnDocument.AFTER + ) + seq = result.get('seq', 1) + return str(seq) + except Exception as e: + logger.error(f"生成tag_code失败: {str(e)}", exc_info=True) + # 兜底:使用时间戳 + return str(int(datetime.now().timestamp() * 1000)) + def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool: """保存分析结果到MongoDB @@ -200,37 +240,47 @@ class TagRelationDatabase: """ try: # 检查标签是否已存在(根据tag_name) - existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')}) + tag_name = tag_data.get('tag_name') + if not tag_name: + logger.warning("标签名称为空,无法保存") + return None + + existing_tag = self.tags_collection.find_one({'tag_name': tag_name}) if existing_tag: - # 更新现有标签 tag_code = existing_tag.get('tag_code') + # 保持原 tag_code,不覆盖 create_time + update_data = dict(tag_data) + update_data['tag_code'] = tag_code + # 避免与 create_time 冲突 + update_data.pop('create_time', None) self.tags_collection.update_one( - {'tag_name': tag_data.get('tag_name')}, - {'$set': tag_data} + {'tag_name': tag_name}, + {'$set': update_data} ) - logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}") + logger.info(f"更新标签: {tag_name}, tag_code: {tag_code}") return tag_code else: - # 插入新标签 - # 如果没有tag_code,生成一个 - if 'tag_code' not in tag_data or not tag_data.get('tag_code'): - # 生成tag_code:使用当前最大ID+1,或者使用时间戳 - max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)]) - if max_tag and max_tag.get('tag_code'): - try: - max_id = int(max_tag['tag_code']) - tag_code = str(max_id + 1) - except: - tag_code = str(int(datetime.now().timestamp() * 1000)) - else: - tag_code = "1" - tag_data['tag_code'] = tag_code + # 生成新的 tag_code(原子递增) + tag_code = tag_data.get('tag_code') or self._get_next_tag_code() + tag_data = dict(tag_data) + # 新增时通过 $setOnInsert 写入 tag_code,避免冲突 + tag_data.pop('tag_code', None) + create_time = tag_data.pop('create_time', datetime.now().isoformat()) - tag_data['create_time'] = datetime.now().isoformat() - result = self.tags_collection.insert_one(tag_data) - logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}") - return tag_data.get('tag_code') + self.tags_collection.update_one( + {'tag_name': tag_name}, + { + '$set': tag_data, + '$setOnInsert': { + 'create_time': create_time, + 'tag_code': tag_code + } + }, + upsert=True + ) + logger.info(f"插入新标签: {tag_name}, tag_code: {tag_code}") + return tag_code except Exception as e: logger.error(f"保存标签失败: {str(e)}", exc_info=True) @@ -437,6 +487,41 @@ class TagRelationDatabase: result['error'] = str(e) return result + def clear_all_tags_and_relations(self) -> Dict[str, Any]: + """清空行业/概念标签及其与个股的中间关系(不动其他标签、不动分析结果/Redis) + + Returns: + Dict[str, Any]: 删除统计 + """ + result = { + 'success': False, + 'deleted_tags': 0, + 'deleted_relations': 0, + 'error': None + } + try: + # 仅删除行业/概念标签 + tag_filter = { + '$or': [ + {'source': {'$in': ['hybk', 'gnbk']}}, + {'tag_type': {'$in': ['行业标签', '概念标签']}} + ] + } + tags = list(self.tags_collection.find(tag_filter, {'tag_code': 1})) + tag_codes = [t.get('tag_code') for t in tags if t.get('tag_code')] + + tag_res = self.tags_collection.delete_many(tag_filter) + rel_res = self.tag_stock_relations_collection.delete_many({'tag_code': {'$in': tag_codes}}) if tag_codes else type("Dummy", (), {"deleted_count": 0})() + + result['deleted_tags'] = tag_res.deleted_count + result['deleted_relations'] = rel_res.deleted_count if hasattr(rel_res, 'deleted_count') else 0 + result['success'] = True + return result + except Exception as e: + logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True) + result['error'] = str(e) + return result + def close_connection(self): """关闭数据库连接""" try: