commit;
This commit is contained in:
parent
cfc06ff720
commit
cf35164e99
|
|
@ -24,3 +24,4 @@ pyyaml==6.0.3
|
||||||
xtquant==250516.1.1
|
xtquant==250516.1.1
|
||||||
tushare>=1.4.24
|
tushare>=1.4.24
|
||||||
akshare==1.17.82
|
akshare==1.17.82
|
||||||
|
seaborn==0.13.2
|
||||||
|
|
|
||||||
26
src/app.py
26
src/app.py
|
|
@ -682,18 +682,34 @@ def precalculate_industry_crowding_batch():
|
||||||
from src.valuation_analysis.industry_analysis import IndustryAnalyzer
|
from src.valuation_analysis.industry_analysis import IndustryAnalyzer
|
||||||
|
|
||||||
analyzer = IndustryAnalyzer()
|
analyzer = IndustryAnalyzer()
|
||||||
# 固定行业和概念板块
|
|
||||||
industries = ["煤炭开采", "焦炭加工", "油气开采", "石油化工", "油服工程", "日用化工", "化纤", "化学原料", "化学制品", "塑料", "橡胶", "农用化工", "非金属材料", "冶钢原料", "普钢", "特钢", "工业金属", "贵金属", "能源金属", "稀有金属", "金属新材料", "水泥", "玻璃玻纤", "装饰建材", "种植业", "养殖业", "林业", "渔业", "饲料", "农产品加工", "动物保健", "酿酒", "饮料乳品", "调味品", "休闲食品", "食品加工", "纺织制造", "服装家纺", "饰品", "造纸", "包装印刷", "家居用品", "文娱用品", "白色家电", "黑色家电", "小家电", "厨卫电器", "家电零部件", "一般零售", "商业物业经营", "专业连锁", "贸易", "电子商务", "乘用车", "商用车", "汽车零部件", "汽车服务", "摩托车及其他", "化学制药", "生物制品", "中药", "医药商业", "医疗器械", "医疗服务", "医疗美容", "电机制造", "电池", "电网设备", "光伏设备", "风电设备", "其他发电设备", "地面兵装", "航空装备", "航天装备", "航海装备", "军工电子", "轨交设备", "通用设备", "专用设备", "工程机械", "自动化设备", "半导体", "消费电子", "光学光电", "元器件", "其他电子", "通信设备", "通信工程", "电信服务", "IT设备", "软件服务", "云服务", "产业互联网", "游戏", "广告营销", "影视院线", "数字媒体", "出版业", "广播电视", "全国性银行", "地方性银行", "证券", "保险", "多元金融", "房屋建设", "基础建设", "专业工程", "工程咨询服务", "装修装饰", "房地产开发", "房产服务", "体育", "教育培训", "酒店餐饮", "旅游", "专业服务", "公路铁路", "航空机场", "航运港口", "物流", "电力", "燃气", "水务", "环保设备", "环境治理", "环境监测", "综合类"]
|
# 从数据库查询所有行业
|
||||||
concepts = ["通达信88", "海峡西岸", "海南自贸", "一带一路", "上海自贸", "雄安新区", "粤港澳", "ST板块", "次新股", "含H股", "含B股", "含GDR", "含可转债", "国防军工", "军民融合", "大飞机", "稀缺资源", "5G概念", "碳中和", "黄金概念", "物联网", "创投概念", "航运概念", "铁路基建", "高端装备", "核电核能", "光伏", "风电", "锂电池概念", "燃料电池", "HJT电池", "固态电池", "钠电池", "钒电池", "TOPCon电池", "钙钛矿电池", "BC电池", "氢能源", "稀土永磁", "盐湖提锂", "锂矿", "水利建设", "卫星导航", "可燃冰", "页岩气", "生物疫苗", "基因概念", "维生素", "仿制药", "创新药", "免疫治疗", "CXO概念", "节能环保", "食品安全", "白酒概念", "代糖概念", "猪肉", "鸡肉", "水产品", "碳纤维", "石墨烯", "3D打印", "苹果概念", "阿里概念", "腾讯概念", "小米概念", "百度概念", "华为鸿蒙", "华为海思", "华为汽车", "华为算力", "特斯拉概念", "消费电子概念", "汽车电子", "无线耳机", "生物质能", "地热能", "充电桩", "新能源车", "换电概念", "高压快充", "草甘膦", "安防服务", "垃圾分类", "核污染防治", "风沙治理", "乡村振兴", "土地流转", "体育概念", "博彩概念", "赛马概念", "分散染料", "聚氨酯", "云计算", "边缘计算", "网络游戏", "信息安全", "国产软件", "大数据", "数据中心", "芯片", "MCU芯片", "汽车芯片", "存储芯片", "互联金融", "婴童概念", "养老概念", "网红经济", "民营医院", "特高压", "智能电网", "智能穿戴", "智能交通", "智能家居", "智能医疗", "智慧城市", "智慧政务", "机器人概念", "机器视觉", "超导概念", "职业教育", "物业管理概念", "虚拟现实", "数字孪生", "钛金属", "钴金属", "镍金属", "氟概念", "磷概念", "无人机", "PPP概念", "新零售", "跨境电商", "量子科技", "无人驾驶", "ETC概念", "胎压监测", "OLED概念", "MiniLED", "MicroLED", "超清视频", "区块链", "数字货币", "人工智能", "租购同权", "工业互联", "知识产权", "工业大麻", "工业气体", "人造肉", "预制菜", "种业", "化肥概念", "操作系统", "光刻机", "第三代半导体", "远程办公", "口罩防护", "医废处理", "虫害防治", "超级电容", "C2M概念", "地摊经济", "冷链物流", "抖音概念", "降解塑料", "医美概念", "人脑工程", "烟草概念", "新型烟草", "有机硅概念", "新冠检测", "BIPV概念", "地下管网", "储能", "新材料", "工业母机", "一体压铸", "汽车热管理", "汽车拆解", "NMN概念", "国资云", "元宇宙概念", "NFT概念", "云游戏", "天然气", "绿色电力", "培育钻石", "信创", "幽门螺杆菌", "电子纸", "新冠药概念", "免税概念", "PVDF概念", "装配式建筑", "绿色建筑", "东数西算", "跨境支付CIPS", "中俄贸易", "电子身份证", "家庭医生", "辅助生殖", "肝炎概念", "新型城镇", "粮食概念", "超临界发电", "虚拟电厂", "动力电池回收", "PCB概念", "先进封装", "热泵概念", "EDA概念", "光热发电", "供销社", "Web3概念", "DRG-DIP", "AIGC概念", "复合铜箔", "数据确权", "数据要素", "POE胶膜", "血氧仪", "旅游概念", "中特估", "ChatGPT概念", "CPO概念", "数字水印", "毫米波雷达", "工业软件", "6G概念", "时空大数据", "可控核聚变", "知识付费", "算力租赁", "光通信", "混合现实", "英伟达概念", "减速器", "减肥药", "合成生物", "星闪概念", "液冷服务器", "新型工业化", "短剧游戏", "多模态AI", "PEEK材料", "小米汽车概念", "飞行汽车", "Sora概念", "人形机器人", "AI手机PC", "低空经济", "铜缆高速连接", "军工信息化", "玻璃基板", "商业航天", "车联网", "财税数字化", "折叠屏", "AI眼镜", "智谱AI", "IP经济", "宠物经济", "小红书概念", "AI智能体", "DeepSeek概念", "AI医疗概念", "海洋经济", "外骨骼机器人", "军贸概念"]
|
industry_list = analyzer.get_industry_list()
|
||||||
|
industries = [item['name'] for item in industry_list]
|
||||||
|
logger.info(f"从数据库获取到 {len(industries)} 个行业")
|
||||||
|
|
||||||
|
# 从数据库查询所有概念板块
|
||||||
|
concept_list = analyzer.get_concept_list()
|
||||||
|
concepts = [item['name'] for item in concept_list]
|
||||||
|
logger.info(f"从数据库获取到 {len(concepts)} 个概念板块")
|
||||||
|
|
||||||
# 批量计算行业和概念板块拥挤度
|
# 批量计算行业和概念板块拥挤度
|
||||||
analyzer.batch_calculate_industry_crowding(industries, concepts)
|
analyzer.batch_calculate_industry_crowding(industries, concepts)
|
||||||
|
|
||||||
logger.info("批量计算行业和概念板块的拥挤度指标完成")
|
logger.info(f"批量计算完成:{len(industries)} 个行业和 {len(concepts)} 个概念板块的拥挤度指标")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量计算行业拥挤度指标失败: {str(e)}")
|
logger.error(f"批量计算行业拥挤度指标失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"status": "success"
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"status": "success",
|
||||||
|
"industries_count": len(industries) if 'industries' in locals() else 0,
|
||||||
|
"concepts_count": len(concepts) if 'concepts' in locals() else 0
|
||||||
}), 200
|
}), 200
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ except ImportError:
|
||||||
|
|
||||||
# 导入股票代码格式化工具
|
# 导入股票代码格式化工具
|
||||||
try:
|
try:
|
||||||
from tools.stock_code_formatter import StockCodeFormatter
|
from src.tools.stock_code_formatter import StockCodeFormatter
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import importlib.util
|
import importlib.util
|
||||||
# project_root 已经是项目根目录,直接拼接 tools 目录
|
# project_root 已经是项目根目录,直接拼接 tools 目录
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import pymongo
|
import pymongo
|
||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -24,7 +23,7 @@ sys.path.append(str(project_root))
|
||||||
from src.valuation_analysis.config import MONGO_CONFIG2, DB_URL
|
from src.valuation_analysis.config import MONGO_CONFIG2, DB_URL
|
||||||
|
|
||||||
# 导入股票代码格式转换工具
|
# 导入股票代码格式转换工具
|
||||||
from tools.stock_code_formatter import StockCodeFormatter
|
from src.tools.stock_code_formatter import StockCodeFormatter
|
||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|
@ -478,8 +477,6 @@ class FinancialIndicatorAnalyzer:
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: 季度列表,格式为 YYYY-MM-DD
|
List[str]: 季度列表,格式为 YYYY-MM-DD
|
||||||
"""
|
"""
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import calendar
|
|
||||||
|
|
||||||
quarters = []
|
quarters = []
|
||||||
year, month, day = map(int, start_date.split('-'))
|
year, month, day = map(int, start_date.split('-'))
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,7 @@ import logging
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
import pandas as pd
|
from datetime import datetime
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
# 添加项目根路径到Python路径
|
# 添加项目根路径到Python路径
|
||||||
project_root = Path(__file__).parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
@ -33,7 +32,7 @@ except ImportError:
|
||||||
|
|
||||||
# 导入股票代码格式转换工具
|
# 导入股票代码格式转换工具
|
||||||
try:
|
try:
|
||||||
from tools.stock_code_formatter import StockCodeFormatter
|
from src.tools.stock_code_formatter import StockCodeFormatter
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 如果上面的导入失败,尝试直接导入
|
# 如果上面的导入失败,尝试直接导入
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,243 @@
|
||||||
|
# coding:utf-8
|
||||||
|
"""
|
||||||
|
同花顺行业板块成分股采集工具
|
||||||
|
功能:从 Tushare 获取同花顺行业板块成分股数据并落库
|
||||||
|
API 文档: https://tushare.pro/document/2?doc_id=261
|
||||||
|
说明:每次调用时全量覆盖,不需要每日更新
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import tushare as ts
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
# 添加项目根目录到路径,确保能够读取配置
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
|
from src.scripts.config import TUSHARE_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
class THSIndustryMemberCollector:
|
||||||
|
"""同花顺行业板块成分股采集器"""
|
||||||
|
|
||||||
|
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
db_url: 数据库连接 URL
|
||||||
|
tushare_token: Tushare Token
|
||||||
|
table_name: 目标表名,默认 ths_industry_member
|
||||||
|
"""
|
||||||
|
self.engine = create_engine(
|
||||||
|
db_url,
|
||||||
|
pool_size=5,
|
||||||
|
max_overflow=10,
|
||||||
|
pool_recycle=3600,
|
||||||
|
)
|
||||||
|
self.table_name = table_name
|
||||||
|
|
||||||
|
ts.set_token(tushare_token)
|
||||||
|
self.pro = ts.pro_api()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("同花顺行业板块成分股采集工具")
|
||||||
|
print(f"目标数据表: {self.table_name}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_tushare_code_to_db(ts_code: str) -> str:
|
||||||
|
"""
|
||||||
|
将 Tushare 代码(600000.SH)转换为数据库代码(SH600000)
|
||||||
|
"""
|
||||||
|
if not ts_code or "." not in ts_code:
|
||||||
|
return ts_code
|
||||||
|
base, market = ts_code.split(".")
|
||||||
|
return f"{market}{base}"
|
||||||
|
|
||||||
|
def fetch_industry_index_list(self) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
获取所有同花顺行业板块列表
|
||||||
|
API: ths_index
|
||||||
|
文档: https://tushare.pro/document/2?doc_id=259
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
行业板块列表 DataFrame
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print("正在获取同花顺行业板块列表...")
|
||||||
|
df = self.pro.ths_index(
|
||||||
|
exchange='A',
|
||||||
|
type='I', # I=行业板块
|
||||||
|
)
|
||||||
|
if df.empty:
|
||||||
|
print("未获取到行业板块数据")
|
||||||
|
return pd.DataFrame()
|
||||||
|
print(f"成功获取 {len(df)} 个行业板块")
|
||||||
|
return df
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"获取行业板块列表失败: {exc}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def fetch_industry_members(self, ts_code: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
获取指定行业板块的成分股
|
||||||
|
API: ths_member
|
||||||
|
文档: https://tushare.pro/document/2?doc_id=261
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: 行业板块代码,如 '885800.TI'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
成分股列表 DataFrame
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
df = self.pro.ths_member(ts_code=ts_code)
|
||||||
|
return df
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"获取行业板块 {ts_code} 成分股失败: {exc}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
将 Tushare 返回的数据转换为数据库入库格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_df: 成分股信息(包含 ts_code, con_code, con_name 等)
|
||||||
|
industry_name: 行业板块名称(从行业板块列表中获取)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转换后的 DataFrame
|
||||||
|
"""
|
||||||
|
if member_df.empty:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# 检查必要字段
|
||||||
|
if "ts_code" not in member_df.columns:
|
||||||
|
print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
if "con_code" not in member_df.columns:
|
||||||
|
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
result = pd.DataFrame()
|
||||||
|
# ts_code 是行业板块代码
|
||||||
|
result["industry_code"] = member_df["ts_code"]
|
||||||
|
result["industry_name"] = industry_name
|
||||||
|
# con_code 是股票代码
|
||||||
|
result["stock_ts_code"] = member_df["con_code"]
|
||||||
|
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
|
||||||
|
# con_name 是股票名称
|
||||||
|
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
|
||||||
|
# is_new 是否最新
|
||||||
|
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
|
||||||
|
result["created_at"] = datetime.now()
|
||||||
|
result["updated_at"] = datetime.now()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def save_dataframe(self, df: pd.DataFrame) -> None:
|
||||||
|
"""
|
||||||
|
将数据写入数据库
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
return
|
||||||
|
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
||||||
|
|
||||||
|
def run_full_collection(self) -> None:
|
||||||
|
"""
|
||||||
|
执行全量覆盖采集:
|
||||||
|
- 清空目标表
|
||||||
|
- 获取所有行业板块列表
|
||||||
|
- 遍历每个行业板块,获取成分股数据
|
||||||
|
- 全量写入数据库
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("开始执行全量覆盖采集(同花顺行业板块成分股)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 清空表
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
||||||
|
print(f"{self.table_name} 已清空")
|
||||||
|
|
||||||
|
# 获取所有行业板块列表
|
||||||
|
industry_list_df = self.fetch_industry_index_list()
|
||||||
|
if industry_list_df.empty:
|
||||||
|
print("未获取到行业板块列表,采集终止")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_records = 0
|
||||||
|
success_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
# 遍历每个行业板块,获取成分股
|
||||||
|
for idx, row in industry_list_df.iterrows():
|
||||||
|
industry_code = row["ts_code"]
|
||||||
|
industry_name = row["name"]
|
||||||
|
|
||||||
|
print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取该行业板块的成分股
|
||||||
|
member_df = self.fetch_industry_members(industry_code)
|
||||||
|
if member_df.empty:
|
||||||
|
print(f" 行业板块 {industry_name} 无成分股数据")
|
||||||
|
failed_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 转换数据格式(传入行业板块名称)
|
||||||
|
result_df = self.transform_data(member_df, industry_name)
|
||||||
|
if result_df.empty:
|
||||||
|
print(f" 行业板块 {industry_name} 数据转换失败")
|
||||||
|
failed_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保存数据
|
||||||
|
self.save_dataframe(result_df)
|
||||||
|
total_records += len(result_df)
|
||||||
|
success_count += 1
|
||||||
|
print(f" 成功采集 {len(result_df)} 只成分股")
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}")
|
||||||
|
failed_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("全量覆盖采集完成")
|
||||||
|
print(f"总行业板块数: {len(industry_list_df)}")
|
||||||
|
print(f"成功采集: {success_count}")
|
||||||
|
print(f"失败: {failed_count}")
|
||||||
|
print(f"累计成分股记录: {total_records}")
|
||||||
|
print("=" * 60)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"全量采集失败: {exc}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
self.engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def collect_ths_industry_member(
|
||||||
|
db_url: str,
|
||||||
|
tushare_token: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
采集入口 - 全量覆盖采集
|
||||||
|
"""
|
||||||
|
collector = THSIndustryMemberCollector(db_url, tushare_token)
|
||||||
|
collector.run_full_collection()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||||||
|
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||||||
|
TOKEN = TUSHARE_TOKEN
|
||||||
|
|
||||||
|
# 执行全量覆盖采集
|
||||||
|
collect_ths_industry_member(DB_URL, TOKEN)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,300 @@
|
||||||
|
# coding:utf-8
|
||||||
|
"""
|
||||||
|
同花顺板块数据转存脚本
|
||||||
|
功能:将 ths_industry_member 和 ths_concept_member 表的数据转存到 gp_hybk 和 gp_gnbk 表。但是提前需要处理ths_concept_member_collector.py和ths_industry_member_collector.py里面的数据。
|
||||||
|
说明:全量清理目标表后重新插入,保持数据结构对齐
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
|
|
||||||
|
class THSToGPBKTransfer:
|
||||||
|
"""同花顺板块数据转存器"""
|
||||||
|
|
||||||
|
def __init__(self, db_url: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
db_url: 数据库连接 URL
|
||||||
|
"""
|
||||||
|
self.engine = create_engine(
|
||||||
|
db_url,
|
||||||
|
pool_size=5,
|
||||||
|
max_overflow=10,
|
||||||
|
pool_recycle=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("同花顺板块数据转存工具")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_bk_code(code: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
从板块代码中提取数字部分(如:885800.TI -> 885800)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 板块代码,如 '885800.TI' 或 '885800'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取的数字,如果无法提取则返回 None
|
||||||
|
"""
|
||||||
|
if not code:
|
||||||
|
return None
|
||||||
|
# 提取数字部分
|
||||||
|
match = re.search(r'(\d+)', str(code))
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return int(match.group(1))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def transfer_industry_data(self) -> dict:
|
||||||
|
"""
|
||||||
|
将 ths_industry_member 表数据转存到 gp_hybk 表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转存结果统计
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("开始转存行业板块数据 (ths_industry_member -> gp_hybk)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'success': False,
|
||||||
|
'source_count': 0,
|
||||||
|
'transferred_count': 0,
|
||||||
|
'error': None
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 读取源表数据
|
||||||
|
print("正在读取 ths_industry_member 表数据...")
|
||||||
|
source_query = text("""
|
||||||
|
SELECT
|
||||||
|
industry_code,
|
||||||
|
industry_name,
|
||||||
|
stock_symbol,
|
||||||
|
stock_name
|
||||||
|
FROM ths_industry_member
|
||||||
|
WHERE industry_code IS NOT NULL
|
||||||
|
AND stock_symbol IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
source_df = pd.read_sql(source_query, conn)
|
||||||
|
|
||||||
|
result['source_count'] = len(source_df)
|
||||||
|
print(f"源表共 {result['source_count']} 条记录")
|
||||||
|
|
||||||
|
if source_df.empty:
|
||||||
|
print("源表无数据,转存终止")
|
||||||
|
result['success'] = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 2. 数据转换
|
||||||
|
print("正在转换数据格式...")
|
||||||
|
target_df = pd.DataFrame()
|
||||||
|
target_df['bk_code'] = source_df['industry_code'].apply(self.extract_bk_code)
|
||||||
|
target_df['bk_name'] = source_df['industry_name']
|
||||||
|
target_df['gp_code'] = source_df['stock_symbol']
|
||||||
|
target_df['gp_name'] = source_df['stock_name']
|
||||||
|
|
||||||
|
# 过滤掉 bk_code 为 None 的记录
|
||||||
|
target_df = target_df[target_df['bk_code'].notna()]
|
||||||
|
|
||||||
|
print(f"转换后共 {len(target_df)} 条有效记录")
|
||||||
|
|
||||||
|
# 3. 清空目标表
|
||||||
|
print("正在清空 gp_hybk 表...")
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(text("TRUNCATE TABLE gp_hybk"))
|
||||||
|
print("gp_hybk 表已清空")
|
||||||
|
|
||||||
|
# 4. 插入数据
|
||||||
|
print("正在插入数据到 gp_hybk 表...")
|
||||||
|
target_df.to_sql(
|
||||||
|
'gp_hybk',
|
||||||
|
self.engine,
|
||||||
|
if_exists='append',
|
||||||
|
index=False,
|
||||||
|
method='multi',
|
||||||
|
chunksize=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
result['transferred_count'] = len(target_df)
|
||||||
|
result['success'] = True
|
||||||
|
print(f"成功转存 {result['transferred_count']} 条记录到 gp_hybk 表")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result['error'] = str(e)
|
||||||
|
print(f"转存行业板块数据失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def transfer_concept_data(self) -> dict:
|
||||||
|
"""
|
||||||
|
将 ths_concept_member 表数据转存到 gp_gnbk 表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转存结果统计
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("开始转存概念板块数据 (ths_concept_member -> gp_gnbk)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'success': False,
|
||||||
|
'source_count': 0,
|
||||||
|
'transferred_count': 0,
|
||||||
|
'error': None
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 读取源表数据
|
||||||
|
print("正在读取 ths_concept_member 表数据...")
|
||||||
|
source_query = text("""
|
||||||
|
SELECT
|
||||||
|
concept_code,
|
||||||
|
concept_name,
|
||||||
|
stock_symbol,
|
||||||
|
stock_name
|
||||||
|
FROM ths_concept_member
|
||||||
|
WHERE concept_code IS NOT NULL
|
||||||
|
AND stock_symbol IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
source_df = pd.read_sql(source_query, conn)
|
||||||
|
|
||||||
|
result['source_count'] = len(source_df)
|
||||||
|
print(f"源表共 {result['source_count']} 条记录")
|
||||||
|
|
||||||
|
if source_df.empty:
|
||||||
|
print("源表无数据,转存终止")
|
||||||
|
result['success'] = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 2. 数据转换
|
||||||
|
print("正在转换数据格式...")
|
||||||
|
target_df = pd.DataFrame()
|
||||||
|
target_df['bk_code'] = source_df['concept_code'].apply(self.extract_bk_code)
|
||||||
|
target_df['bk_name'] = source_df['concept_name']
|
||||||
|
target_df['gp_code'] = source_df['stock_symbol']
|
||||||
|
target_df['gp_name'] = source_df['stock_name']
|
||||||
|
|
||||||
|
# 过滤掉 bk_code 为 None 的记录
|
||||||
|
target_df = target_df[target_df['bk_code'].notna()]
|
||||||
|
|
||||||
|
print(f"转换后共 {len(target_df)} 条有效记录")
|
||||||
|
|
||||||
|
# 3. 清空目标表
|
||||||
|
print("正在清空 gp_gnbk 表...")
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(text("TRUNCATE TABLE gp_gnbk"))
|
||||||
|
print("gp_gnbk 表已清空")
|
||||||
|
|
||||||
|
# 4. 插入数据
|
||||||
|
print("正在插入数据到 gp_gnbk 表...")
|
||||||
|
target_df.to_sql(
|
||||||
|
'gp_gnbk',
|
||||||
|
self.engine,
|
||||||
|
if_exists='append',
|
||||||
|
index=False,
|
||||||
|
method='multi',
|
||||||
|
chunksize=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
result['transferred_count'] = len(target_df)
|
||||||
|
result['success'] = True
|
||||||
|
print(f"成功转存 {result['transferred_count']} 条记录到 gp_gnbk 表")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result['error'] = str(e)
|
||||||
|
print(f"转存概念板块数据失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def run_full_transfer(self) -> dict:
|
||||||
|
"""
|
||||||
|
执行全量转存:行业 + 概念
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转存结果汇总
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("开始执行全量转存")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
'industry': None,
|
||||||
|
'concept': None,
|
||||||
|
'all_success': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# 转存行业数据
|
||||||
|
industry_result = self.transfer_industry_data()
|
||||||
|
summary['industry'] = industry_result
|
||||||
|
|
||||||
|
# 转存概念数据
|
||||||
|
concept_result = self.transfer_concept_data()
|
||||||
|
summary['concept'] = concept_result
|
||||||
|
|
||||||
|
# 汇总结果
|
||||||
|
summary['all_success'] = industry_result.get('success') and concept_result.get('success')
|
||||||
|
|
||||||
|
# 打印汇总
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("转存完成汇总")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"行业板块转存: {'成功' if industry_result.get('success') else '失败'}")
|
||||||
|
if industry_result.get('success'):
|
||||||
|
print(f" - 源表记录数: {industry_result.get('source_count', 0)}")
|
||||||
|
print(f" - 转存记录数: {industry_result.get('transferred_count', 0)}")
|
||||||
|
if industry_result.get('error'):
|
||||||
|
print(f" - 错误: {industry_result.get('error')}")
|
||||||
|
|
||||||
|
print(f"\n概念板块转存: {'成功' if concept_result.get('success') else '失败'}")
|
||||||
|
if concept_result.get('success'):
|
||||||
|
print(f" - 源表记录数: {concept_result.get('source_count', 0)}")
|
||||||
|
print(f" - 转存记录数: {concept_result.get('transferred_count', 0)}")
|
||||||
|
if concept_result.get('error'):
|
||||||
|
print(f" - 错误: {concept_result.get('error')}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_ths_to_gp_bk(db_url: str):
|
||||||
|
"""
|
||||||
|
转存入口函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_url: 数据库连接 URL
|
||||||
|
"""
|
||||||
|
transfer = THSToGPBKTransfer(db_url)
|
||||||
|
transfer.run_full_transfer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||||||
|
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||||||
|
|
||||||
|
# 执行全量转存
|
||||||
|
transfer_ths_to_gp_bk(DB_URL)
|
||||||
|
|
||||||
|
|
@ -14,10 +14,21 @@ import datetime
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import redis
|
import redis
|
||||||
import time
|
from typing import Tuple, Dict, List
|
||||||
from typing import Tuple, Dict, List, Optional, Union
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from .config import DB_URL, OUTPUT_DIR, LOG_FILE
|
# 处理相对导入和绝对导入
|
||||||
|
try:
|
||||||
|
from .config import DB_URL, OUTPUT_DIR, LOG_FILE
|
||||||
|
except ImportError:
|
||||||
|
# 直接执行时使用绝对导入
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(current_dir.parent.parent))
|
||||||
|
from src.valuation_analysis.config import DB_URL, OUTPUT_DIR, LOG_FILE
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|
@ -59,6 +70,132 @@ class IndustryAnalyzer:
|
||||||
)
|
)
|
||||||
logger.info("行业估值分析器初始化完成")
|
logger.info("行业估值分析器初始化完成")
|
||||||
|
|
||||||
|
def _setup_chinese_font(self):
|
||||||
|
"""
|
||||||
|
设置matplotlib中文字体,解决中文显示乱码问题
|
||||||
|
使用setup_fonts工具确保字体存在,并优先使用项目目录中的字体文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from matplotlib import font_manager
|
||||||
|
from matplotlib.font_manager import FontProperties
|
||||||
|
|
||||||
|
# 1. 确保字体文件存在(使用setup_fonts工具)
|
||||||
|
try:
|
||||||
|
# 获取项目路径
|
||||||
|
current_file = Path(__file__)
|
||||||
|
project_root = current_file.parent.parent
|
||||||
|
setup_fonts_path = project_root / "fundamentals_llm" / "setup_fonts.py"
|
||||||
|
|
||||||
|
if setup_fonts_path.exists():
|
||||||
|
# 动态导入setup_fonts模块
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.spec_from_file_location("setup_fonts", str(setup_fonts_path))
|
||||||
|
setup_fonts = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(setup_fonts)
|
||||||
|
|
||||||
|
# 确保字体已安装到fundamentals_llm/fonts目录
|
||||||
|
font_dir = project_root / "fundamentals_llm" / "fonts"
|
||||||
|
font_file = setup_fonts.install_font(str(font_dir))
|
||||||
|
|
||||||
|
if font_file and os.path.exists(font_file):
|
||||||
|
logger.info(f"setup_fonts已确保字体存在: {font_file}")
|
||||||
|
else:
|
||||||
|
logger.warning("setup_fonts未能安装字体,尝试查找现有字体")
|
||||||
|
else:
|
||||||
|
logger.warning(f"未找到setup_fonts.py: {setup_fonts_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"调用setup_fonts失败: {e},尝试查找现有字体")
|
||||||
|
|
||||||
|
# 2. 查找项目中的字体文件(优先使用)
|
||||||
|
current_file = Path(__file__)
|
||||||
|
project_root = current_file.parent.parent
|
||||||
|
|
||||||
|
# 可能的字体文件位置
|
||||||
|
font_paths = [
|
||||||
|
project_root / "fundamentals_llm" / "fonts" / "simhei.ttf", # 现有字体目录
|
||||||
|
project_root / "valuation_analysis" / "fonts" / "simhei.ttf", # 本模块字体目录
|
||||||
|
]
|
||||||
|
|
||||||
|
font_file = None
|
||||||
|
for font_path in font_paths:
|
||||||
|
if font_path.exists():
|
||||||
|
font_file = str(font_path)
|
||||||
|
logger.info(f"找到项目字体文件: {font_file}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 3. 如果找到项目字体文件,直接使用
|
||||||
|
if font_file and os.path.exists(font_file):
|
||||||
|
try:
|
||||||
|
# 使用字体文件路径直接注册
|
||||||
|
font_prop = FontProperties(fname=font_file)
|
||||||
|
# 获取字体名称
|
||||||
|
font_name = font_prop.get_name()
|
||||||
|
# 设置matplotlib使用该字体
|
||||||
|
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
|
||||||
|
logger.info(f"成功使用项目字体: {font_name} (来自: {font_file})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载项目字体文件失败: {e},尝试使用系统字体")
|
||||||
|
font_file = None
|
||||||
|
|
||||||
|
# 4. 如果项目字体不可用,尝试使用系统字体
|
||||||
|
if font_file is None:
|
||||||
|
import platform
|
||||||
|
system = platform.system()
|
||||||
|
font_candidates = []
|
||||||
|
if system == "Windows":
|
||||||
|
font_candidates = [
|
||||||
|
'Microsoft YaHei',
|
||||||
|
'SimHei',
|
||||||
|
'SimSun',
|
||||||
|
'KaiTi',
|
||||||
|
'FangSong'
|
||||||
|
]
|
||||||
|
elif system == "Darwin": # macOS
|
||||||
|
font_candidates = [
|
||||||
|
'PingFang SC',
|
||||||
|
'STHeiti',
|
||||||
|
'Arial Unicode MS',
|
||||||
|
'Heiti SC'
|
||||||
|
]
|
||||||
|
else: # Linux
|
||||||
|
font_candidates = [
|
||||||
|
'WenQuanYi Micro Hei',
|
||||||
|
'WenQuanYi Zen Hei',
|
||||||
|
'Noto Sans CJK SC',
|
||||||
|
'Droid Sans Fallback'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 尝试设置系统字体
|
||||||
|
for font_name in font_candidates:
|
||||||
|
try:
|
||||||
|
font_list = [f.name for f in font_manager.fontManager.ttflist]
|
||||||
|
if font_name in font_list:
|
||||||
|
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
|
||||||
|
logger.info(f"使用系统字体: {font_name}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5. 如果都没找到,使用默认设置
|
||||||
|
if 'font.sans-serif' not in plt.rcParams or len(plt.rcParams['font.sans-serif']) == 0:
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'DejaVu Sans']
|
||||||
|
logger.warning("使用默认中文字体设置")
|
||||||
|
|
||||||
|
# 解决负号显示问题
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
|
# 清除matplotlib字体缓存,强制重新加载字体
|
||||||
|
try:
|
||||||
|
font_manager._rebuild()
|
||||||
|
logger.info("已清除matplotlib字体缓存")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清除字体缓存失败: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"设置中文字体失败,使用默认设置: {e}")
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'DejaVu Sans']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
def get_industry_list(self) -> List[Dict]:
|
def get_industry_list(self) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
获取所有行业列表
|
获取所有行业列表
|
||||||
|
|
@ -1234,3 +1371,559 @@ class IndustryAnalyzer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"筛选拥挤度缓存时出错: {e}")
|
logger.error(f"筛选拥挤度缓存时出错: {e}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _get_all_industries_concepts(self, days: int = 120) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
从Redis缓存中获取所有有效的行业和概念板块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 需要分析的天数,用于验证数据有效性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(all_industries, all_concepts) 元组
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算结束日期和开始日期
|
||||||
|
end_date = datetime.datetime.now()
|
||||||
|
start_date = end_date - datetime.timedelta(days=days)
|
||||||
|
|
||||||
|
all_industries = []
|
||||||
|
all_concepts = []
|
||||||
|
|
||||||
|
# 获取所有行业和概念板块的缓存key
|
||||||
|
industry_keys = redis_client.keys('industry_crowding:*')
|
||||||
|
concept_keys = redis_client.keys('concept_crowding:*')
|
||||||
|
|
||||||
|
# 处理行业
|
||||||
|
for key in industry_keys:
|
||||||
|
try:
|
||||||
|
name = key.split(':', 1)[1]
|
||||||
|
cached_data = redis_client.get(key)
|
||||||
|
if not cached_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(json.loads(cached_data))
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确保trade_date是日期类型
|
||||||
|
if 'trade_date' in df.columns:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
|
||||||
|
# 筛选近120天的数据
|
||||||
|
df = df[df['trade_date'] >= start_date]
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
|
||||||
|
if len(df) < 10: # 数据点太少,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_industries.append(name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理行业缓存 {key} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理概念板块
|
||||||
|
for key in concept_keys:
|
||||||
|
try:
|
||||||
|
name = key.split(':', 1)[1]
|
||||||
|
cached_data = redis_client.get(key)
|
||||||
|
if not cached_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(json.loads(cached_data))
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确保trade_date是日期类型
|
||||||
|
if 'trade_date' in df.columns:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
|
||||||
|
# 筛选近120天的数据
|
||||||
|
df = df[df['trade_date'] >= start_date]
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
|
||||||
|
if len(df) < 10: # 数据点太少,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_concepts.append(name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理概念缓存 {key} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"获取到 {len(all_industries)} 个行业和 {len(all_concepts)} 个概念板块")
|
||||||
|
return all_industries, all_concepts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取行业/概念板块列表失败: {e}")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
def _select_top_industries_concepts(self, count: int = 40, days: int = 120) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
从Redis缓存中选择拥挤度变化幅度最大的行业和概念板块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count: 选择的行业/概念数量
|
||||||
|
days: 需要分析的天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(selected_industries, selected_concepts) 元组
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算结束日期和开始日期
|
||||||
|
end_date = datetime.datetime.now()
|
||||||
|
start_date = end_date - datetime.timedelta(days=days)
|
||||||
|
|
||||||
|
industry_scores = []
|
||||||
|
concept_scores = []
|
||||||
|
|
||||||
|
# 获取所有行业和概念板块的缓存key
|
||||||
|
industry_keys = redis_client.keys('industry_crowding:*')
|
||||||
|
concept_keys = redis_client.keys('concept_crowding:*')
|
||||||
|
|
||||||
|
# 处理行业
|
||||||
|
for key in industry_keys:
|
||||||
|
try:
|
||||||
|
name = key.split(':', 1)[1]
|
||||||
|
cached_data = redis_client.get(key)
|
||||||
|
if not cached_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(json.loads(cached_data))
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确保trade_date是日期类型
|
||||||
|
if 'trade_date' in df.columns:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
|
||||||
|
# 筛选近120天的数据
|
||||||
|
df = df[df['trade_date'] >= start_date]
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
|
||||||
|
if len(df) < 10: # 数据点太少,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算拥挤度变化幅度(最大值-最小值)
|
||||||
|
percentile_values = df['percentile'].values
|
||||||
|
if len(percentile_values) > 0:
|
||||||
|
change_range = float(percentile_values.max() - percentile_values.min())
|
||||||
|
current_value = float(percentile_values[-1])
|
||||||
|
|
||||||
|
# 综合评分:变化幅度 + 当前值的绝对值(突出极端情况)
|
||||||
|
score = change_range * 0.7 + abs(current_value - 50) * 0.3
|
||||||
|
|
||||||
|
industry_scores.append({
|
||||||
|
'name': name,
|
||||||
|
'score': score,
|
||||||
|
'change_range': change_range,
|
||||||
|
'current': current_value
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理行业缓存 {key} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理概念板块
|
||||||
|
for key in concept_keys:
|
||||||
|
try:
|
||||||
|
name = key.split(':', 1)[1]
|
||||||
|
cached_data = redis_client.get(key)
|
||||||
|
if not cached_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(json.loads(cached_data))
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确保trade_date是日期类型
|
||||||
|
if 'trade_date' in df.columns:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
|
||||||
|
# 筛选近120天的数据
|
||||||
|
df = df[df['trade_date'] >= start_date]
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
|
||||||
|
if len(df) < 10: # 数据点太少,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算拥挤度变化幅度
|
||||||
|
percentile_values = df['percentile'].values
|
||||||
|
if len(percentile_values) > 0:
|
||||||
|
change_range = float(percentile_values.max() - percentile_values.min())
|
||||||
|
current_value = float(percentile_values[-1])
|
||||||
|
|
||||||
|
# 综合评分
|
||||||
|
score = change_range * 0.7 + abs(current_value - 50) * 0.3
|
||||||
|
|
||||||
|
concept_scores.append({
|
||||||
|
'name': name,
|
||||||
|
'score': score,
|
||||||
|
'change_range': change_range,
|
||||||
|
'current': current_value
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理概念缓存 {key} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 按评分排序,选择前count个
|
||||||
|
industry_scores.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
concept_scores.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
|
||||||
|
selected_industries = [item['name'] for item in industry_scores[:count]]
|
||||||
|
selected_concepts = [item['name'] for item in concept_scores[:count]]
|
||||||
|
|
||||||
|
logger.info(f"选择了 {len(selected_industries)} 个行业和 {len(selected_concepts)} 个概念板块")
|
||||||
|
|
||||||
|
return selected_industries, selected_concepts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"选择行业/概念板块时出错: {e}")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
def _build_heatmap_data(self, names: List[str], is_concept: bool, days: int = 120) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
构建热力图数据矩阵
|
||||||
|
|
||||||
|
Args:
|
||||||
|
names: 行业/概念板块名称列表
|
||||||
|
is_concept: 是否为概念板块
|
||||||
|
days: 需要分析的天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含热力图数据的DataFrame,行为行业/概念,列为日期,值为拥挤度百分位
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算日期范围
|
||||||
|
end_date = datetime.datetime.now()
|
||||||
|
start_date = end_date - datetime.timedelta(days=days)
|
||||||
|
|
||||||
|
# 获取所有交易日期(从第一个有效数据中获取)
|
||||||
|
all_dates = set()
|
||||||
|
data_dict = {}
|
||||||
|
|
||||||
|
prefix = 'concept_crowding:' if is_concept else 'industry_crowding:'
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
cache_key = f"{prefix}{name}"
|
||||||
|
cached_data = redis_client.get(cache_key)
|
||||||
|
|
||||||
|
if not cached_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = pd.DataFrame(json.loads(cached_data))
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确保trade_date是日期类型
|
||||||
|
if 'trade_date' in df.columns:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
|
||||||
|
# 筛选日期范围
|
||||||
|
df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)]
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 收集所有日期
|
||||||
|
all_dates.update(df['trade_date'].tolist())
|
||||||
|
|
||||||
|
# 存储数据
|
||||||
|
data_dict[name] = df.set_index('trade_date')['percentile'].to_dict()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理 {name} 数据时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_dates:
|
||||||
|
logger.warning("没有找到有效数据")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# 创建完整的日期序列(只包含交易日)
|
||||||
|
all_dates = sorted(list(all_dates))
|
||||||
|
|
||||||
|
# 构建矩阵
|
||||||
|
matrix_data = []
|
||||||
|
for name in names:
|
||||||
|
row = []
|
||||||
|
for date in all_dates:
|
||||||
|
if name in data_dict and date in data_dict[name]:
|
||||||
|
row.append(float(data_dict[name][date]))
|
||||||
|
else:
|
||||||
|
row.append(np.nan)
|
||||||
|
matrix_data.append(row)
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
heatmap_df = pd.DataFrame(matrix_data, index=names, columns=all_dates)
|
||||||
|
|
||||||
|
logger.info(f"构建热力图数据完成: {heatmap_df.shape}")
|
||||||
|
return heatmap_df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建热力图数据失败: {e}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def plot_crowding_heatmap(self, days: int = 120, items_per_page: int = 40, output_dir: str = None):
|
||||||
|
"""
|
||||||
|
绘制行业和概念板块拥挤度热力图
|
||||||
|
每40个行业/概念生成一张图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 分析的天数,默认120天
|
||||||
|
items_per_page: 每张图显示的行业/概念数量,默认40个
|
||||||
|
output_dir: 输出目录,默认为src/data目录
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 设置输出目录
|
||||||
|
if output_dir is None:
|
||||||
|
# 获取项目根目录
|
||||||
|
current_file = Path(__file__)
|
||||||
|
project_root = current_file.parent.parent.parent
|
||||||
|
output_dir = project_root / "src" / "data"
|
||||||
|
else:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取所有行业和概念
|
||||||
|
all_industries, all_concepts = self._get_all_industries_concepts(days=days)
|
||||||
|
|
||||||
|
if not all_industries and not all_concepts:
|
||||||
|
logger.warning("没有找到有效的行业或概念板块数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 设置中文字体(使用setup_fonts工具)
|
||||||
|
self._setup_chinese_font()
|
||||||
|
|
||||||
|
# 设置seaborn样式
|
||||||
|
sns.set_style("whitegrid")
|
||||||
|
|
||||||
|
# 生成时间戳用于文件名
|
||||||
|
timestamp = datetime.datetime.now().strftime('%Y%m%d')
|
||||||
|
|
||||||
|
# 绘制行业热力图(按每items_per_page个分组)
|
||||||
|
if all_industries:
|
||||||
|
total_industries = len(all_industries)
|
||||||
|
total_pages = (total_industries + items_per_page - 1) // items_per_page # 向上取整
|
||||||
|
logger.info(f"开始绘制行业热力图,共 {total_industries} 个行业,将生成 {total_pages} 张图")
|
||||||
|
|
||||||
|
for page in range(total_pages):
|
||||||
|
start_idx = page * items_per_page
|
||||||
|
end_idx = min(start_idx + items_per_page, total_industries)
|
||||||
|
page_industries = all_industries[start_idx:end_idx]
|
||||||
|
|
||||||
|
logger.info(f"绘制行业热力图第 {page + 1}/{total_pages} 张(行业 {start_idx + 1}-{end_idx})")
|
||||||
|
industry_heatmap = self._build_heatmap_data(page_industries, is_concept=False, days=days)
|
||||||
|
|
||||||
|
if not industry_heatmap.empty:
|
||||||
|
self._plot_single_heatmap(
|
||||||
|
industry_heatmap,
|
||||||
|
title=f"行业拥挤度热力图(近{days}天,第{page + 1}/{total_pages}页,{len(page_industries)}个行业)",
|
||||||
|
output_path=output_dir / f"industry_crowding_heatmap_{timestamp}_page{page + 1:02d}.png",
|
||||||
|
is_concept=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制概念板块热力图(按每items_per_page个分组)
|
||||||
|
if all_concepts:
|
||||||
|
total_concepts = len(all_concepts)
|
||||||
|
total_pages = (total_concepts + items_per_page - 1) // items_per_page # 向上取整
|
||||||
|
logger.info(f"开始绘制概念板块热力图,共 {total_concepts} 个概念,将生成 {total_pages} 张图")
|
||||||
|
|
||||||
|
for page in range(total_pages):
|
||||||
|
start_idx = page * items_per_page
|
||||||
|
end_idx = min(start_idx + items_per_page, total_concepts)
|
||||||
|
page_concepts = all_concepts[start_idx:end_idx]
|
||||||
|
|
||||||
|
logger.info(f"绘制概念板块热力图第 {page + 1}/{total_pages} 张(概念 {start_idx + 1}-{end_idx})")
|
||||||
|
concept_heatmap = self._build_heatmap_data(page_concepts, is_concept=True, days=days)
|
||||||
|
|
||||||
|
if not concept_heatmap.empty:
|
||||||
|
self._plot_single_heatmap(
|
||||||
|
concept_heatmap,
|
||||||
|
title=f"概念板块拥挤度热力图(近{days}天,第{page + 1}/{total_pages}页,{len(page_concepts)}个概念)",
|
||||||
|
output_path=output_dir / f"concept_crowding_heatmap_{timestamp}_page{page + 1:02d}.png",
|
||||||
|
is_concept=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("热力图绘制完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"绘制热力图失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
def _plot_single_heatmap(self, heatmap_df: pd.DataFrame, title: str, output_path: Path, is_concept: bool):
|
||||||
|
"""
|
||||||
|
绘制单个热力图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
heatmap_df: 热力图数据DataFrame
|
||||||
|
title: 图表标题
|
||||||
|
output_path: 输出路径
|
||||||
|
is_concept: 是否为概念板块
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算图表尺寸(根据数据量动态调整)
|
||||||
|
n_rows = len(heatmap_df)
|
||||||
|
n_cols = len(heatmap_df.columns)
|
||||||
|
|
||||||
|
# 基础尺寸,根据数据量调整
|
||||||
|
fig_width = max(20, n_cols * 0.15)
|
||||||
|
fig_height = max(12, n_rows * 0.4)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
||||||
|
|
||||||
|
# 获取当前设置的字体
|
||||||
|
from matplotlib.font_manager import FontProperties
|
||||||
|
current_font = plt.rcParams['font.sans-serif'][0] if plt.rcParams['font.sans-serif'] else 'SimHei'
|
||||||
|
font_prop = FontProperties(fname=str(Path(__file__).parent.parent / "fundamentals_llm" / "fonts" / "simhei.ttf")) if (Path(__file__).parent.parent / "fundamentals_llm" / "fonts" / "simhei.ttf").exists() else None
|
||||||
|
|
||||||
|
# heatmap_df: 行为行业/概念,列为日期
|
||||||
|
# seaborn heatmap: 行在y轴,列在x轴
|
||||||
|
# 所以:y轴是行业/概念,x轴是日期(符合要求)
|
||||||
|
plot_data = heatmap_df
|
||||||
|
|
||||||
|
# 使用自定义颜色映射:绿色(不拥挤)-> 黄色(中性)-> 红色(拥挤)
|
||||||
|
# 创建从绿色到红色渐变的colormap
|
||||||
|
colors = ['#2ecc71', '#f1c40f', '#e74c3c'] # 绿色、黄色、红色
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
n_bins = 100
|
||||||
|
cmap = LinearSegmentedColormap.from_list('crowding', colors, N=n_bins)
|
||||||
|
|
||||||
|
# 绘制热力图
|
||||||
|
heatmap = sns.heatmap(
|
||||||
|
plot_data,
|
||||||
|
cmap=cmap,
|
||||||
|
vmin=0,
|
||||||
|
vmax=100,
|
||||||
|
annot=False, # 不显示数值(数据太多)
|
||||||
|
fmt='.1f',
|
||||||
|
cbar_kws={
|
||||||
|
'label': '拥挤度百分位 (%)',
|
||||||
|
'shrink': 0.8,
|
||||||
|
'format': '%.0f'
|
||||||
|
},
|
||||||
|
linewidths=0.05,
|
||||||
|
linecolor='white',
|
||||||
|
ax=ax,
|
||||||
|
xticklabels=False, # 先不显示x轴标签,后面自定义
|
||||||
|
yticklabels=True # 显示y轴标签(行业/概念名称)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置colorbar标签字体
|
||||||
|
try:
|
||||||
|
cbar = heatmap.collections[0].colorbar
|
||||||
|
if cbar is not None:
|
||||||
|
if font_prop:
|
||||||
|
cbar.set_label('拥挤度百分位 (%)', fontproperties=font_prop, fontsize=10)
|
||||||
|
else:
|
||||||
|
cbar.set_label('拥挤度百分位 (%)', fontfamily=current_font, fontsize=10)
|
||||||
|
# 设置colorbar刻度标签字体
|
||||||
|
for label in cbar.ax.get_yticklabels():
|
||||||
|
if font_prop:
|
||||||
|
label.set_fontproperties(font_prop)
|
||||||
|
else:
|
||||||
|
label.set_fontfamily(current_font)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"设置colorbar字体失败: {e}")
|
||||||
|
|
||||||
|
# 设置标题和标签,显式指定字体
|
||||||
|
if font_prop:
|
||||||
|
ax.set_title(title, fontsize=16, fontweight='bold', pad=20, fontproperties=font_prop)
|
||||||
|
ax.set_xlabel('日期(时间)', fontsize=12, fontproperties=font_prop)
|
||||||
|
ax.set_ylabel('行业/概念板块', fontsize=12, fontproperties=font_prop)
|
||||||
|
else:
|
||||||
|
ax.set_title(title, fontsize=16, fontweight='bold', pad=20, fontfamily=current_font)
|
||||||
|
ax.set_xlabel('日期(时间)', fontsize=12, fontfamily=current_font)
|
||||||
|
ax.set_ylabel('行业/概念板块', fontsize=12, fontfamily=current_font)
|
||||||
|
|
||||||
|
# 调整y轴标签(行业/概念名称)- 显示完整名称,但可以旋转
|
||||||
|
y_labels = plot_data.index.tolist()
|
||||||
|
if len(y_labels) > 0:
|
||||||
|
# 如果名称超过15个字符,截断
|
||||||
|
y_labels_display = [label[:15] + '...' if len(str(label)) > 15 else str(label) for label in y_labels]
|
||||||
|
if font_prop:
|
||||||
|
ax.set_yticklabels(y_labels_display, fontsize=9, rotation=0, ha='right', fontproperties=font_prop)
|
||||||
|
else:
|
||||||
|
ax.set_yticklabels(y_labels_display, fontsize=9, rotation=0, ha='right', fontfamily=current_font)
|
||||||
|
|
||||||
|
# 调整x轴标签(日期)- 只显示部分日期,避免过于密集
|
||||||
|
x_labels = plot_data.columns.tolist()
|
||||||
|
if len(x_labels) > 0:
|
||||||
|
# 如果日期太多,只显示部分
|
||||||
|
if len(x_labels) > 30:
|
||||||
|
step = max(1, len(x_labels) // 15) # 最多显示15个日期标签
|
||||||
|
x_positions = list(range(0, len(x_labels), step))
|
||||||
|
x_labels_display = []
|
||||||
|
for pos in x_positions:
|
||||||
|
if pos < len(x_labels):
|
||||||
|
date_str = str(x_labels[pos])
|
||||||
|
# 只显示月-日
|
||||||
|
if isinstance(x_labels[pos], pd.Timestamp):
|
||||||
|
x_labels_display.append(x_labels[pos].strftime('%m-%d'))
|
||||||
|
else:
|
||||||
|
x_labels_display.append(date_str[:10] if len(date_str) > 10 else date_str)
|
||||||
|
ax.set_xticks(x_positions)
|
||||||
|
if font_prop:
|
||||||
|
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontproperties=font_prop)
|
||||||
|
else:
|
||||||
|
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontfamily=current_font)
|
||||||
|
else:
|
||||||
|
# 日期不多,全部显示
|
||||||
|
x_labels_display = []
|
||||||
|
for date in x_labels:
|
||||||
|
if isinstance(date, pd.Timestamp):
|
||||||
|
x_labels_display.append(date.strftime('%m-%d'))
|
||||||
|
else:
|
||||||
|
date_str = str(date)
|
||||||
|
x_labels_display.append(date_str[:10] if len(date_str) > 10 else date_str)
|
||||||
|
if font_prop:
|
||||||
|
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontproperties=font_prop)
|
||||||
|
else:
|
||||||
|
ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontfamily=current_font)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 保存图片
|
||||||
|
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logger.info(f"热力图已保存到: {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"绘制热力图失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
if 'fig' in locals():
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数:生成行业和概念板块拥挤度热力图"""
|
||||||
|
try:
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("开始生成行业和概念板块拥挤度热力图")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 创建分析器实例
|
||||||
|
analyzer = IndustryAnalyzer()
|
||||||
|
|
||||||
|
# 绘制热力图
|
||||||
|
# 参数说明:
|
||||||
|
# days: 分析的天数,默认120天
|
||||||
|
# items_per_page: 每张图显示的行业/概念数量,默认40个
|
||||||
|
# output_dir: 输出目录,默认为src/data目录
|
||||||
|
analyzer.plot_crowding_heatmap(days=120, items_per_page=40)
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("热力图生成完成!")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"执行主函数失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue