commit;
This commit is contained in:
parent
3aa280090e
commit
bcfb3c7a26
89
src/app.py
89
src/app.py
|
|
@ -505,15 +505,9 @@ def run_chip_distribution_collection():
|
|||
# tushare_token=TUSHARE_TOKEN,
|
||||
# mode='full'
|
||||
# )
|
||||
collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
|
||||
start_date='2020-01-02', end_date='2020-12-31')
|
||||
# collect_chip_distribution(
|
||||
# db_url=db_url,
|
||||
# tushare_token=TUSHARE_TOKEN,
|
||||
# mode='daily', # 每日增量采集
|
||||
# date=today,
|
||||
# batch_size=100 # 每100只股票批量入库一次
|
||||
# )
|
||||
# collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
|
||||
# start_date='2025-12-10', end_date='2025-12-30')
|
||||
collect_chip_distribution(db_url, TUSHARE_TOKEN, mode='daily')
|
||||
|
||||
logger.info("每日筹码分布数据采集完成(Tushare)")
|
||||
return jsonify({
|
||||
|
|
@ -3834,8 +3828,12 @@ def run_tech_fundamental_strategy_batch():
|
|||
start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date')
|
||||
end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date')
|
||||
|
||||
if not start_date or not end_date:
|
||||
return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400
|
||||
# 如果未传入日期,默认使用当天日期
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
if not start_date:
|
||||
start_date = today
|
||||
if not end_date:
|
||||
end_date = today
|
||||
|
||||
# 验证日期
|
||||
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
|
|
@ -4130,6 +4128,75 @@ def delete_tag_relation():
|
|||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tag/clear_tags', methods=['POST'])
|
||||
def clear_tags():
|
||||
"""清空 MongoDB 中的行业/概念标签及其标签-股票中间关系(不动Redis队列、不删分析结果、保留其他标签类型)"""
|
||||
try:
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
result = tag_relation_api.clear_all_tags_and_relations()
|
||||
|
||||
if result.get("success"):
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get("error", "清空失败"),
|
||||
"data": result
|
||||
}), 500
|
||||
except Exception as e:
|
||||
logger.error(f"清空标签失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"清空标签失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tag/init_from_mysql', methods=['POST'])
|
||||
def init_tags_from_mysql():
|
||||
"""从MySQL的行业/概念表初始化标签及其与个股的关系到MongoDB(无参数版)"""
|
||||
try:
|
||||
# 检查API是否初始化成功
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
# 直接全量初始化:行业+概念,覆盖已有关联
|
||||
result = tag_relation_api.init_tags_from_mysql(
|
||||
include_hybk=True,
|
||||
include_gnbk=True,
|
||||
overwrite_relations=True
|
||||
)
|
||||
|
||||
if result.get('success'):
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "初始化存在错误,请查看errors字段",
|
||||
"data": result
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化标签失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"初始化标签失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# 启动Web服务器
|
||||
|
|
|
|||
|
|
@ -373,6 +373,165 @@ class TagRelationAPI:
|
|||
logger.error(f"查询gp_gnbk失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def init_tags_from_mysql(
|
||||
self,
|
||||
include_hybk: bool = True,
|
||||
include_gnbk: bool = True,
|
||||
overwrite_relations: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""从MySQL的行业/概念表初始化标签及其与个股的关联关系到MongoDB
|
||||
|
||||
Args:
|
||||
include_hybk: 是否初始化行业标签(gp_hybk)
|
||||
include_gnbk: 是否初始化概念标签(gp_gnbk)
|
||||
overwrite_relations: 是否覆盖已有的标签-个股关联关系
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 初始化统计结果
|
||||
"""
|
||||
stats = {
|
||||
"success": True,
|
||||
"hybk": {
|
||||
"tag_count": 0,
|
||||
"relation_count": 0
|
||||
},
|
||||
"gnbk": {
|
||||
"tag_count": 0,
|
||||
"relation_count": 0
|
||||
},
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
if include_hybk:
|
||||
try:
|
||||
logger.info("开始从 gp_hybk 初始化行业标签及关联关系")
|
||||
sql = """
|
||||
SELECT bk_name, gp_code
|
||||
FROM gp_hybk
|
||||
WHERE bk_name IS NOT NULL
|
||||
AND gp_code IS NOT NULL
|
||||
"""
|
||||
with self.mysql_engine.connect() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
rows = result.fetchall()
|
||||
|
||||
# 聚合为 {tag_name: set(gp_code)}
|
||||
tag_stocks_map: Dict[str, set] = {}
|
||||
for bk_name, gp_code in rows:
|
||||
if not bk_name or not gp_code:
|
||||
continue
|
||||
tag_stocks_map.setdefault(bk_name, set()).add(gp_code)
|
||||
|
||||
for tag_name, stock_codes in tag_stocks_map.items():
|
||||
try:
|
||||
# 过滤有效股票(转为 gp_code_two)
|
||||
valid_stocks = self._filter_valid_stocks(list(stock_codes))
|
||||
if not valid_stocks:
|
||||
logger.info(f"行业标签'{tag_name}'过滤后无有效股票,跳过")
|
||||
continue
|
||||
|
||||
tag_data = {
|
||||
'tag_name': tag_name,
|
||||
'tag_type': '行业标签',
|
||||
'status': 'completed',
|
||||
'progress': 100.0,
|
||||
'source': 'hybk'
|
||||
}
|
||||
tag_code = self.service.database.save_tag(tag_data)
|
||||
if not tag_code:
|
||||
logger.warning(f"行业标签'{tag_name}'保存失败,跳过")
|
||||
continue
|
||||
|
||||
# 保存关联关系
|
||||
self.service.database.save_tag_stock_relations(
|
||||
tag_code,
|
||||
valid_stocks,
|
||||
replace_all=overwrite_relations
|
||||
)
|
||||
|
||||
stats["hybk"]["tag_count"] += 1
|
||||
stats["hybk"]["relation_count"] += len(valid_stocks)
|
||||
except Exception as e:
|
||||
err_msg = f"初始化行业标签'{tag_name}'失败: {str(e)}"
|
||||
logger.error(err_msg, exc_info=True)
|
||||
stats["errors"].append(err_msg)
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"从 gp_hybk 初始化行业标签失败: {str(e)}"
|
||||
logger.error(err_msg, exc_info=True)
|
||||
stats["errors"].append(err_msg)
|
||||
stats["success"] = False
|
||||
|
||||
if include_gnbk:
|
||||
try:
|
||||
logger.info("开始从 gp_gnbk 初始化概念标签及关联关系")
|
||||
sql = """
|
||||
SELECT bk_name, gp_code
|
||||
FROM gp_gnbk
|
||||
WHERE bk_name IS NOT NULL
|
||||
AND gp_code IS NOT NULL
|
||||
"""
|
||||
with self.mysql_engine.connect() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
rows = result.fetchall()
|
||||
|
||||
# 聚合为 {tag_name: set(gp_code)}
|
||||
tag_stocks_map: Dict[str, set] = {}
|
||||
for bk_name, gp_code in rows:
|
||||
if not bk_name or not gp_code:
|
||||
continue
|
||||
tag_stocks_map.setdefault(bk_name, set()).add(gp_code)
|
||||
|
||||
for tag_name, stock_codes in tag_stocks_map.items():
|
||||
try:
|
||||
# 过滤有效股票(转为 gp_code_two)
|
||||
valid_stocks = self._filter_valid_stocks(list(stock_codes))
|
||||
if not valid_stocks:
|
||||
logger.info(f"概念标签'{tag_name}'过滤后无有效股票,跳过")
|
||||
continue
|
||||
|
||||
tag_data = {
|
||||
'tag_name': tag_name,
|
||||
'tag_type': '概念标签',
|
||||
'status': 'completed',
|
||||
'progress': 100.0,
|
||||
'source': 'gnbk'
|
||||
}
|
||||
tag_code = self.service.database.save_tag(tag_data)
|
||||
if not tag_code:
|
||||
logger.warning(f"概念标签'{tag_name}'保存失败,跳过")
|
||||
continue
|
||||
|
||||
# 保存关联关系
|
||||
self.service.database.save_tag_stock_relations(
|
||||
tag_code,
|
||||
valid_stocks,
|
||||
replace_all=overwrite_relations
|
||||
)
|
||||
|
||||
stats["gnbk"]["tag_count"] += 1
|
||||
stats["gnbk"]["relation_count"] += len(valid_stocks)
|
||||
except Exception as e:
|
||||
err_msg = f"初始化概念标签'{tag_name}'失败: {str(e)}"
|
||||
logger.error(err_msg, exc_info=True)
|
||||
stats["errors"].append(err_msg)
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"从 gp_gnbk 初始化概念标签失败: {str(e)}"
|
||||
logger.error(err_msg, exc_info=True)
|
||||
stats["errors"].append(err_msg)
|
||||
stats["success"] = False
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"初始化标签失败: {str(e)}"
|
||||
logger.error(err_msg, exc_info=True)
|
||||
stats["success"] = False
|
||||
stats["errors"].append(err_msg)
|
||||
return stats
|
||||
|
||||
def process_tag(
|
||||
self,
|
||||
tag_name: str,
|
||||
|
|
@ -952,6 +1111,17 @@ class TagRelationAPI:
|
|||
"error": str(e)
|
||||
}
|
||||
|
||||
def clear_all_tags_and_relations(self) -> Dict[str, Any]:
|
||||
"""清空 MongoDB 中行业/概念标签及其中间关系(不操作 Redis 队列)"""
|
||||
try:
|
||||
return self.service.database.clear_all_tags_and_relations()
|
||||
except Exception as e:
|
||||
logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭所有连接"""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import logging
|
||||
import pymongo
|
||||
from pymongo import ReturnDocument
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -47,6 +48,7 @@ class TagRelationDatabase:
|
|||
# 标签集合和关联集合
|
||||
self.tags_collection = None
|
||||
self.tag_stock_relations_collection = None
|
||||
self.counters_collection = None
|
||||
|
||||
self.connect_mongodb()
|
||||
|
||||
|
|
@ -66,6 +68,10 @@ class TagRelationDatabase:
|
|||
# 初始化标签集合和关联集合
|
||||
self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合
|
||||
self.tag_stock_relations_collection = self.db['tag_stock_relations']
|
||||
self.counters_collection = self.db['counters']
|
||||
|
||||
# 确保索引(唯一约束)
|
||||
self._ensure_indexes()
|
||||
|
||||
# 测试连接
|
||||
self.mongo_client.admin.command('ping')
|
||||
|
|
@ -75,6 +81,40 @@ class TagRelationDatabase:
|
|||
logger.error(f"MongoDB连接失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def _ensure_indexes(self):
|
||||
"""为标签集合创建唯一索引,防止 tag_name / tag_code 冲突"""
|
||||
try:
|
||||
# tag_name 唯一
|
||||
self.tags_collection.create_index(
|
||||
[('tag_name', pymongo.ASCENDING)],
|
||||
unique=True,
|
||||
name='idx_tag_name_unique'
|
||||
)
|
||||
# tag_code 唯一
|
||||
self.tags_collection.create_index(
|
||||
[('tag_code', pymongo.ASCENDING)],
|
||||
unique=True,
|
||||
name='idx_tag_code_unique'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"创建索引时发生异常: {str(e)}")
|
||||
|
||||
def _get_next_tag_code(self) -> str:
|
||||
"""使用 counters 集合原子递增生成 tag_code"""
|
||||
try:
|
||||
result = self.counters_collection.find_one_and_update(
|
||||
{'_id': 'tag_code'},
|
||||
{'$inc': {'seq': 1}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER
|
||||
)
|
||||
seq = result.get('seq', 1)
|
||||
return str(seq)
|
||||
except Exception as e:
|
||||
logger.error(f"生成tag_code失败: {str(e)}", exc_info=True)
|
||||
# 兜底:使用时间戳
|
||||
return str(int(datetime.now().timestamp() * 1000))
|
||||
|
||||
def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool:
|
||||
"""保存分析结果到MongoDB
|
||||
|
||||
|
|
@ -200,37 +240,47 @@ class TagRelationDatabase:
|
|||
"""
|
||||
try:
|
||||
# 检查标签是否已存在(根据tag_name)
|
||||
existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')})
|
||||
tag_name = tag_data.get('tag_name')
|
||||
if not tag_name:
|
||||
logger.warning("标签名称为空,无法保存")
|
||||
return None
|
||||
|
||||
existing_tag = self.tags_collection.find_one({'tag_name': tag_name})
|
||||
|
||||
if existing_tag:
|
||||
# 更新现有标签
|
||||
tag_code = existing_tag.get('tag_code')
|
||||
# 保持原 tag_code,不覆盖 create_time
|
||||
update_data = dict(tag_data)
|
||||
update_data['tag_code'] = tag_code
|
||||
# 避免与 create_time 冲突
|
||||
update_data.pop('create_time', None)
|
||||
self.tags_collection.update_one(
|
||||
{'tag_name': tag_data.get('tag_name')},
|
||||
{'$set': tag_data}
|
||||
{'tag_name': tag_name},
|
||||
{'$set': update_data}
|
||||
)
|
||||
logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}")
|
||||
logger.info(f"更新标签: {tag_name}, tag_code: {tag_code}")
|
||||
return tag_code
|
||||
else:
|
||||
# 插入新标签
|
||||
# 如果没有tag_code,生成一个
|
||||
if 'tag_code' not in tag_data or not tag_data.get('tag_code'):
|
||||
# 生成tag_code:使用当前最大ID+1,或者使用时间戳
|
||||
max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)])
|
||||
if max_tag and max_tag.get('tag_code'):
|
||||
try:
|
||||
max_id = int(max_tag['tag_code'])
|
||||
tag_code = str(max_id + 1)
|
||||
except:
|
||||
tag_code = str(int(datetime.now().timestamp() * 1000))
|
||||
else:
|
||||
tag_code = "1"
|
||||
tag_data['tag_code'] = tag_code
|
||||
# 生成新的 tag_code(原子递增)
|
||||
tag_code = tag_data.get('tag_code') or self._get_next_tag_code()
|
||||
tag_data = dict(tag_data)
|
||||
# 新增时通过 $setOnInsert 写入 tag_code,避免冲突
|
||||
tag_data.pop('tag_code', None)
|
||||
create_time = tag_data.pop('create_time', datetime.now().isoformat())
|
||||
|
||||
tag_data['create_time'] = datetime.now().isoformat()
|
||||
result = self.tags_collection.insert_one(tag_data)
|
||||
logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}")
|
||||
return tag_data.get('tag_code')
|
||||
self.tags_collection.update_one(
|
||||
{'tag_name': tag_name},
|
||||
{
|
||||
'$set': tag_data,
|
||||
'$setOnInsert': {
|
||||
'create_time': create_time,
|
||||
'tag_code': tag_code
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
logger.info(f"插入新标签: {tag_name}, tag_code: {tag_code}")
|
||||
return tag_code
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存标签失败: {str(e)}", exc_info=True)
|
||||
|
|
@ -437,6 +487,41 @@ class TagRelationDatabase:
|
|||
result['error'] = str(e)
|
||||
return result
|
||||
|
||||
def clear_all_tags_and_relations(self) -> Dict[str, Any]:
|
||||
"""清空行业/概念标签及其与个股的中间关系(不动其他标签、不动分析结果/Redis)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 删除统计
|
||||
"""
|
||||
result = {
|
||||
'success': False,
|
||||
'deleted_tags': 0,
|
||||
'deleted_relations': 0,
|
||||
'error': None
|
||||
}
|
||||
try:
|
||||
# 仅删除行业/概念标签
|
||||
tag_filter = {
|
||||
'$or': [
|
||||
{'source': {'$in': ['hybk', 'gnbk']}},
|
||||
{'tag_type': {'$in': ['行业标签', '概念标签']}}
|
||||
]
|
||||
}
|
||||
tags = list(self.tags_collection.find(tag_filter, {'tag_code': 1}))
|
||||
tag_codes = [t.get('tag_code') for t in tags if t.get('tag_code')]
|
||||
|
||||
tag_res = self.tags_collection.delete_many(tag_filter)
|
||||
rel_res = self.tag_stock_relations_collection.delete_many({'tag_code': {'$in': tag_codes}}) if tag_codes else type("Dummy", (), {"deleted_count": 0})()
|
||||
|
||||
result['deleted_tags'] = tag_res.deleted_count
|
||||
result['deleted_relations'] = rel_res.deleted_count if hasattr(rel_res, 'deleted_count') else 0
|
||||
result['success'] = True
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"清空标签及关系失败: {str(e)}", exc_info=True)
|
||||
result['error'] = str(e)
|
||||
return result
|
||||
|
||||
def close_connection(self):
|
||||
"""关闭数据库连接"""
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue