This commit is contained in:
满脸小星星 2026-01-16 15:42:04 +08:00
parent cfc06ff720
commit cf35164e99
10 changed files with 1268 additions and 19 deletions

View File

@ -24,3 +24,4 @@ pyyaml==6.0.3
xtquant==250516.1.1 xtquant==250516.1.1
tushare>=1.4.24 tushare>=1.4.24
akshare==1.17.82 akshare==1.17.82
seaborn==0.13.2

View File

@ -682,18 +682,34 @@ def precalculate_industry_crowding_batch():
from src.valuation_analysis.industry_analysis import IndustryAnalyzer from src.valuation_analysis.industry_analysis import IndustryAnalyzer
analyzer = IndustryAnalyzer() analyzer = IndustryAnalyzer()
# 固定行业和概念板块
industries = ["煤炭开采", "焦炭加工", "油气开采", "石油化工", "油服工程", "日用化工", "化纤", "化学原料", "化学制品", "塑料", "橡胶", "农用化工", "非金属材料", "冶钢原料", "普钢", "特钢", "工业金属", "贵金属", "能源金属", "稀有金属", "金属新材料", "水泥", "玻璃玻纤", "装饰建材", "种植业", "养殖业", "林业", "渔业", "饲料", "农产品加工", "动物保健", "酿酒", "饮料乳品", "调味品", "休闲食品", "食品加工", "纺织制造", "服装家纺", "饰品", "造纸", "包装印刷", "家居用品", "文娱用品", "白色家电", "黑色家电", "小家电", "厨卫电器", "家电零部件", "一般零售", "商业物业经营", "专业连锁", "贸易", "电子商务", "乘用车", "商用车", "汽车零部件", "汽车服务", "摩托车及其他", "化学制药", "生物制品", "中药", "医药商业", "医疗器械", "医疗服务", "医疗美容", "电机制造", "电池", "电网设备", "光伏设备", "风电设备", "其他发电设备", "地面兵装", "航空装备", "航天装备", "航海装备", "军工电子", "轨交设备", "通用设备", "专用设备", "工程机械", "自动化设备", "半导体", "消费电子", "光学光电", "元器件", "其他电子", "通信设备", "通信工程", "电信服务", "IT设备", "软件服务", "云服务", "产业互联网", "游戏", "广告营销", "影视院线", "数字媒体", "出版业", "广播电视", "全国性银行", "地方性银行", "证券", "保险", "多元金融", "房屋建设", "基础建设", "专业工程", "工程咨询服务", "装修装饰", "房地产开发", "房产服务", "体育", "教育培训", "酒店餐饮", "旅游", "专业服务", "公路铁路", "航空机场", "航运港口", "物流", "电力", "燃气", "水务", "环保设备", "环境治理", "环境监测", "综合类"] # 从数据库查询所有行业
concepts = ["通达信88", "海峡西岸", "海南自贸", "一带一路", "上海自贸", "雄安新区", "粤港澳", "ST板块", "次新股", "含H股", "含B股", "含GDR", "含可转债", "国防军工", "军民融合", "大飞机", "稀缺资源", "5G概念", "碳中和", "黄金概念", "物联网", "创投概念", "航运概念", "铁路基建", "高端装备", "核电核能", "光伏", "风电", "锂电池概念", "燃料电池", "HJT电池", "固态电池", "钠电池", "钒电池", "TOPCon电池", "钙钛矿电池", "BC电池", "氢能源", "稀土永磁", "盐湖提锂", "锂矿", "水利建设", "卫星导航", "可燃冰", "页岩气", "生物疫苗", "基因概念", "维生素", "仿制药", "创新药", "免疫治疗", "CXO概念", "节能环保", "食品安全", "白酒概念", "代糖概念", "猪肉", "鸡肉", "水产品", "碳纤维", "石墨烯", "3D打印", "苹果概念", "阿里概念", "腾讯概念", "小米概念", "百度概念", "华为鸿蒙", "华为海思", "华为汽车", "华为算力", "特斯拉概念", "消费电子概念", "汽车电子", "无线耳机", "生物质能", "地热能", "充电桩", "新能源车", "换电概念", "高压快充", "草甘膦", "安防服务", "垃圾分类", "核污染防治", "风沙治理", "乡村振兴", "土地流转", "体育概念", "博彩概念", "赛马概念", "分散染料", "聚氨酯", "云计算", "边缘计算", "网络游戏", "信息安全", "国产软件", "大数据", "数据中心", "芯片", "MCU芯片", "汽车芯片", "存储芯片", "互联金融", "婴童概念", "养老概念", "网红经济", "民营医院", "特高压", "智能电网", "智能穿戴", "智能交通", "智能家居", "智能医疗", "智慧城市", "智慧政务", "机器人概念", "机器视觉", "超导概念", "职业教育", "物业管理概念", "虚拟现实", "数字孪生", "钛金属", "钴金属", "镍金属", "氟概念", "磷概念", "无人机", "PPP概念", "新零售", "跨境电商", "量子科技", "无人驾驶", "ETC概念", "胎压监测", "OLED概念", "MiniLED", "MicroLED", "超清视频", "区块链", "数字货币", "人工智能", "租购同权", "工业互联", "知识产权", "工业大麻", "工业气体", "人造肉", "预制菜", "种业", "化肥概念", "操作系统", "光刻机", "第三代半导体", "远程办公", "口罩防护", "医废处理", "虫害防治", "超级电容", "C2M概念", "地摊经济", "冷链物流", "抖音概念", "降解塑料", "医美概念", "人脑工程", "烟草概念", "新型烟草", "有机硅概念", "新冠检测", "BIPV概念", "地下管网", "储能", "新材料", "工业母机", "一体压铸", "汽车热管理", "汽车拆解", "NMN概念", "国资云", "元宇宙概念", "NFT概念", "云游戏", "天然气", "绿色电力", "培育钻石", "信创", "幽门螺杆菌", "电子纸", "新冠药概念", "免税概念", "PVDF概念", "装配式建筑", "绿色建筑", "东数西算", "跨境支付CIPS", "中俄贸易", "电子身份证", "家庭医生", "辅助生殖", "肝炎概念", "新型城镇", "粮食概念", "超临界发电", "虚拟电厂", "动力电池回收", "PCB概念", "先进封装", "热泵概念", "EDA概念", "光热发电", "供销社", "Web3概念", "DRG-DIP", "AIGC概念", "复合铜箔", "数据确权", "数据要素", "POE胶膜", "血氧仪", "旅游概念", "中特估", "ChatGPT概念", "CPO概念", "数字水印", "毫米波雷达", "工业软件", "6G概念", "时空大数据", "可控核聚变", "知识付费", "算力租赁", "光通信", "混合现实", "英伟达概念", "减速器", "减肥药", "合成生物", "星闪概念", "液冷服务器", "新型工业化", "短剧游戏", "多模态AI", "PEEK材料", "小米汽车概念", "飞行汽车", "Sora概念", "人形机器人", "AI手机PC", "低空经济", "铜缆高速连接", "军工信息化", "玻璃基板", "商业航天", "车联网", "财税数字化", "折叠屏", "AI眼镜", "智谱AI", "IP经济", "宠物经济", "小红书概念", "AI智能体", "DeepSeek概念", "AI医疗概念", "海洋经济", "外骨骼机器人", "军贸概念"] industry_list = analyzer.get_industry_list()
industries = [item['name'] for item in industry_list]
logger.info(f"从数据库获取到 {len(industries)} 个行业")
# 从数据库查询所有概念板块
concept_list = analyzer.get_concept_list()
concepts = [item['name'] for item in concept_list]
logger.info(f"从数据库获取到 {len(concepts)} 个概念板块")
# 批量计算行业和概念板块拥挤度 # 批量计算行业和概念板块拥挤度
analyzer.batch_calculate_industry_crowding(industries, concepts) analyzer.batch_calculate_industry_crowding(industries, concepts)
logger.info("批量计算行业和概念板块的拥挤度指标完成") logger.info(f"批量计算完成:{len(industries)}行业和 {len(concepts)}概念板块的拥挤度指标")
except Exception as e: except Exception as e:
logger.error(f"批量计算行业拥挤度指标失败: {str(e)}") logger.error(f"批量计算行业拥挤度指标失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return jsonify({ return jsonify({
"status": "success" "status": "error",
"message": str(e)
}), 500
return jsonify({
"status": "success",
"industries_count": len(industries) if 'industries' in locals() else 0,
"concepts_count": len(concepts) if 'concepts' in locals() else 0
}), 200 }), 200
@app.route('/') @app.route('/')

View File

@ -28,7 +28,7 @@ except ImportError:
# 导入股票代码格式化工具 # 导入股票代码格式化工具
try: try:
from tools.stock_code_formatter import StockCodeFormatter from src.tools.stock_code_formatter import StockCodeFormatter
except ImportError: except ImportError:
import importlib.util import importlib.util
# project_root 已经是项目根目录,直接拼接 tools 目录 # project_root 已经是项目根目录,直接拼接 tools 目录

View File

@ -9,7 +9,6 @@
import sys import sys
import pymongo import pymongo
import datetime
import logging import logging
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pathlib import Path from pathlib import Path
@ -24,7 +23,7 @@ sys.path.append(str(project_root))
from src.valuation_analysis.config import MONGO_CONFIG2, DB_URL from src.valuation_analysis.config import MONGO_CONFIG2, DB_URL
# 导入股票代码格式转换工具 # 导入股票代码格式转换工具
from tools.stock_code_formatter import StockCodeFormatter from src.tools.stock_code_formatter import StockCodeFormatter
# 设置日志 # 设置日志
logging.basicConfig( logging.basicConfig(
@ -478,8 +477,6 @@ class FinancialIndicatorAnalyzer:
Returns: Returns:
List[str]: 季度列表格式为 YYYY-MM-DD List[str]: 季度列表格式为 YYYY-MM-DD
""" """
from datetime import datetime, timedelta
import calendar
quarters = [] quarters = []
year, month, day = map(int, start_date.split('-')) year, month, day = map(int, start_date.split('-'))

View File

@ -12,8 +12,7 @@ import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from pathlib import Path from pathlib import Path
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
import pandas as pd from datetime import datetime
from datetime import datetime, timedelta
# 添加项目根路径到Python路径 # 添加项目根路径到Python路径
project_root = Path(__file__).parent.parent.parent project_root = Path(__file__).parent.parent.parent
@ -33,7 +32,7 @@ except ImportError:
# 导入股票代码格式转换工具 # 导入股票代码格式转换工具
try: try:
from tools.stock_code_formatter import StockCodeFormatter from src.tools.stock_code_formatter import StockCodeFormatter
except ImportError: except ImportError:
# 如果上面的导入失败,尝试直接导入 # 如果上面的导入失败,尝试直接导入
import importlib.util import importlib.util

View File

@ -0,0 +1,243 @@
# coding:utf-8
"""
同花顺行业板块成分股采集工具
功能 Tushare 获取同花顺行业板块成分股数据并落库
API 文档: https://tushare.pro/document/2?doc_id=261
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSIndustryMemberCollector:
"""同花顺行业板块成分股采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 ths_industry_member
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺行业板块成分股采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
def fetch_industry_index_list(self) -> pd.DataFrame:
"""
获取所有同花顺行业板块列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Returns:
行业板块列表 DataFrame
"""
try:
print("正在获取同花顺行业板块列表...")
df = self.pro.ths_index(
exchange='A',
type='I', # I=行业板块
)
if df.empty:
print("未获取到行业板块数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个行业板块")
return df
except Exception as exc:
print(f"获取行业板块列表失败: {exc}")
return pd.DataFrame()
def fetch_industry_members(self, ts_code: str) -> pd.DataFrame:
"""
获取指定行业板块的成分股
API: ths_member
文档: https://tushare.pro/document/2?doc_id=261
Args:
ts_code: 行业板块代码 '885800.TI'
Returns:
成分股列表 DataFrame
"""
try:
df = self.pro.ths_member(ts_code=ts_code)
return df
except Exception as exc:
print(f"获取行业板块 {ts_code} 成分股失败: {exc}")
return pd.DataFrame()
def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
member_df: 成分股信息包含 ts_code, con_code, con_name
industry_name: 行业板块名称从行业板块列表中获取
Returns:
转换后的 DataFrame
"""
if member_df.empty:
return pd.DataFrame()
# 检查必要字段
if "ts_code" not in member_df.columns:
print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
if "con_code" not in member_df.columns:
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
result = pd.DataFrame()
# ts_code 是行业板块代码
result["industry_code"] = member_df["ts_code"]
result["industry_name"] = industry_name
# con_code 是股票代码
result["stock_ts_code"] = member_df["con_code"]
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
# con_name 是股票名称
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
# is_new 是否最新
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(self) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取所有行业板块列表
- 遍历每个行业板块获取成分股数据
- 全量写入数据库
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺行业板块成分股)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取所有行业板块列表
industry_list_df = self.fetch_industry_index_list()
if industry_list_df.empty:
print("未获取到行业板块列表,采集终止")
return
total_records = 0
success_count = 0
failed_count = 0
# 遍历每个行业板块,获取成分股
for idx, row in industry_list_df.iterrows():
industry_code = row["ts_code"]
industry_name = row["name"]
print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})")
try:
# 获取该行业板块的成分股
member_df = self.fetch_industry_members(industry_code)
if member_df.empty:
print(f" 行业板块 {industry_name} 无成分股数据")
failed_count += 1
continue
# 转换数据格式(传入行业板块名称)
result_df = self.transform_data(member_df, industry_name)
if result_df.empty:
print(f" 行业板块 {industry_name} 数据转换失败")
failed_count += 1
continue
# 保存数据
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
print(f" 成功采集 {len(result_df)} 只成分股")
except Exception as exc:
print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}")
failed_count += 1
continue
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print(f"总行业板块数: {len(industry_list_df)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计成分股记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_industry_member(
db_url: str,
tushare_token: str,
):
"""
采集入口 - 全量覆盖采集
"""
collector = THSIndustryMemberCollector(db_url, tushare_token)
collector.run_full_collection()
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集
collect_ths_industry_member(DB_URL, TOKEN)

View File

@ -0,0 +1,300 @@
# coding:utf-8
"""
同花顺板块数据转存脚本
功能 ths_industry_member ths_concept_member 表的数据转存到 gp_hybk gp_gnbk 但是提前需要处理ths_concept_member_collector.py和ths_industry_member_collector.py里面的数据
说明全量清理目标表后重新插入保持数据结构对齐
"""
import os
import sys
import re
from typing import Optional
import pandas as pd
from sqlalchemy import create_engine, text
# 添加项目根目录到路径
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
class THSToGPBKTransfer:
"""同花顺板块数据转存器"""
def __init__(self, db_url: str):
"""
Args:
db_url: 数据库连接 URL
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
print("=" * 60)
print("同花顺板块数据转存工具")
print("=" * 60)
@staticmethod
def extract_bk_code(code: str) -> Optional[int]:
"""
从板块代码中提取数字部分885800.TI -> 885800
Args:
code: 板块代码 '885800.TI' '885800'
Returns:
提取的数字如果无法提取则返回 None
"""
if not code:
return None
# 提取数字部分
match = re.search(r'(\d+)', str(code))
if match:
try:
return int(match.group(1))
except ValueError:
return None
return None
def transfer_industry_data(self) -> dict:
"""
ths_industry_member 表数据转存到 gp_hybk
Returns:
转存结果统计
"""
print("\n" + "=" * 60)
print("开始转存行业板块数据 (ths_industry_member -> gp_hybk)")
print("=" * 60)
result = {
'success': False,
'source_count': 0,
'transferred_count': 0,
'error': None
}
try:
# 1. 读取源表数据
print("正在读取 ths_industry_member 表数据...")
source_query = text("""
SELECT
industry_code,
industry_name,
stock_symbol,
stock_name
FROM ths_industry_member
WHERE industry_code IS NOT NULL
AND stock_symbol IS NOT NULL
""")
with self.engine.connect() as conn:
source_df = pd.read_sql(source_query, conn)
result['source_count'] = len(source_df)
print(f"源表共 {result['source_count']} 条记录")
if source_df.empty:
print("源表无数据,转存终止")
result['success'] = True
return result
# 2. 数据转换
print("正在转换数据格式...")
target_df = pd.DataFrame()
target_df['bk_code'] = source_df['industry_code'].apply(self.extract_bk_code)
target_df['bk_name'] = source_df['industry_name']
target_df['gp_code'] = source_df['stock_symbol']
target_df['gp_name'] = source_df['stock_name']
# 过滤掉 bk_code 为 None 的记录
target_df = target_df[target_df['bk_code'].notna()]
print(f"转换后共 {len(target_df)} 条有效记录")
# 3. 清空目标表
print("正在清空 gp_hybk 表...")
with self.engine.begin() as conn:
conn.execute(text("TRUNCATE TABLE gp_hybk"))
print("gp_hybk 表已清空")
# 4. 插入数据
print("正在插入数据到 gp_hybk 表...")
target_df.to_sql(
'gp_hybk',
self.engine,
if_exists='append',
index=False,
method='multi',
chunksize=1000
)
result['transferred_count'] = len(target_df)
result['success'] = True
print(f"成功转存 {result['transferred_count']} 条记录到 gp_hybk 表")
except Exception as e:
result['error'] = str(e)
print(f"转存行业板块数据失败: {e}")
import traceback
traceback.print_exc()
return result
def transfer_concept_data(self) -> dict:
"""
ths_concept_member 表数据转存到 gp_gnbk
Returns:
转存结果统计
"""
print("\n" + "=" * 60)
print("开始转存概念板块数据 (ths_concept_member -> gp_gnbk)")
print("=" * 60)
result = {
'success': False,
'source_count': 0,
'transferred_count': 0,
'error': None
}
try:
# 1. 读取源表数据
print("正在读取 ths_concept_member 表数据...")
source_query = text("""
SELECT
concept_code,
concept_name,
stock_symbol,
stock_name
FROM ths_concept_member
WHERE concept_code IS NOT NULL
AND stock_symbol IS NOT NULL
""")
with self.engine.connect() as conn:
source_df = pd.read_sql(source_query, conn)
result['source_count'] = len(source_df)
print(f"源表共 {result['source_count']} 条记录")
if source_df.empty:
print("源表无数据,转存终止")
result['success'] = True
return result
# 2. 数据转换
print("正在转换数据格式...")
target_df = pd.DataFrame()
target_df['bk_code'] = source_df['concept_code'].apply(self.extract_bk_code)
target_df['bk_name'] = source_df['concept_name']
target_df['gp_code'] = source_df['stock_symbol']
target_df['gp_name'] = source_df['stock_name']
# 过滤掉 bk_code 为 None 的记录
target_df = target_df[target_df['bk_code'].notna()]
print(f"转换后共 {len(target_df)} 条有效记录")
# 3. 清空目标表
print("正在清空 gp_gnbk 表...")
with self.engine.begin() as conn:
conn.execute(text("TRUNCATE TABLE gp_gnbk"))
print("gp_gnbk 表已清空")
# 4. 插入数据
print("正在插入数据到 gp_gnbk 表...")
target_df.to_sql(
'gp_gnbk',
self.engine,
if_exists='append',
index=False,
method='multi',
chunksize=1000
)
result['transferred_count'] = len(target_df)
result['success'] = True
print(f"成功转存 {result['transferred_count']} 条记录到 gp_gnbk 表")
except Exception as e:
result['error'] = str(e)
print(f"转存概念板块数据失败: {e}")
import traceback
traceback.print_exc()
return result
def run_full_transfer(self) -> dict:
"""
执行全量转存行业 + 概念
Returns:
转存结果汇总
"""
print("\n" + "=" * 60)
print("开始执行全量转存")
print("=" * 60)
summary = {
'industry': None,
'concept': None,
'all_success': False
}
# 转存行业数据
industry_result = self.transfer_industry_data()
summary['industry'] = industry_result
# 转存概念数据
concept_result = self.transfer_concept_data()
summary['concept'] = concept_result
# 汇总结果
summary['all_success'] = industry_result.get('success') and concept_result.get('success')
# 打印汇总
print("\n" + "=" * 60)
print("转存完成汇总")
print("=" * 60)
print(f"行业板块转存: {'成功' if industry_result.get('success') else '失败'}")
if industry_result.get('success'):
print(f" - 源表记录数: {industry_result.get('source_count', 0)}")
print(f" - 转存记录数: {industry_result.get('transferred_count', 0)}")
if industry_result.get('error'):
print(f" - 错误: {industry_result.get('error')}")
print(f"\n概念板块转存: {'成功' if concept_result.get('success') else '失败'}")
if concept_result.get('success'):
print(f" - 源表记录数: {concept_result.get('source_count', 0)}")
print(f" - 转存记录数: {concept_result.get('transferred_count', 0)}")
if concept_result.get('error'):
print(f" - 错误: {concept_result.get('error')}")
print("=" * 60)
return summary
def transfer_ths_to_gp_bk(db_url: str):
"""
转存入口函数
Args:
db_url: 数据库连接 URL
"""
transfer = THSToGPBKTransfer(db_url)
transfer.run_full_transfer()
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
# 执行全量转存
transfer_ths_to_gp_bk(DB_URL)

View File

@ -14,10 +14,21 @@ import datetime
import logging import logging
import json import json
import redis import redis
import time from typing import Tuple, Dict, List
from typing import Tuple, Dict, List, Optional, Union import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
import sys
from .config import DB_URL, OUTPUT_DIR, LOG_FILE # 处理相对导入和绝对导入
try:
from .config import DB_URL, OUTPUT_DIR, LOG_FILE
except ImportError:
# 直接执行时使用绝对导入
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir.parent.parent))
from src.valuation_analysis.config import DB_URL, OUTPUT_DIR, LOG_FILE
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
@ -59,6 +70,132 @@ class IndustryAnalyzer:
) )
logger.info("行业估值分析器初始化完成") logger.info("行业估值分析器初始化完成")
def _setup_chinese_font(self):
"""
设置matplotlib中文字体解决中文显示乱码问题
使用setup_fonts工具确保字体存在并优先使用项目目录中的字体文件
"""
try:
from matplotlib import font_manager
from matplotlib.font_manager import FontProperties
# 1. 确保字体文件存在使用setup_fonts工具
try:
# 获取项目路径
current_file = Path(__file__)
project_root = current_file.parent.parent
setup_fonts_path = project_root / "fundamentals_llm" / "setup_fonts.py"
if setup_fonts_path.exists():
# 动态导入setup_fonts模块
import importlib.util
spec = importlib.util.spec_from_file_location("setup_fonts", str(setup_fonts_path))
setup_fonts = importlib.util.module_from_spec(spec)
spec.loader.exec_module(setup_fonts)
# 确保字体已安装到fundamentals_llm/fonts目录
font_dir = project_root / "fundamentals_llm" / "fonts"
font_file = setup_fonts.install_font(str(font_dir))
if font_file and os.path.exists(font_file):
logger.info(f"setup_fonts已确保字体存在: {font_file}")
else:
logger.warning("setup_fonts未能安装字体尝试查找现有字体")
else:
logger.warning(f"未找到setup_fonts.py: {setup_fonts_path}")
except Exception as e:
logger.warning(f"调用setup_fonts失败: {e},尝试查找现有字体")
# 2. 查找项目中的字体文件(优先使用)
current_file = Path(__file__)
project_root = current_file.parent.parent
# 可能的字体文件位置
font_paths = [
project_root / "fundamentals_llm" / "fonts" / "simhei.ttf", # 现有字体目录
project_root / "valuation_analysis" / "fonts" / "simhei.ttf", # 本模块字体目录
]
font_file = None
for font_path in font_paths:
if font_path.exists():
font_file = str(font_path)
logger.info(f"找到项目字体文件: {font_file}")
break
# 3. 如果找到项目字体文件,直接使用
if font_file and os.path.exists(font_file):
try:
# 使用字体文件路径直接注册
font_prop = FontProperties(fname=font_file)
# 获取字体名称
font_name = font_prop.get_name()
# 设置matplotlib使用该字体
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
logger.info(f"成功使用项目字体: {font_name} (来自: {font_file})")
except Exception as e:
logger.warning(f"加载项目字体文件失败: {e},尝试使用系统字体")
font_file = None
# 4. 如果项目字体不可用,尝试使用系统字体
if font_file is None:
import platform
system = platform.system()
font_candidates = []
if system == "Windows":
font_candidates = [
'Microsoft YaHei',
'SimHei',
'SimSun',
'KaiTi',
'FangSong'
]
elif system == "Darwin": # macOS
font_candidates = [
'PingFang SC',
'STHeiti',
'Arial Unicode MS',
'Heiti SC'
]
else: # Linux
font_candidates = [
'WenQuanYi Micro Hei',
'WenQuanYi Zen Hei',
'Noto Sans CJK SC',
'Droid Sans Fallback'
]
# 尝试设置系统字体
for font_name in font_candidates:
try:
font_list = [f.name for f in font_manager.fontManager.ttflist]
if font_name in font_list:
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
logger.info(f"使用系统字体: {font_name}")
break
except Exception as e:
continue
# 5. 如果都没找到,使用默认设置
if 'font.sans-serif' not in plt.rcParams or len(plt.rcParams['font.sans-serif']) == 0:
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'DejaVu Sans']
logger.warning("使用默认中文字体设置")
# 解决负号显示问题
plt.rcParams['axes.unicode_minus'] = False
# 清除matplotlib字体缓存强制重新加载字体
try:
font_manager._rebuild()
logger.info("已清除matplotlib字体缓存")
except Exception as e:
logger.warning(f"清除字体缓存失败: {e}")
except Exception as e:
logger.warning(f"设置中文字体失败,使用默认设置: {e}")
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def get_industry_list(self) -> List[Dict]: def get_industry_list(self) -> List[Dict]:
""" """
获取所有行业列表 获取所有行业列表
@ -1234,3 +1371,559 @@ class IndustryAnalyzer:
except Exception as e: except Exception as e:
logger.error(f"筛选拥挤度缓存时出错: {e}") logger.error(f"筛选拥挤度缓存时出错: {e}")
return result return result
def _get_all_industries_concepts(self, days: int = 120) -> Tuple[List[str], List[str]]:
"""
从Redis缓存中获取所有有效的行业和概念板块
Args:
days: 需要分析的天数用于验证数据有效性
Returns:
(all_industries, all_concepts) 元组
"""
try:
# 计算结束日期和开始日期
end_date = datetime.datetime.now()
start_date = end_date - datetime.timedelta(days=days)
all_industries = []
all_concepts = []
# 获取所有行业和概念板块的缓存key
industry_keys = redis_client.keys('industry_crowding:*')
concept_keys = redis_client.keys('concept_crowding:*')
# 处理行业
for key in industry_keys:
try:
name = key.split(':', 1)[1]
cached_data = redis_client.get(key)
if not cached_data:
continue
df = pd.DataFrame(json.loads(cached_data))
if df.empty:
continue
# 确保trade_date是日期类型
if 'trade_date' in df.columns:
df['trade_date'] = pd.to_datetime(df['trade_date'])
# 筛选近120天的数据
df = df[df['trade_date'] >= start_date]
df = df.sort_values('trade_date')
if len(df) < 10: # 数据点太少,跳过
continue
all_industries.append(name)
except Exception as e:
logger.warning(f"处理行业缓存 {key} 时出错: {e}")
continue
# 处理概念板块
for key in concept_keys:
try:
name = key.split(':', 1)[1]
cached_data = redis_client.get(key)
if not cached_data:
continue
df = pd.DataFrame(json.loads(cached_data))
if df.empty:
continue
# 确保trade_date是日期类型
if 'trade_date' in df.columns:
df['trade_date'] = pd.to_datetime(df['trade_date'])
# 筛选近120天的数据
df = df[df['trade_date'] >= start_date]
df = df.sort_values('trade_date')
if len(df) < 10: # 数据点太少,跳过
continue
all_concepts.append(name)
except Exception as e:
logger.warning(f"处理概念缓存 {key} 时出错: {e}")
continue
logger.info(f"获取到 {len(all_industries)} 个行业和 {len(all_concepts)} 个概念板块")
return all_industries, all_concepts
except Exception as e:
logger.error(f"获取行业/概念板块列表失败: {e}")
return [], []
def _select_top_industries_concepts(self, count: int = 40, days: int = 120) -> Tuple[List[str], List[str]]:
"""
从Redis缓存中选择拥挤度变化幅度最大的行业和概念板块
Args:
count: 选择的行业/概念数量
days: 需要分析的天数
Returns:
(selected_industries, selected_concepts) 元组
"""
try:
# 计算结束日期和开始日期
end_date = datetime.datetime.now()
start_date = end_date - datetime.timedelta(days=days)
industry_scores = []
concept_scores = []
# 获取所有行业和概念板块的缓存key
industry_keys = redis_client.keys('industry_crowding:*')
concept_keys = redis_client.keys('concept_crowding:*')
# 处理行业
for key in industry_keys:
try:
name = key.split(':', 1)[1]
cached_data = redis_client.get(key)
if not cached_data:
continue
df = pd.DataFrame(json.loads(cached_data))
if df.empty:
continue
# 确保trade_date是日期类型
if 'trade_date' in df.columns:
df['trade_date'] = pd.to_datetime(df['trade_date'])
# 筛选近120天的数据
df = df[df['trade_date'] >= start_date]
df = df.sort_values('trade_date')
if len(df) < 10: # 数据点太少,跳过
continue
# 计算拥挤度变化幅度(最大值-最小值)
percentile_values = df['percentile'].values
if len(percentile_values) > 0:
change_range = float(percentile_values.max() - percentile_values.min())
current_value = float(percentile_values[-1])
# 综合评分:变化幅度 + 当前值的绝对值(突出极端情况)
score = change_range * 0.7 + abs(current_value - 50) * 0.3
industry_scores.append({
'name': name,
'score': score,
'change_range': change_range,
'current': current_value
})
except Exception as e:
logger.warning(f"处理行业缓存 {key} 时出错: {e}")
continue
# 处理概念板块
for key in concept_keys:
try:
name = key.split(':', 1)[1]
cached_data = redis_client.get(key)
if not cached_data:
continue
df = pd.DataFrame(json.loads(cached_data))
if df.empty:
continue
# 确保trade_date是日期类型
if 'trade_date' in df.columns:
df['trade_date'] = pd.to_datetime(df['trade_date'])
# 筛选近120天的数据
df = df[df['trade_date'] >= start_date]
df = df.sort_values('trade_date')
if len(df) < 10: # 数据点太少,跳过
continue
# 计算拥挤度变化幅度
percentile_values = df['percentile'].values
if len(percentile_values) > 0:
change_range = float(percentile_values.max() - percentile_values.min())
current_value = float(percentile_values[-1])
# 综合评分
score = change_range * 0.7 + abs(current_value - 50) * 0.3
concept_scores.append({
'name': name,
'score': score,
'change_range': change_range,
'current': current_value
})
except Exception as e:
logger.warning(f"处理概念缓存 {key} 时出错: {e}")
continue
# 按评分排序选择前count个
industry_scores.sort(key=lambda x: x['score'], reverse=True)
concept_scores.sort(key=lambda x: x['score'], reverse=True)
selected_industries = [item['name'] for item in industry_scores[:count]]
selected_concepts = [item['name'] for item in concept_scores[:count]]
logger.info(f"选择了 {len(selected_industries)} 个行业和 {len(selected_concepts)} 个概念板块")
return selected_industries, selected_concepts
except Exception as e:
logger.error(f"选择行业/概念板块时出错: {e}")
return [], []
def _build_heatmap_data(self, names: List[str], is_concept: bool, days: int = 120) -> pd.DataFrame:
"""
构建热力图数据矩阵
Args:
names: 行业/概念板块名称列表
is_concept: 是否为概念板块
days: 需要分析的天数
Returns:
包含热力图数据的DataFrame行为行业/概念列为日期值为拥挤度百分位
"""
try:
# 计算日期范围
end_date = datetime.datetime.now()
start_date = end_date - datetime.timedelta(days=days)
# 获取所有交易日期(从第一个有效数据中获取)
all_dates = set()
data_dict = {}
prefix = 'concept_crowding:' if is_concept else 'industry_crowding:'
for name in names:
cache_key = f"{prefix}{name}"
cached_data = redis_client.get(cache_key)
if not cached_data:
continue
try:
df = pd.DataFrame(json.loads(cached_data))
if df.empty:
continue
# 确保trade_date是日期类型
if 'trade_date' in df.columns:
df['trade_date'] = pd.to_datetime(df['trade_date'])
# 筛选日期范围
df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)]
df = df.sort_values('trade_date')
if df.empty:
continue
# 收集所有日期
all_dates.update(df['trade_date'].tolist())
# 存储数据
data_dict[name] = df.set_index('trade_date')['percentile'].to_dict()
except Exception as e:
logger.warning(f"处理 {name} 数据时出错: {e}")
continue
if not all_dates:
logger.warning("没有找到有效数据")
return pd.DataFrame()
# 创建完整的日期序列(只包含交易日)
all_dates = sorted(list(all_dates))
# 构建矩阵
matrix_data = []
for name in names:
row = []
for date in all_dates:
if name in data_dict and date in data_dict[name]:
row.append(float(data_dict[name][date]))
else:
row.append(np.nan)
matrix_data.append(row)
# 创建DataFrame
heatmap_df = pd.DataFrame(matrix_data, index=names, columns=all_dates)
logger.info(f"构建热力图数据完成: {heatmap_df.shape}")
return heatmap_df
except Exception as e:
logger.error(f"构建热力图数据失败: {e}")
return pd.DataFrame()
def plot_crowding_heatmap(self, days: int = 120, items_per_page: int = 40, output_dir: str = None):
"""
绘制行业和概念板块拥挤度热力图
每40个行业/概念生成一张图
Args:
days: 分析的天数默认120天
items_per_page: 每张图显示的行业/概念数量默认40个
output_dir: 输出目录默认为src/data目录
"""
try:
# 设置输出目录
if output_dir is None:
# 获取项目根目录
current_file = Path(__file__)
project_root = current_file.parent.parent.parent
output_dir = project_root / "src" / "data"
else:
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)
# 获取所有行业和概念
all_industries, all_concepts = self._get_all_industries_concepts(days=days)
if not all_industries and not all_concepts:
logger.warning("没有找到有效的行业或概念板块数据")
return
# 设置中文字体使用setup_fonts工具
self._setup_chinese_font()
# 设置seaborn样式
sns.set_style("whitegrid")
# 生成时间戳用于文件名
timestamp = datetime.datetime.now().strftime('%Y%m%d')
# 绘制行业热力图按每items_per_page个分组
if all_industries:
total_industries = len(all_industries)
total_pages = (total_industries + items_per_page - 1) // items_per_page # 向上取整
logger.info(f"开始绘制行业热力图,共 {total_industries} 个行业,将生成 {total_pages} 张图")
for page in range(total_pages):
start_idx = page * items_per_page
end_idx = min(start_idx + items_per_page, total_industries)
page_industries = all_industries[start_idx:end_idx]
logger.info(f"绘制行业热力图第 {page + 1}/{total_pages} 张(行业 {start_idx + 1}-{end_idx}")
industry_heatmap = self._build_heatmap_data(page_industries, is_concept=False, days=days)
if not industry_heatmap.empty:
self._plot_single_heatmap(
industry_heatmap,
title=f"行业拥挤度热力图(近{days}天,第{page + 1}/{total_pages}页,{len(page_industries)}个行业)",
output_path=output_dir / f"industry_crowding_heatmap_{timestamp}_page{page + 1:02d}.png",
is_concept=False
)
# 绘制概念板块热力图按每items_per_page个分组
if all_concepts:
total_concepts = len(all_concepts)
total_pages = (total_concepts + items_per_page - 1) // items_per_page # 向上取整
logger.info(f"开始绘制概念板块热力图,共 {total_concepts} 个概念,将生成 {total_pages} 张图")
for page in range(total_pages):
start_idx = page * items_per_page
end_idx = min(start_idx + items_per_page, total_concepts)
page_concepts = all_concepts[start_idx:end_idx]
logger.info(f"绘制概念板块热力图第 {page + 1}/{total_pages} 张(概念 {start_idx + 1}-{end_idx}")
concept_heatmap = self._build_heatmap_data(page_concepts, is_concept=True, days=days)
if not concept_heatmap.empty:
self._plot_single_heatmap(
concept_heatmap,
title=f"概念板块拥挤度热力图(近{days}天,第{page + 1}/{total_pages}页,{len(page_concepts)}个概念)",
output_path=output_dir / f"concept_crowding_heatmap_{timestamp}_page{page + 1:02d}.png",
is_concept=True
)
logger.info("热力图绘制完成")
except Exception as e:
logger.error(f"绘制热力图失败: {e}")
import traceback
logger.error(traceback.format_exc())
def _plot_single_heatmap(self, heatmap_df: pd.DataFrame, title: str, output_path: Path, is_concept: bool):
"""
绘制单个热力图
Args:
heatmap_df: 热力图数据DataFrame
title: 图表标题
output_path: 输出路径
is_concept: 是否为概念板块
"""
try:
# 计算图表尺寸(根据数据量动态调整)
n_rows = len(heatmap_df)
n_cols = len(heatmap_df.columns)
# 基础尺寸,根据数据量调整
fig_width = max(20, n_cols * 0.15)
fig_height = max(12, n_rows * 0.4)
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
# 获取当前设置的字体
from matplotlib.font_manager import FontProperties
current_font = plt.rcParams['font.sans-serif'][0] if plt.rcParams['font.sans-serif'] else 'SimHei'
font_prop = FontProperties(fname=str(Path(__file__).parent.parent / "fundamentals_llm" / "fonts" / "simhei.ttf")) if (Path(__file__).parent.parent / "fundamentals_llm" / "fonts" / "simhei.ttf").exists() else None
# heatmap_df: 行为行业/概念,列为日期
# seaborn heatmap: 行在y轴列在x轴
# 所以y轴是行业/概念x轴是日期符合要求
plot_data = heatmap_df
# 使用自定义颜色映射:绿色(不拥挤)-> 黄色(中性)-> 红色(拥挤)
# 创建从绿色到红色渐变的colormap
colors = ['#2ecc71', '#f1c40f', '#e74c3c'] # 绿色、黄色、红色
from matplotlib.colors import LinearSegmentedColormap
n_bins = 100
cmap = LinearSegmentedColormap.from_list('crowding', colors, N=n_bins)
# 绘制热力图
heatmap = sns.heatmap(
plot_data,
cmap=cmap,
vmin=0,
vmax=100,
annot=False, # 不显示数值(数据太多)
fmt='.1f',
cbar_kws={
'label': '拥挤度百分位 (%)',
'shrink': 0.8,
'format': '%.0f'
},
linewidths=0.05,
linecolor='white',
ax=ax,
xticklabels=False, # 先不显示x轴标签后面自定义
yticklabels=True # 显示y轴标签行业/概念名称)
)
# 设置colorbar标签字体
try:
cbar = heatmap.collections[0].colorbar
if cbar is not None:
if font_prop:
cbar.set_label('拥挤度百分位 (%)', fontproperties=font_prop, fontsize=10)
else:
cbar.set_label('拥挤度百分位 (%)', fontfamily=current_font, fontsize=10)
# 设置colorbar刻度标签字体
for label in cbar.ax.get_yticklabels():
if font_prop:
label.set_fontproperties(font_prop)
else:
label.set_fontfamily(current_font)
except Exception as e:
logger.warning(f"设置colorbar字体失败: {e}")
# 设置标题和标签,显式指定字体
if font_prop:
ax.set_title(title, fontsize=16, fontweight='bold', pad=20, fontproperties=font_prop)
ax.set_xlabel('日期(时间)', fontsize=12, fontproperties=font_prop)
ax.set_ylabel('行业/概念板块', fontsize=12, fontproperties=font_prop)
else:
ax.set_title(title, fontsize=16, fontweight='bold', pad=20, fontfamily=current_font)
ax.set_xlabel('日期(时间)', fontsize=12, fontfamily=current_font)
ax.set_ylabel('行业/概念板块', fontsize=12, fontfamily=current_font)
# 调整y轴标签行业/概念名称)- 显示完整名称,但可以旋转
y_labels = plot_data.index.tolist()
if len(y_labels) > 0:
# 如果名称超过15个字符截断
y_labels_display = [label[:15] + '...' if len(str(label)) > 15 else str(label) for label in y_labels]
if font_prop:
ax.set_yticklabels(y_labels_display, fontsize=9, rotation=0, ha='right', fontproperties=font_prop)
else:
ax.set_yticklabels(y_labels_display, fontsize=9, rotation=0, ha='right', fontfamily=current_font)
# 调整x轴标签日期- 只显示部分日期,避免过于密集
x_labels = plot_data.columns.tolist()
if len(x_labels) > 0:
# 如果日期太多,只显示部分
if len(x_labels) > 30:
step = max(1, len(x_labels) // 15) # 最多显示15个日期标签
x_positions = list(range(0, len(x_labels), step))
x_labels_display = []
for pos in x_positions:
if pos < len(x_labels):
date_str = str(x_labels[pos])
# 只显示月-日
if isinstance(x_labels[pos], pd.Timestamp):
x_labels_display.append(x_labels[pos].strftime('%m-%d'))
else:
x_labels_display.append(date_str[:10] if len(date_str) > 10 else date_str)
ax.set_xticks(x_positions)
if font_prop:
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontproperties=font_prop)
else:
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontfamily=current_font)
else:
# 日期不多,全部显示
x_labels_display = []
for date in x_labels:
if isinstance(date, pd.Timestamp):
x_labels_display.append(date.strftime('%m-%d'))
else:
date_str = str(date)
x_labels_display.append(date_str[:10] if len(date_str) > 10 else date_str)
if font_prop:
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontproperties=font_prop)
else:
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontfamily=current_font)
plt.tight_layout()
# 保存图片
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"热力图已保存到: {output_path}")
except Exception as e:
logger.error(f"绘制热力图失败: {e}")
import traceback
logger.error(traceback.format_exc())
if 'fig' in locals():
plt.close(fig)
def main():
"""主函数:生成行业和概念板块拥挤度热力图"""
try:
logger.info("=" * 60)
logger.info("开始生成行业和概念板块拥挤度热力图")
logger.info("=" * 60)
# 创建分析器实例
analyzer = IndustryAnalyzer()
# 绘制热力图
# 参数说明:
# days: 分析的天数默认120天
# items_per_page: 每张图显示的行业/概念数量默认40个
# output_dir: 输出目录默认为src/data目录
analyzer.plot_crowding_heatmap(days=120, items_per_page=40)
logger.info("=" * 60)
logger.info("热力图生成完成!")
logger.info("=" * 60)
except Exception as e:
logger.error(f"执行主函数失败: {e}")
import traceback
logger.error(traceback.format_exc())
if __name__ == '__main__':
main()