commit;
This commit is contained in:
parent
2a03d088ef
commit
982e92c403
|
|
@ -3,6 +3,7 @@ werkzeug==2.0.3
|
|||
flask-cors==3.0.10
|
||||
sqlalchemy==2.0.40
|
||||
pymysql==1.0.3
|
||||
psycopg2-binary==2.9.9
|
||||
tqdm>=4.65.0
|
||||
easy-spider-tool>=0.0.4
|
||||
easy-twitter-crawler>=0.0.4
|
||||
|
|
|
|||
109
src/app.py
109
src/app.py
|
|
@ -4230,6 +4230,115 @@ def init_tags_from_mysql():
|
|||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tag/calculate_scores', methods=['POST', 'GET'])
|
||||
def calculate_tag_hot_scores():
|
||||
"""计算标签平均评分并保存到MySQL数据库接口
|
||||
|
||||
该接口会:
|
||||
1. 从MongoDB的stock_pool_tags集合获取所有标签
|
||||
2. 通过标签查找tag_stock_relations集合,获取关联的股票代码
|
||||
3. 从PostgreSQL的t_signal_daily_results表,计算每个股票近5天的平均total_score
|
||||
4. 用平均加权方式计算每个标签下的平均total_score
|
||||
5. 保存到MySQL 16.150数据库的tag_hot_scores表
|
||||
|
||||
返回:
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"success": true,
|
||||
"message": "标签评分计算并保存成功",
|
||||
"processed_count": 150,
|
||||
"batch_no": 1,
|
||||
"tag_count": 150
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 检查API是否初始化成功
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
# 调用计算方法
|
||||
result = tag_relation_api.calculate_and_save_tag_hot_scores()
|
||||
|
||||
# 返回结果
|
||||
if result.get('success'):
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get('message', '计算标签评分失败'),
|
||||
"data": result
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算标签评分失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"计算标签评分失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/stock/calculate_scores', methods=['POST', 'GET'])
|
||||
def calculate_stock_hot_scores():
|
||||
"""计算所有股票平均评分并保存到MySQL数据库接口
|
||||
|
||||
该接口会:
|
||||
1. 从153数据库的gp_code_all表获取所有股票的gp_code_two
|
||||
2. 从PostgreSQL的t_signal_daily_results表,计算每个股票近5天的平均total_score
|
||||
3. 计算每个股票在所有股票中的排名
|
||||
4. 保存到MySQL 16.150数据库的stock_hot_scores表
|
||||
|
||||
返回:
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"success": true,
|
||||
"message": "股票评分计算并保存成功",
|
||||
"processed_count": 5000,
|
||||
"batch_no": 1,
|
||||
"stock_count": 5000
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 检查API是否初始化成功
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
# 调用计算方法
|
||||
result = tag_relation_api.calculate_and_save_stock_hot_scores()
|
||||
|
||||
# 返回结果
|
||||
if result.get('success'):
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get('message', '计算股票评分失败'),
|
||||
"data": result
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算股票评分失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"计算股票评分失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# 启动Web服务器
|
||||
|
|
|
|||
|
|
@ -1121,6 +1121,57 @@ class TagRelationAPI:
|
|||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def calculate_and_save_tag_hot_scores(self) -> Dict[str, Any]:
|
||||
"""计算标签平均评分并保存到MySQL数据库
|
||||
|
||||
逻辑:
|
||||
1. 从MongoDB的stock_pool_tags集合获取所有标签
|
||||
2. 通过标签查找tag_stock_relations集合,获取关联的股票代码
|
||||
3. 从PostgreSQL的t_signal_daily_results表,计算每个股票近5天的平均total_score
|
||||
4. 用平均加权方式计算每个标签下的平均total_score
|
||||
5. 保存到MySQL 16.150数据库的tag_hot_scores表
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 执行结果,包含成功状态、处理数量等信息
|
||||
"""
|
||||
try:
|
||||
logger.info("开始计算标签平均评分...")
|
||||
result = self.service.calculate_and_save_tag_hot_scores()
|
||||
logger.info(f"标签评分计算完成: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"计算标签评分失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"计算标签评分失败: {str(e)}",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
def calculate_and_save_stock_hot_scores(self) -> Dict[str, Any]:
|
||||
"""计算所有股票的平均评分并保存到MySQL数据库
|
||||
|
||||
逻辑:
|
||||
1. 从153数据库的gp_code_all表获取所有股票的gp_code_two
|
||||
2. 从PostgreSQL的t_signal_daily_results表,计算每个股票近5天的平均total_score
|
||||
3. 计算每个股票在所有股票中的排名
|
||||
4. 保存到MySQL 16.150数据库的stock_hot_scores表
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 执行结果,包含成功状态、处理数量等信息
|
||||
"""
|
||||
try:
|
||||
logger.info("开始计算所有股票平均评分...")
|
||||
result = self.service.calculate_and_save_stock_hot_scores()
|
||||
logger.info(f"股票评分计算完成: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"计算股票评分失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"计算股票评分失败: {str(e)}",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭所有连接"""
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ import logging
|
|||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import create_engine, text
|
||||
import pandas as pd
|
||||
|
||||
# 处理相对导入和绝对导入
|
||||
try:
|
||||
|
|
@ -187,6 +190,543 @@ class TagRelationService:
|
|||
"""
|
||||
return self.database.get_tag_analyses(tag_name)
|
||||
|
||||
def calculate_and_save_tag_hot_scores(self) -> Dict[str, Any]:
|
||||
"""
|
||||
计算标签平均评分并保存到MySQL数据库
|
||||
|
||||
逻辑:
|
||||
1. 从MongoDB的stock_pool_tags集合获取所有标签
|
||||
2. 通过标签查找tag_stock_relations集合,获取关联的股票代码
|
||||
3. 从PostgreSQL的t_signal_daily_results表,计算每个股票近5天的平均total_score
|
||||
4. 用平均加权方式计算每个标签下的平均total_score
|
||||
5. 保存到MySQL 16.150数据库的tag_hot_scores表
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 执行结果,包含成功状态、处理数量等信息
|
||||
"""
|
||||
try:
|
||||
# 导入配置
|
||||
from src.valuation_analysis.config import DB_URL_TAG_SCORE, PG_URL_SIGNAL
|
||||
|
||||
# 1. 从MongoDB获取所有标签和股票关联关系
|
||||
logger.info("开始获取标签和股票关联关系...")
|
||||
|
||||
# 1.1 直接查询tag_stock_relations集合的所有数据
|
||||
relations_collection = self.database.tag_stock_relations_collection
|
||||
all_relations = list(relations_collection.find({}, {'tag_code': 1, 'stock_code': 1, '_id': 0}))
|
||||
logger.info(f"从tag_stock_relations获取到 {len(all_relations)} 条关联关系")
|
||||
|
||||
if not all_relations:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未找到任何标签与股票的关联关系",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
# 1.2 从stock_pool_tags获取所有标签,建立tag_code -> tag_name的映射
|
||||
tags_collection = self.database.tags_collection
|
||||
all_tags = list(tags_collection.find({}, {'tag_code': 1, 'tag_name': 1, '_id': 0}))
|
||||
tag_code_to_name = {tag['tag_code']: tag['tag_name'] for tag in all_tags if 'tag_code' in tag and 'tag_name' in tag}
|
||||
logger.info(f"从stock_pool_tags获取到 {len(tag_code_to_name)} 个标签")
|
||||
|
||||
# 1.3 通过tag_code关联,构建tag_name -> [stock_codes]的映射
|
||||
tag_stock_map = {} # {tag_name: [stock_codes]}
|
||||
tag_code_stock_map = {} # {tag_code: [stock_codes]} 临时存储
|
||||
|
||||
# 先按tag_code分组股票代码
|
||||
for relation in all_relations:
|
||||
tag_code = relation.get('tag_code')
|
||||
stock_code = relation.get('stock_code')
|
||||
|
||||
if tag_code and stock_code:
|
||||
if tag_code not in tag_code_stock_map:
|
||||
tag_code_stock_map[tag_code] = []
|
||||
tag_code_stock_map[tag_code].append(stock_code)
|
||||
|
||||
# 通过tag_code映射到tag_name
|
||||
for tag_code, stock_codes in tag_code_stock_map.items():
|
||||
tag_name = tag_code_to_name.get(tag_code)
|
||||
if tag_name:
|
||||
tag_stock_map[tag_name] = stock_codes
|
||||
logger.debug(f"标签 {tag_name} (tag_code: {tag_code}) 关联 {len(stock_codes)} 只股票")
|
||||
|
||||
logger.info(f"构建完成,共 {len(tag_stock_map)} 个标签有股票关联")
|
||||
|
||||
if not tag_stock_map:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未找到任何有效的标签与股票关联关系",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
# 3. 连接PostgreSQL,查询每个股票的近5天平均total_score
|
||||
logger.info("连接PostgreSQL查询股票评分...")
|
||||
pg_engine = create_engine(PG_URL_SIGNAL)
|
||||
|
||||
try:
|
||||
# 收集所有需要查询的股票代码(去重)
|
||||
all_stock_codes = set()
|
||||
for stock_codes in tag_stock_map.values():
|
||||
all_stock_codes.update(stock_codes)
|
||||
|
||||
logger.info(f"需要查询 {len(all_stock_codes)} 只股票的评分数据")
|
||||
|
||||
# 股票代码格式转换函数:688213.SH -> SH688213
|
||||
def convert_stock_code_format(code: str) -> str:
|
||||
"""将股票代码从 688213.SH 格式转换为 SH688213 格式"""
|
||||
if '.' in code:
|
||||
parts = code.split('.')
|
||||
if len(parts) == 2:
|
||||
return f"{parts[1]}{parts[0]}"
|
||||
return code
|
||||
|
||||
# 转换所有股票代码格式(用于PostgreSQL查询)
|
||||
converted_codes = {code: convert_stock_code_format(code) for code in all_stock_codes}
|
||||
converted_codes_set = set(converted_codes.values())
|
||||
|
||||
logger.info(f"转换后的股票代码格式示例: {list(converted_codes.items())[:5] if converted_codes else []}")
|
||||
|
||||
# 先查询最近5天的最大trade_date
|
||||
logger.info("查询最近5天的交易日期...")
|
||||
with pg_engine.connect() as conn:
|
||||
date_query = text("""
|
||||
SELECT DISTINCT trade_date
|
||||
FROM t_signal_daily_results
|
||||
WHERE trade_date IS NOT NULL
|
||||
ORDER BY trade_date DESC
|
||||
LIMIT 5
|
||||
""")
|
||||
date_result = conn.execute(date_query).fetchall()
|
||||
recent_dates = [row[0] for row in date_result if row[0]]
|
||||
|
||||
if not recent_dates:
|
||||
logger.warning("未找到最近的交易日期数据")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未找到最近的交易日期数据",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
logger.info(f"找到最近5天的交易日期: {recent_dates}")
|
||||
|
||||
# 全量查询这些日期下的所有数据
|
||||
logger.info("全量查询最近5天的信号数据...")
|
||||
with pg_engine.connect() as conn:
|
||||
# 构建日期占位符
|
||||
date_placeholders = ','.join([f':date_{i}' for i in range(len(recent_dates))])
|
||||
data_query = text(f"""
|
||||
SELECT
|
||||
ts_code,
|
||||
trade_date,
|
||||
total_score
|
||||
FROM t_signal_daily_results
|
||||
WHERE trade_date IN ({date_placeholders})
|
||||
AND ts_code IS NOT NULL
|
||||
AND total_score IS NOT NULL
|
||||
""")
|
||||
|
||||
# 构建参数字典
|
||||
date_params = {f'date_{i}': date for i, date in enumerate(recent_dates)}
|
||||
|
||||
# 执行查询并转换为DataFrame
|
||||
result = conn.execute(data_query, date_params)
|
||||
df = pd.DataFrame(result.fetchall(), columns=['ts_code', 'trade_date', 'total_score'])
|
||||
|
||||
if df.empty:
|
||||
logger.warning("未查询到任何信号数据")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未查询到任何信号数据",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
logger.info(f"查询到 {len(df)} 条信号数据记录,涉及 {df['ts_code'].nunique()} 只股票")
|
||||
|
||||
# 使用DataFrame计算每个股票近5天的平均total_score
|
||||
# 只处理我们需要的股票代码
|
||||
df_filtered = df[df['ts_code'].isin(converted_codes_set)]
|
||||
|
||||
if df_filtered.empty:
|
||||
logger.warning("过滤后没有匹配的股票代码数据")
|
||||
logger.debug(f"PostgreSQL中的股票代码示例: {df['ts_code'].unique()[:10] if not df.empty else []}")
|
||||
logger.debug(f"需要查询的股票代码示例: {list(converted_codes_set)[:10] if converted_codes_set else []}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "过滤后没有匹配的股票代码数据",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
logger.info(f"过滤后得到 {len(df_filtered)} 条记录,涉及 {df_filtered['ts_code'].nunique()} 只股票")
|
||||
|
||||
# 按股票代码分组,计算平均分
|
||||
stock_hot_scores_df = df_filtered.groupby('ts_code').agg({
|
||||
'total_score': ['mean', 'count']
|
||||
}).reset_index()
|
||||
|
||||
stock_hot_scores_df.columns = ['ts_code', 'avg_score', 'day_count']
|
||||
|
||||
# 转换为字典,key为原始格式(688213.SH),value为平均分
|
||||
stock_hot_scores = {}
|
||||
for _, row in stock_hot_scores_df.iterrows():
|
||||
pg_code = row['ts_code'] # SH688213格式
|
||||
avg_score = float(row['avg_score']) if pd.notna(row['avg_score']) else None
|
||||
day_count = int(row['day_count'])
|
||||
|
||||
# 找到对应的原始格式代码
|
||||
original_code = None
|
||||
for orig_code, conv_code in converted_codes.items():
|
||||
if conv_code == pg_code:
|
||||
original_code = orig_code
|
||||
break
|
||||
|
||||
if original_code and avg_score is not None:
|
||||
stock_hot_scores[original_code] = {
|
||||
'avg_score': avg_score,
|
||||
'day_count': day_count
|
||||
}
|
||||
|
||||
logger.info(f"计算得到 {len(stock_hot_scores)} 只股票的平均评分")
|
||||
finally:
|
||||
# 关闭PostgreSQL连接
|
||||
pg_engine.dispose()
|
||||
|
||||
# 4. 计算每个标签的平均评分(加权平均)
|
||||
tag_hot_scores = []
|
||||
for tag_name, stock_codes in tag_stock_map.items():
|
||||
# 获取该标签下所有有评分的股票
|
||||
valid_scores = []
|
||||
for stock_code in stock_codes:
|
||||
if stock_code in stock_hot_scores:
|
||||
valid_scores.append(stock_hot_scores[stock_code]['avg_score'])
|
||||
|
||||
if valid_scores:
|
||||
# 计算平均分(简单平均)
|
||||
tag_avg_score = sum(valid_scores) / len(valid_scores)
|
||||
tag_hot_scores.append({
|
||||
'tag_name': tag_name,
|
||||
'score': round(tag_avg_score, 4),
|
||||
'stock_count': len(valid_scores),
|
||||
'total_stock_count': len(stock_codes)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"标签 {tag_name} 没有有效的股票评分数据")
|
||||
|
||||
if not tag_hot_scores:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "没有计算出任何标签评分",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
# 5. 计算总排名(按分数降序)
|
||||
tag_hot_scores.sort(key=lambda x: x['score'], reverse=True)
|
||||
for idx, tag_score in enumerate(tag_hot_scores, 1):
|
||||
tag_score['rank'] = idx
|
||||
|
||||
# 6. 获取当前批次号(上一次批次+1)
|
||||
mysql_engine = create_engine(DB_URL_TAG_SCORE)
|
||||
with mysql_engine.connect() as conn:
|
||||
# 查询最大批次号
|
||||
max_batch_query = text("SELECT MAX(batch_no) as max_batch FROM tag_hot_scores")
|
||||
result = conn.execute(max_batch_query).fetchone()
|
||||
max_batch = result[0] if result and result[0] is not None else 0
|
||||
current_batch = max_batch + 1
|
||||
|
||||
# 7. 保存到MySQL数据库
|
||||
logger.info(f"开始保存标签评分到MySQL,批次号: {current_batch}")
|
||||
insert_query = text("""
|
||||
INSERT INTO tag_hot_scores (tag_name, score, total_rank, create_time, batch_no)
|
||||
VALUES (:tag_name, :score, :total_rank, :create_time, :batch_no)
|
||||
""")
|
||||
|
||||
saved_count = 0
|
||||
create_time = datetime.now()
|
||||
|
||||
with mysql_engine.connect() as conn:
|
||||
trans = conn.begin()
|
||||
try:
|
||||
for tag_score in tag_hot_scores:
|
||||
conn.execute(insert_query, {
|
||||
'tag_name': tag_score['tag_name'],
|
||||
'score': tag_score['score'],
|
||||
'total_rank': tag_score['rank'],
|
||||
'create_time': create_time,
|
||||
'batch_no': current_batch
|
||||
})
|
||||
saved_count += 1
|
||||
|
||||
trans.commit()
|
||||
logger.info(f"成功保存 {saved_count} 条标签评分记录,批次号: {current_batch}")
|
||||
except Exception as e:
|
||||
trans.rollback()
|
||||
raise e
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "标签评分计算并保存成功",
|
||||
"processed_count": saved_count,
|
||||
"batch_no": current_batch,
|
||||
"tag_count": len(tag_hot_scores)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算并保存标签评分失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"计算并保存标签评分失败: {str(e)}",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
def calculate_and_save_stock_hot_scores(self) -> Dict[str, Any]:
|
||||
"""
|
||||
计算所有股票的平均评分并保存到MySQL数据库
|
||||
|
||||
逻辑:
|
||||
1. 从153数据库的gp_code_all表获取所有股票的gp_code
|
||||
2. 从PostgreSQL的t_signal_daily_results表,计算每个股票近5天的平均total_score
|
||||
3. 计算每个股票在所有股票中的排名
|
||||
4. 保存到MySQL 16.150数据库的stock_hot_scores表
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 执行结果,包含成功状态、处理数量等信息
|
||||
"""
|
||||
try:
|
||||
# 导入配置
|
||||
from src.valuation_analysis.config import DB_URL_TAG_SCORE, PG_URL_SIGNAL, DB_URL_153
|
||||
|
||||
# 1. 从153数据库获取所有股票的gp_code
|
||||
logger.info("从153数据库获取所有股票列表...")
|
||||
db153_engine = create_engine(DB_URL_153)
|
||||
|
||||
try:
|
||||
with db153_engine.connect() as conn:
|
||||
stock_query = text("""
|
||||
SELECT DISTINCT gp_code
|
||||
FROM gp_code_all
|
||||
WHERE gp_code IS NOT NULL
|
||||
AND gp_code != ''
|
||||
""")
|
||||
result = conn.execute(stock_query)
|
||||
stock_list = [row[0] for row in result if row[0]]
|
||||
|
||||
logger.info(f"获取到 {len(stock_list)} 只股票")
|
||||
|
||||
if not stock_list:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未找到任何股票",
|
||||
"processed_count": 0
|
||||
}
|
||||
finally:
|
||||
db153_engine.dispose()
|
||||
|
||||
# 2. 股票代码格式转换函数:SH688213 -> SH688213(PostgreSQL格式)
|
||||
# 实际上gp_code已经是SH688213格式,不需要转换
|
||||
# 但为了保险起见,我们确保格式正确
|
||||
def ensure_pg_format(code: str) -> str:
|
||||
"""确保股票代码是PostgreSQL格式(SH688213)"""
|
||||
if code and len(code) > 0:
|
||||
return code.upper()
|
||||
return code
|
||||
|
||||
converted_codes = {code: ensure_pg_format(code) for code in stock_list}
|
||||
converted_codes_set = set(converted_codes.values())
|
||||
|
||||
# 3. 连接PostgreSQL,查询每个股票的近5天平均total_score
|
||||
logger.info("连接PostgreSQL查询股票评分...")
|
||||
pg_engine = create_engine(PG_URL_SIGNAL)
|
||||
|
||||
try:
|
||||
# 先查询最近5天的最大trade_date
|
||||
logger.info("查询最近5天的交易日期...")
|
||||
with pg_engine.connect() as conn:
|
||||
date_query = text("""
|
||||
SELECT DISTINCT trade_date
|
||||
FROM t_signal_daily_results
|
||||
WHERE trade_date IS NOT NULL
|
||||
ORDER BY trade_date DESC
|
||||
LIMIT 5
|
||||
""")
|
||||
date_result = conn.execute(date_query).fetchall()
|
||||
recent_dates = [row[0] for row in date_result if row[0]]
|
||||
|
||||
if not recent_dates:
|
||||
logger.warning("未找到最近的交易日期数据")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未找到最近的交易日期数据",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
logger.info(f"找到最近5天的交易日期: {recent_dates}")
|
||||
|
||||
# 全量查询这些日期下的所有数据
|
||||
logger.info("全量查询最近5天的信号数据...")
|
||||
with pg_engine.connect() as conn:
|
||||
# 构建日期占位符
|
||||
date_placeholders = ','.join([f':date_{i}' for i in range(len(recent_dates))])
|
||||
data_query = text(f"""
|
||||
SELECT
|
||||
ts_code,
|
||||
trade_date,
|
||||
total_score
|
||||
FROM t_signal_daily_results
|
||||
WHERE trade_date IN ({date_placeholders})
|
||||
AND ts_code IS NOT NULL
|
||||
AND total_score IS NOT NULL
|
||||
""")
|
||||
|
||||
# 构建参数字典
|
||||
date_params = {f'date_{i}': date for i, date in enumerate(recent_dates)}
|
||||
|
||||
# 执行查询并转换为DataFrame
|
||||
result = conn.execute(data_query, date_params)
|
||||
df = pd.DataFrame(result.fetchall(), columns=['ts_code', 'trade_date', 'total_score'])
|
||||
|
||||
if df.empty:
|
||||
logger.warning("未查询到任何信号数据")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "未查询到任何信号数据",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
logger.info(f"查询到 {len(df)} 条信号数据记录,涉及 {df['ts_code'].nunique()} 只股票")
|
||||
|
||||
# 使用DataFrame计算每个股票近5天的平均total_score
|
||||
# 只处理我们需要的股票代码
|
||||
df_filtered = df[df['ts_code'].isin(converted_codes_set)]
|
||||
|
||||
if df_filtered.empty:
|
||||
logger.warning("过滤后没有匹配的股票代码数据")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "过滤后没有匹配的股票代码数据",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
logger.info(f"过滤后得到 {len(df_filtered)} 条记录,涉及 {df_filtered['ts_code'].nunique()} 只股票")
|
||||
|
||||
# 按股票代码分组,计算平均分
|
||||
stock_hot_scores_df = df_filtered.groupby('ts_code').agg({
|
||||
'total_score': ['mean', 'count']
|
||||
}).reset_index()
|
||||
|
||||
stock_hot_scores_df.columns = ['ts_code', 'avg_score', 'day_count']
|
||||
|
||||
# 转换为字典,key为股票代码(gp_code格式)
|
||||
stock_hot_scores = {}
|
||||
for _, row in stock_hot_scores_df.iterrows():
|
||||
pg_code = row['ts_code'] # SH688213格式
|
||||
avg_score = float(row['avg_score']) if pd.notna(row['avg_score']) else None
|
||||
day_count = int(row['day_count'])
|
||||
|
||||
# 找到对应的原始格式代码(实际上pg_code就是gp_code格式)
|
||||
original_code = None
|
||||
for orig_code, conv_code in converted_codes.items():
|
||||
if conv_code == pg_code:
|
||||
original_code = orig_code
|
||||
break
|
||||
|
||||
# 如果没找到,直接使用pg_code(可能格式一致)
|
||||
if not original_code:
|
||||
original_code = pg_code
|
||||
|
||||
if avg_score is not None:
|
||||
stock_hot_scores[original_code] = {
|
||||
'avg_score': avg_score,
|
||||
'day_count': day_count
|
||||
}
|
||||
|
||||
logger.info(f"计算得到 {len(stock_hot_scores)} 只股票的平均评分")
|
||||
finally:
|
||||
# 关闭PostgreSQL连接
|
||||
pg_engine.dispose()
|
||||
|
||||
if not stock_hot_scores:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "没有计算出任何股票评分",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
# 4. 计算总排名(按分数降序)
|
||||
stock_hot_scores_list = [
|
||||
{
|
||||
'stock_code': code,
|
||||
'score': data['avg_score'],
|
||||
'day_count': data['day_count']
|
||||
}
|
||||
for code, data in stock_hot_scores.items()
|
||||
]
|
||||
|
||||
# 按分数降序排序
|
||||
stock_hot_scores_list.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
# 计算排名(相同分数排名相同,下一个排名跳过相同数量)
|
||||
current_rank = 1
|
||||
prev_score = None
|
||||
for idx, stock_score in enumerate(stock_hot_scores_list):
|
||||
current_score = stock_score['score']
|
||||
if prev_score is not None and current_score < prev_score:
|
||||
current_rank = idx + 1
|
||||
stock_score['rank'] = current_rank
|
||||
prev_score = current_score
|
||||
|
||||
# 5. 获取当前批次号(上一次批次+1)
|
||||
mysql_engine = create_engine(DB_URL_TAG_SCORE)
|
||||
with mysql_engine.connect() as conn:
|
||||
# 查询最大批次号
|
||||
max_batch_query = text("SELECT MAX(batch_no) as max_batch FROM stock_hot_scores")
|
||||
result = conn.execute(max_batch_query).fetchone()
|
||||
max_batch = result[0] if result and result[0] is not None else 0
|
||||
current_batch = max_batch + 1
|
||||
|
||||
# 6. 保存到MySQL数据库
|
||||
logger.info(f"开始保存股票评分到MySQL,批次号: {current_batch}")
|
||||
insert_query = text("""
|
||||
INSERT INTO stock_hot_scores (stock_code, score, total_rank, create_time, batch_no)
|
||||
VALUES (:stock_code, :score, :total_rank, :create_time, :batch_no)
|
||||
""")
|
||||
|
||||
saved_count = 0
|
||||
create_time = datetime.now()
|
||||
|
||||
with mysql_engine.connect() as conn:
|
||||
trans = conn.begin()
|
||||
try:
|
||||
for stock_score in stock_hot_scores_list:
|
||||
conn.execute(insert_query, {
|
||||
'stock_code': stock_score['stock_code'],
|
||||
'score': stock_score['score'],
|
||||
'total_rank': stock_score['rank'],
|
||||
'create_time': create_time,
|
||||
'batch_no': current_batch
|
||||
})
|
||||
saved_count += 1
|
||||
|
||||
trans.commit()
|
||||
logger.info(f"成功保存 {saved_count} 条股票评分记录,批次号: {current_batch}")
|
||||
except Exception as e:
|
||||
trans.rollback()
|
||||
raise e
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "股票评分计算并保存成功",
|
||||
"processed_count": saved_count,
|
||||
"batch_no": current_batch,
|
||||
"stock_count": len(stock_hot_scores_list)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算并保存股票评分失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"计算并保存股票评分失败: {str(e)}",
|
||||
"processed_count": 0
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
self.database.close_connection()
|
||||
|
|
|
|||
|
|
@ -77,4 +77,37 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|||
LOG_FILE = ROOT_DIR / "logs" / "valuation_analysis.log"
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs(LOG_FILE.parent, exist_ok=True)
|
||||
os.makedirs(LOG_FILE.parent, exist_ok=True)
|
||||
|
||||
# MySQL配置(16.150数据库,用于标签评分存储)
|
||||
DB_CONFIG_TAG_SCORE = {
|
||||
'host': '192.168.16.150',
|
||||
'port': 3306,
|
||||
'user': 'fac_pattern',
|
||||
'password': 'Chlry$%.8pattern',
|
||||
'database': 'factordb_mysql'
|
||||
}
|
||||
# 创建数据库连接URL(用于标签评分存储)
|
||||
DB_URL_TAG_SCORE = f"mysql+pymysql://{DB_CONFIG_TAG_SCORE['user']}:{DB_CONFIG_TAG_SCORE['password']}@{DB_CONFIG_TAG_SCORE['host']}:{DB_CONFIG_TAG_SCORE['port']}/{DB_CONFIG_TAG_SCORE['database']}"
|
||||
|
||||
# PostgreSQL配置(16.150数据库,用于信号数据查询)
|
||||
PG_CONFIG_SIGNAL = {
|
||||
'host': '192.168.16.150',
|
||||
'port': 5432,
|
||||
'user': 'fac_pattern',
|
||||
'password': 'Chlry$%.8pattern',
|
||||
'database': 'factordb'
|
||||
}
|
||||
# 创建PostgreSQL连接URL(用于信号数据查询)
|
||||
PG_URL_SIGNAL = f"postgresql://{PG_CONFIG_SIGNAL['user']}:{PG_CONFIG_SIGNAL['password']}@{PG_CONFIG_SIGNAL['host']}:{PG_CONFIG_SIGNAL['port']}/{PG_CONFIG_SIGNAL['database']}"
|
||||
|
||||
# MySQL配置(16.153数据库,用于股票列表查询)
|
||||
DB_CONFIG_153 = {
|
||||
'host': '192.168.16.153',
|
||||
'port': 3307,
|
||||
'user': 'fac_pattern',
|
||||
'password': 'Chlry$%.8_app',
|
||||
'database': 'my_quant_db'
|
||||
}
|
||||
# 创建数据库连接URL(用于股票列表查询)
|
||||
DB_URL_153 = f"mysql+pymysql://{DB_CONFIG_153['user']}:{DB_CONFIG_153['password']}@{DB_CONFIG_153['host']}:{DB_CONFIG_153['port']}/{DB_CONFIG_153['database']}"
|
||||
|
|
@ -1370,7 +1370,7 @@ class IndustryAnalyzer:
|
|||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"筛选拥挤度缓存时出错: {e}")
|
||||
return result
|
||||
return result
|
||||
|
||||
def _get_all_industries_concepts(self, days: int = 120) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue