Merge remote-tracking branch 'origin/master'

This commit is contained in:
liao 2026-01-14 14:13:06 +08:00
commit cfc06ff720
3 changed files with 356 additions and 34 deletions

View File

@ -505,15 +505,9 @@ def run_chip_distribution_collection():
# tushare_token=TUSHARE_TOKEN,
# mode='full'
# )
collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
start_date='2020-01-02', end_date='2020-12-31')
# collect_chip_distribution(
# db_url=db_url,
# tushare_token=TUSHARE_TOKEN,
# mode='daily', # 每日增量采集
# date=today,
# batch_size=100 # 每100只股票批量入库一次
# )
# collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
# start_date='2025-12-10', end_date='2025-12-30')
collect_chip_distribution(db_url, TUSHARE_TOKEN, mode='daily')
logger.info("每日筹码分布数据采集完成Tushare")
return jsonify({
@ -3834,8 +3828,12 @@ def run_tech_fundamental_strategy_batch():
start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date')
end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date')
if not start_date or not end_date:
return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400
# 如果未传入日期,默认使用当天日期
today = datetime.now().strftime('%Y-%m-%d')
if not start_date:
start_date = today
if not end_date:
end_date = today
# 验证日期
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
@ -4147,6 +4145,75 @@ def delete_tag_relation():
}), 500
@app.route('/api/tag/clear_tags', methods=['POST'])
def clear_tags():
"""清空 MongoDB 中的行业/概念标签及其标签-股票中间关系不动Redis队列、不删分析结果、保留其他标签类型"""
try:
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
result = tag_relation_api.clear_all_tags_and_relations()
if result.get("success"):
return jsonify({
"status": "success",
"data": result
})
else:
return jsonify({
"status": "error",
"message": result.get("error", "清空失败"),
"data": result
}), 500
except Exception as e:
logger.error(f"清空标签失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"清空标签失败: {str(e)}"
}), 500
@app.route('/api/tag/init_from_mysql', methods=['POST'])
def init_tags_from_mysql():
"""从MySQL的行业/概念表初始化标签及其与个股的关系到MongoDB无参数版"""
try:
# 检查API是否初始化成功
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
# 直接全量初始化:行业+概念,覆盖已有关联
result = tag_relation_api.init_tags_from_mysql(
include_hybk=True,
include_gnbk=True,
overwrite_relations=True
)
if result.get('success'):
return jsonify({
"status": "success",
"data": result
})
else:
return jsonify({
"status": "error",
"message": "初始化存在错误请查看errors字段",
"data": result
}), 500
except Exception as e:
logger.error(f"初始化标签失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"初始化标签失败: {str(e)}"
}), 500
if __name__ == '__main__':
# 启动Web服务器

View File

@ -373,6 +373,165 @@ class TagRelationAPI:
logger.error(f"查询gp_gnbk失败: {str(e)}")
return []
def init_tags_from_mysql(
self,
include_hybk: bool = True,
include_gnbk: bool = True,
overwrite_relations: bool = True
) -> Dict[str, Any]:
"""从MySQL的行业/概念表初始化标签及其与个股的关联关系到MongoDB
Args:
include_hybk: 是否初始化行业标签gp_hybk
include_gnbk: 是否初始化概念标签gp_gnbk
overwrite_relations: 是否覆盖已有的标签-个股关联关系
Returns:
Dict[str, Any]: 初始化统计结果
"""
stats = {
"success": True,
"hybk": {
"tag_count": 0,
"relation_count": 0
},
"gnbk": {
"tag_count": 0,
"relation_count": 0
},
"errors": []
}
try:
if include_hybk:
try:
logger.info("开始从 gp_hybk 初始化行业标签及关联关系")
sql = """
SELECT bk_name, gp_code
FROM gp_hybk
WHERE bk_name IS NOT NULL
AND gp_code IS NOT NULL
"""
with self.mysql_engine.connect() as conn:
result = conn.execute(text(sql))
rows = result.fetchall()
# 聚合为 {tag_name: set(gp_code)}
tag_stocks_map: Dict[str, set] = {}
for bk_name, gp_code in rows:
if not bk_name or not gp_code:
continue
tag_stocks_map.setdefault(bk_name, set()).add(gp_code)
for tag_name, stock_codes in tag_stocks_map.items():
try:
# 过滤有效股票(转为 gp_code_two
valid_stocks = self._filter_valid_stocks(list(stock_codes))
if not valid_stocks:
logger.info(f"行业标签'{tag_name}'过滤后无有效股票,跳过")
continue
tag_data = {
'tag_name': tag_name,
'tag_type': '行业标签',
'status': 'completed',
'progress': 100.0,
'source': 'hybk'
}
tag_code = self.service.database.save_tag(tag_data)
if not tag_code:
logger.warning(f"行业标签'{tag_name}'保存失败,跳过")
continue
# 保存关联关系
self.service.database.save_tag_stock_relations(
tag_code,
valid_stocks,
replace_all=overwrite_relations
)
stats["hybk"]["tag_count"] += 1
stats["hybk"]["relation_count"] += len(valid_stocks)
except Exception as e:
err_msg = f"初始化行业标签'{tag_name}'失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
except Exception as e:
err_msg = f"从 gp_hybk 初始化行业标签失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
stats["success"] = False
if include_gnbk:
try:
logger.info("开始从 gp_gnbk 初始化概念标签及关联关系")
sql = """
SELECT bk_name, gp_code
FROM gp_gnbk
WHERE bk_name IS NOT NULL
AND gp_code IS NOT NULL
"""
with self.mysql_engine.connect() as conn:
result = conn.execute(text(sql))
rows = result.fetchall()
# 聚合为 {tag_name: set(gp_code)}
tag_stocks_map: Dict[str, set] = {}
for bk_name, gp_code in rows:
if not bk_name or not gp_code:
continue
tag_stocks_map.setdefault(bk_name, set()).add(gp_code)
for tag_name, stock_codes in tag_stocks_map.items():
try:
# 过滤有效股票(转为 gp_code_two
valid_stocks = self._filter_valid_stocks(list(stock_codes))
if not valid_stocks:
logger.info(f"概念标签'{tag_name}'过滤后无有效股票,跳过")
continue
tag_data = {
'tag_name': tag_name,
'tag_type': '概念标签',
'status': 'completed',
'progress': 100.0,
'source': 'gnbk'
}
tag_code = self.service.database.save_tag(tag_data)
if not tag_code:
logger.warning(f"概念标签'{tag_name}'保存失败,跳过")
continue
# 保存关联关系
self.service.database.save_tag_stock_relations(
tag_code,
valid_stocks,
replace_all=overwrite_relations
)
stats["gnbk"]["tag_count"] += 1
stats["gnbk"]["relation_count"] += len(valid_stocks)
except Exception as e:
err_msg = f"初始化概念标签'{tag_name}'失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
except Exception as e:
err_msg = f"从 gp_gnbk 初始化概念标签失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["errors"].append(err_msg)
stats["success"] = False
return stats
except Exception as e:
err_msg = f"初始化标签失败: {str(e)}"
logger.error(err_msg, exc_info=True)
stats["success"] = False
stats["errors"].append(err_msg)
return stats
def process_tag(
self,
tag_name: str,
@ -952,6 +1111,17 @@ class TagRelationAPI:
"error": str(e)
}
def clear_all_tags_and_relations(self) -> Dict[str, Any]:
"""清空 MongoDB 中行业/概念标签及其中间关系(不操作 Redis 队列)"""
try:
return self.service.database.clear_all_tags_and_relations()
except Exception as e:
logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True)
return {
"success": False,
"error": str(e)
}
def close(self):
"""关闭所有连接"""
try:

View File

@ -5,6 +5,7 @@
import logging
import pymongo
from pymongo import ReturnDocument
from typing import Dict, Any, Optional, List
from datetime import datetime
@ -47,6 +48,7 @@ class TagRelationDatabase:
# 标签集合和关联集合
self.tags_collection = None
self.tag_stock_relations_collection = None
self.counters_collection = None
self.connect_mongodb()
@ -66,6 +68,10 @@ class TagRelationDatabase:
# 初始化标签集合和关联集合
self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合
self.tag_stock_relations_collection = self.db['tag_stock_relations']
self.counters_collection = self.db['counters']
# 确保索引(唯一约束)
self._ensure_indexes()
# 测试连接
self.mongo_client.admin.command('ping')
@ -75,6 +81,40 @@ class TagRelationDatabase:
logger.error(f"MongoDB连接失败: {str(e)}")
raise
def _ensure_indexes(self):
"""为标签集合创建唯一索引,防止 tag_name / tag_code 冲突"""
try:
# tag_name 唯一
self.tags_collection.create_index(
[('tag_name', pymongo.ASCENDING)],
unique=True,
name='idx_tag_name_unique'
)
# tag_code 唯一
self.tags_collection.create_index(
[('tag_code', pymongo.ASCENDING)],
unique=True,
name='idx_tag_code_unique'
)
except Exception as e:
logger.warning(f"创建索引时发生异常: {str(e)}")
def _get_next_tag_code(self) -> str:
"""使用 counters 集合原子递增生成 tag_code"""
try:
result = self.counters_collection.find_one_and_update(
{'_id': 'tag_code'},
{'$inc': {'seq': 1}},
upsert=True,
return_document=ReturnDocument.AFTER
)
seq = result.get('seq', 1)
return str(seq)
except Exception as e:
logger.error(f"生成tag_code失败: {str(e)}", exc_info=True)
# 兜底:使用时间戳
return str(int(datetime.now().timestamp() * 1000))
def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool:
"""保存分析结果到MongoDB
@ -200,37 +240,47 @@ class TagRelationDatabase:
"""
try:
# 检查标签是否已存在根据tag_name
existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')})
tag_name = tag_data.get('tag_name')
if not tag_name:
logger.warning("标签名称为空,无法保存")
return None
existing_tag = self.tags_collection.find_one({'tag_name': tag_name})
if existing_tag:
# 更新现有标签
tag_code = existing_tag.get('tag_code')
# 保持原 tag_code不覆盖 create_time
update_data = dict(tag_data)
update_data['tag_code'] = tag_code
# 避免与 create_time 冲突
update_data.pop('create_time', None)
self.tags_collection.update_one(
{'tag_name': tag_data.get('tag_name')},
{'$set': tag_data}
{'tag_name': tag_name},
{'$set': update_data}
)
logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}")
logger.info(f"更新标签: {tag_name}, tag_code: {tag_code}")
return tag_code
else:
# 插入新标签
# 如果没有tag_code生成一个
if 'tag_code' not in tag_data or not tag_data.get('tag_code'):
# 生成tag_code使用当前最大ID+1或者使用时间戳
max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)])
if max_tag and max_tag.get('tag_code'):
try:
max_id = int(max_tag['tag_code'])
tag_code = str(max_id + 1)
except:
tag_code = str(int(datetime.now().timestamp() * 1000))
else:
tag_code = "1"
tag_data['tag_code'] = tag_code
# 生成新的 tag_code原子递增
tag_code = tag_data.get('tag_code') or self._get_next_tag_code()
tag_data = dict(tag_data)
# 新增时通过 $setOnInsert 写入 tag_code避免冲突
tag_data.pop('tag_code', None)
create_time = tag_data.pop('create_time', datetime.now().isoformat())
tag_data['create_time'] = datetime.now().isoformat()
result = self.tags_collection.insert_one(tag_data)
logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}")
return tag_data.get('tag_code')
self.tags_collection.update_one(
{'tag_name': tag_name},
{
'$set': tag_data,
'$setOnInsert': {
'create_time': create_time,
'tag_code': tag_code
}
},
upsert=True
)
logger.info(f"插入新标签: {tag_name}, tag_code: {tag_code}")
return tag_code
except Exception as e:
logger.error(f"保存标签失败: {str(e)}", exc_info=True)
@ -437,6 +487,41 @@ class TagRelationDatabase:
result['error'] = str(e)
return result
def clear_all_tags_and_relations(self) -> Dict[str, Any]:
"""清空行业/概念标签及其与个股的中间关系(不动其他标签、不动分析结果/Redis
Returns:
Dict[str, Any]: 删除统计
"""
result = {
'success': False,
'deleted_tags': 0,
'deleted_relations': 0,
'error': None
}
try:
# 仅删除行业/概念标签
tag_filter = {
'$or': [
{'source': {'$in': ['hybk', 'gnbk']}},
{'tag_type': {'$in': ['行业标签', '概念标签']}}
]
}
tags = list(self.tags_collection.find(tag_filter, {'tag_code': 1}))
tag_codes = [t.get('tag_code') for t in tags if t.get('tag_code')]
tag_res = self.tags_collection.delete_many(tag_filter)
rel_res = self.tag_stock_relations_collection.delete_many({'tag_code': {'$in': tag_codes}}) if tag_codes else type("Dummy", (), {"deleted_count": 0})()
result['deleted_tags'] = tag_res.deleted_count
result['deleted_relations'] = rel_res.deleted_count if hasattr(rel_res, 'deleted_count') else 0
result['success'] = True
return result
except Exception as e:
logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True)
result['error'] = str(e)
return result
def close_connection(self):
"""关闭数据库连接"""
try: