This commit is contained in:
满脸小星星 2026-01-14 09:17:17 +08:00
parent 3aa280090e
commit bcfb3c7a26
3 changed files with 356 additions and 34 deletions

View File

@ -505,15 +505,9 @@ def run_chip_distribution_collection():
# tushare_token=TUSHARE_TOKEN, # tushare_token=TUSHARE_TOKEN,
# mode='full' # mode='full'
# ) # )
collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full', # collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
start_date='2020-01-02', end_date='2020-12-31') # start_date='2025-12-10', end_date='2025-12-30')
# collect_chip_distribution( collect_chip_distribution(db_url, TUSHARE_TOKEN, mode='daily')
# db_url=db_url,
# tushare_token=TUSHARE_TOKEN,
# mode='daily', # 每日增量采集
# date=today,
# batch_size=100 # 每100只股票批量入库一次
# )
logger.info("每日筹码分布数据采集完成Tushare") logger.info("每日筹码分布数据采集完成Tushare")
return jsonify({ return jsonify({
@ -3834,8 +3828,12 @@ def run_tech_fundamental_strategy_batch():
start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date') start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date')
end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date') end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date')
if not start_date or not end_date: # 如果未传入日期,默认使用当天日期
return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400 today = datetime.now().strftime('%Y-%m-%d')
if not start_date:
start_date = today
if not end_date:
end_date = today
# 验证日期 # 验证日期
start_dt = datetime.strptime(start_date, '%Y-%m-%d') start_dt = datetime.strptime(start_date, '%Y-%m-%d')
@ -4130,6 +4128,75 @@ def delete_tag_relation():
}), 500 }), 500
@app.route('/api/tag/clear_tags', methods=['POST'])
def clear_tags():
"""清空 MongoDB 中的行业/概念标签及其标签-股票中间关系不动Redis队列、不删分析结果、保留其他标签类型"""
try:
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
result = tag_relation_api.clear_all_tags_and_relations()
if result.get("success"):
return jsonify({
"status": "success",
"data": result
})
else:
return jsonify({
"status": "error",
"message": result.get("error", "清空失败"),
"data": result
}), 500
except Exception as e:
logger.error(f"清空标签失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"清空标签失败: {str(e)}"
}), 500
@app.route('/api/tag/init_from_mysql', methods=['POST'])
def init_tags_from_mysql():
"""从MySQL的行业/概念表初始化标签及其与个股的关系到MongoDB无参数版"""
try:
# 检查API是否初始化成功
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
# 直接全量初始化:行业+概念,覆盖已有关联
result = tag_relation_api.init_tags_from_mysql(
include_hybk=True,
include_gnbk=True,
overwrite_relations=True
)
if result.get('success'):
return jsonify({
"status": "success",
"data": result
})
else:
return jsonify({
"status": "error",
"message": "初始化存在错误请查看errors字段",
"data": result
}), 500
except Exception as e:
logger.error(f"初始化标签失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"初始化标签失败: {str(e)}"
}), 500
if __name__ == '__main__': if __name__ == '__main__':
# 启动Web服务器 # 启动Web服务器

View File

@ -373,6 +373,165 @@ class TagRelationAPI:
logger.error(f"查询gp_gnbk失败: {str(e)}") logger.error(f"查询gp_gnbk失败: {str(e)}")
return [] return []
def init_tags_from_mysql(
self,
include_hybk: bool = True,
include_gnbk: bool = True,
overwrite_relations: bool = True
) -> Dict[str, Any]:
"""从MySQL的行业/概念表初始化标签及其与个股的关联关系到MongoDB
Args:
include_hybk: 是否初始化行业标签gp_hybk
include_gnbk: 是否初始化概念标签gp_gnbk
overwrite_relations: 是否覆盖已有的标签-个股关联关系
Returns:
Dict[str, Any]: 初始化统计结果
"""
stats = {
"success": True,
"hybk": {
"tag_count": 0,
"relation_count": 0
},
"gnbk": {
"tag_count": 0,
"relation_count": 0
},
"errors": []
}
try:
if include_hybk:
try:
logger.info("开始从 gp_hybk 初始化行业标签及关联关系")
sql = """
SELECT bk_name, gp_code
FROM gp_hybk
WHERE bk_name IS NOT NULL
AND gp_code IS NOT NULL
"""
with self.mysql_engine.connect() as conn:
result = conn.execute(text(sql))
rows = result.fetchall()
# 聚合为 {tag_name: set(gp_code)}
tag_stocks_map: Dict[str, set] = {}
for bk_name, gp_code in rows:
if not bk_name or not gp_code:
continue
tag_stocks_map.setdefault(bk_name, set()).add(gp_code)
for tag_name, stock_codes in tag_stocks_map.items():
try:
# 过滤有效股票(转为 gp_code_two
valid_stocks = self._filter_valid_stocks(list(stock_codes))
if not valid_stocks:
logger.info(f"行业标签'{tag_name}'过滤后无有效股票,跳过")
continue
tag_data = {
'tag_name': tag_name,
'tag_type': '行业标签',
'status': 'completed',
'progress': 100.0,
'source': 'hybk'
}
tag_code = self.service.database.save_tag(tag_data)
if not tag_code:
logger.warning(f"行业标签'{tag_name}'保存失败,跳过")
continue
# 保存关联关系
self.service.database.save_tag_stock_relations(
tag_code,
valid_stocks,
replace_all=overwrite_relations
)
stats["hybk"]["tag_count"] += 1
stats["hybk"]["relation_count"] += len(valid_stocks)
except Exception as e:
err_msg = f"初始化行业标签'{tag_name}'失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
except Exception as e:
err_msg = f"从 gp_hybk 初始化行业标签失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
stats["success"] = False
if include_gnbk:
try:
logger.info("开始从 gp_gnbk 初始化概念标签及关联关系")
sql = """
SELECT bk_name, gp_code
FROM gp_gnbk
WHERE bk_name IS NOT NULL
AND gp_code IS NOT NULL
"""
with self.mysql_engine.connect() as conn:
result = conn.execute(text(sql))
rows = result.fetchall()
# 聚合为 {tag_name: set(gp_code)}
tag_stocks_map: Dict[str, set] = {}
for bk_name, gp_code in rows:
if not bk_name or not gp_code:
continue
tag_stocks_map.setdefault(bk_name, set()).add(gp_code)
for tag_name, stock_codes in tag_stocks_map.items():
try:
# 过滤有效股票(转为 gp_code_two
valid_stocks = self._filter_valid_stocks(list(stock_codes))
if not valid_stocks:
logger.info(f"概念标签'{tag_name}'过滤后无有效股票,跳过")
continue
tag_data = {
'tag_name': tag_name,
'tag_type': '概念标签',
'status': 'completed',
'progress': 100.0,
'source': 'gnbk'
}
tag_code = self.service.database.save_tag(tag_data)
if not tag_code:
logger.warning(f"概念标签'{tag_name}'保存失败,跳过")
continue
# 保存关联关系
self.service.database.save_tag_stock_relations(
tag_code,
valid_stocks,
replace_all=overwrite_relations
)
stats["gnbk"]["tag_count"] += 1
stats["gnbk"]["relation_count"] += len(valid_stocks)
except Exception as e:
err_msg = f"初始化概念标签'{tag_name}'失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
except Exception as e:
err_msg = f"从 gp_gnbk 初始化概念标签失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
stats["success"] = False
return stats
except Exception as e:
err_msg = f"初始化标签失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["success"] = False
stats["errors"].append(err_msg)
return stats
def process_tag( def process_tag(
self, self,
tag_name: str, tag_name: str,
@ -952,6 +1111,17 @@ class TagRelationAPI:
"error": str(e) "error": str(e)
} }
def clear_all_tags_and_relations(self) -> Dict[str, Any]:
"""清空 MongoDB 中行业/概念标签及其中间关系(不操作 Redis 队列)"""
try:
return self.service.database.clear_all_tags_and_relations()
except Exception as e:
logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True)
return {
"success": False,
"error": str(e)
}
def close(self): def close(self):
"""关闭所有连接""" """关闭所有连接"""
try: try:

View File

@ -5,6 +5,7 @@
import logging import logging
import pymongo import pymongo
from pymongo import ReturnDocument
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from datetime import datetime from datetime import datetime
@ -47,6 +48,7 @@ class TagRelationDatabase:
# 标签集合和关联集合 # 标签集合和关联集合
self.tags_collection = None self.tags_collection = None
self.tag_stock_relations_collection = None self.tag_stock_relations_collection = None
self.counters_collection = None
self.connect_mongodb() self.connect_mongodb()
@ -66,6 +68,10 @@ class TagRelationDatabase:
# 初始化标签集合和关联集合 # 初始化标签集合和关联集合
self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合 self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合
self.tag_stock_relations_collection = self.db['tag_stock_relations'] self.tag_stock_relations_collection = self.db['tag_stock_relations']
self.counters_collection = self.db['counters']
# 确保索引(唯一约束)
self._ensure_indexes()
# 测试连接 # 测试连接
self.mongo_client.admin.command('ping') self.mongo_client.admin.command('ping')
@ -75,6 +81,40 @@ class TagRelationDatabase:
logger.error(f"MongoDB连接失败: {str(e)}") logger.error(f"MongoDB连接失败: {str(e)}")
raise raise
def _ensure_indexes(self):
"""为标签集合创建唯一索引,防止 tag_name / tag_code 冲突"""
try:
# tag_name 唯一
self.tags_collection.create_index(
[('tag_name', pymongo.ASCENDING)],
unique=True,
name='idx_tag_name_unique'
)
# tag_code 唯一
self.tags_collection.create_index(
[('tag_code', pymongo.ASCENDING)],
unique=True,
name='idx_tag_code_unique'
)
except Exception as e:
logger.warning(f"创建索引时发生异常: {str(e)}")
def _get_next_tag_code(self) -> str:
"""使用 counters 集合原子递增生成 tag_code"""
try:
result = self.counters_collection.find_one_and_update(
{'_id': 'tag_code'},
{'$inc': {'seq': 1}},
upsert=True,
return_document=ReturnDocument.AFTER
)
seq = result.get('seq', 1)
return str(seq)
except Exception as e:
logger.error(f"生成tag_code失败: {str(e)}", exc_info=True)
# 兜底:使用时间戳
return str(int(datetime.now().timestamp() * 1000))
def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool: def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool:
"""保存分析结果到MongoDB """保存分析结果到MongoDB
@ -200,37 +240,47 @@ class TagRelationDatabase:
""" """
try: try:
# 检查标签是否已存在根据tag_name # 检查标签是否已存在根据tag_name
existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')}) tag_name = tag_data.get('tag_name')
if not tag_name:
logger.warning("标签名称为空,无法保存")
return None
existing_tag = self.tags_collection.find_one({'tag_name': tag_name})
if existing_tag: if existing_tag:
# 更新现有标签
tag_code = existing_tag.get('tag_code') tag_code = existing_tag.get('tag_code')
# 保持原 tag_code不覆盖 create_time
update_data = dict(tag_data)
update_data['tag_code'] = tag_code
# 避免与 create_time 冲突
update_data.pop('create_time', None)
self.tags_collection.update_one( self.tags_collection.update_one(
{'tag_name': tag_data.get('tag_name')}, {'tag_name': tag_name},
{'$set': tag_data} {'$set': update_data}
) )
logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}") logger.info(f"更新标签: {tag_name}, tag_code: {tag_code}")
return tag_code return tag_code
else: else:
# 插入新标签 # 生成新的 tag_code原子递增
# 如果没有tag_code生成一个 tag_code = tag_data.get('tag_code') or self._get_next_tag_code()
if 'tag_code' not in tag_data or not tag_data.get('tag_code'): tag_data = dict(tag_data)
# 生成tag_code使用当前最大ID+1或者使用时间戳 # 新增时通过 $setOnInsert 写入 tag_code避免冲突
max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)]) tag_data.pop('tag_code', None)
if max_tag and max_tag.get('tag_code'): create_time = tag_data.pop('create_time', datetime.now().isoformat())
try:
max_id = int(max_tag['tag_code'])
tag_code = str(max_id + 1)
except:
tag_code = str(int(datetime.now().timestamp() * 1000))
else:
tag_code = "1"
tag_data['tag_code'] = tag_code
tag_data['create_time'] = datetime.now().isoformat() self.tags_collection.update_one(
result = self.tags_collection.insert_one(tag_data) {'tag_name': tag_name},
logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}") {
return tag_data.get('tag_code') '$set': tag_data,
'$setOnInsert': {
'create_time': create_time,
'tag_code': tag_code
}
},
upsert=True
)
logger.info(f"插入新标签: {tag_name}, tag_code: {tag_code}")
return tag_code
except Exception as e: except Exception as e:
logger.error(f"保存标签失败: {str(e)}", exc_info=True) logger.error(f"保存标签失败: {str(e)}", exc_info=True)
@ -437,6 +487,41 @@ class TagRelationDatabase:
result['error'] = str(e) result['error'] = str(e)
return result return result
def clear_all_tags_and_relations(self) -> Dict[str, Any]:
"""清空行业/概念标签及其与个股的中间关系(不动其他标签、不动分析结果/Redis
Returns:
Dict[str, Any]: 删除统计
"""
result = {
'success': False,
'deleted_tags': 0,
'deleted_relations': 0,
'error': None
}
try:
# 仅删除行业/概念标签
tag_filter = {
'$or': [
{'source': {'$in': ['hybk', 'gnbk']}},
{'tag_type': {'$in': ['行业标签', '概念标签']}}
]
}
tags = list(self.tags_collection.find(tag_filter, {'tag_code': 1}))
tag_codes = [t.get('tag_code') for t in tags if t.get('tag_code')]
tag_res = self.tags_collection.delete_many(tag_filter)
rel_res = self.tag_stock_relations_collection.delete_many({'tag_code': {'$in': tag_codes}}) if tag_codes else type("Dummy", (), {"deleted_count": 0})()
result['deleted_tags'] = tag_res.deleted_count
result['deleted_relations'] = rel_res.deleted_count if hasattr(rel_res, 'deleted_count') else 0
result['success'] = True
return result
except Exception as e:
logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True)
result['error'] = str(e)
return result
def close_connection(self): def close_connection(self):
"""关闭数据库连接""" """关闭数据库连接"""
try: try: