diff --git a/requirements.txt b/requirements.txt index a3df8b3..91b053e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,5 @@ dbutils==3.1.2 pyyaml==6.0.3 xtquant==250516.1.1 tushare>=1.4.24 -akshare==1.17.82 \ No newline at end of file +akshare==1.17.82 +seaborn==0.13.2 diff --git a/src/app.py b/src/app.py index 805d53d..37b490f 100644 --- a/src/app.py +++ b/src/app.py @@ -682,18 +682,34 @@ def precalculate_industry_crowding_batch(): from src.valuation_analysis.industry_analysis import IndustryAnalyzer analyzer = IndustryAnalyzer() - # 固定行业和概念板块 - industries = ["煤炭开采", "焦炭加工", "油气开采", "石油化工", "油服工程", "日用化工", "化纤", "化学原料", "化学制品", "塑料", "橡胶", "农用化工", "非金属材料", "冶钢原料", "普钢", "特钢", "工业金属", "贵金属", "能源金属", "稀有金属", "金属新材料", "水泥", "玻璃玻纤", "装饰建材", "种植业", "养殖业", "林业", "渔业", "饲料", "农产品加工", "动物保健", "酿酒", "饮料乳品", "调味品", "休闲食品", "食品加工", "纺织制造", "服装家纺", "饰品", "造纸", "包装印刷", "家居用品", "文娱用品", "白色家电", "黑色家电", "小家电", "厨卫电器", "家电零部件", "一般零售", "商业物业经营", "专业连锁", "贸易", "电子商务", "乘用车", "商用车", "汽车零部件", "汽车服务", "摩托车及其他", "化学制药", "生物制品", "中药", "医药商业", "医疗器械", "医疗服务", "医疗美容", "电机制造", "电池", "电网设备", "光伏设备", "风电设备", "其他发电设备", "地面兵装", "航空装备", "航天装备", "航海装备", "军工电子", "轨交设备", "通用设备", "专用设备", "工程机械", "自动化设备", "半导体", "消费电子", "光学光电", "元器件", "其他电子", "通信设备", "通信工程", "电信服务", "IT设备", "软件服务", "云服务", "产业互联网", "游戏", "广告营销", "影视院线", "数字媒体", "出版业", "广播电视", "全国性银行", "地方性银行", "证券", "保险", "多元金融", "房屋建设", "基础建设", "专业工程", "工程咨询服务", "装修装饰", "房地产开发", "房产服务", "体育", "教育培训", "酒店餐饮", "旅游", "专业服务", "公路铁路", "航空机场", "航运港口", "物流", "电力", "燃气", "水务", "环保设备", "环境治理", "环境监测", "综合类"] - concepts = ["通达信88", "海峡西岸", "海南自贸", "一带一路", "上海自贸", "雄安新区", "粤港澳", "ST板块", "次新股", "含H股", "含B股", "含GDR", "含可转债", "国防军工", "军民融合", "大飞机", "稀缺资源", "5G概念", "碳中和", "黄金概念", "物联网", "创投概念", "航运概念", "铁路基建", "高端装备", "核电核能", "光伏", "风电", "锂电池概念", "燃料电池", "HJT电池", "固态电池", "钠电池", "钒电池", "TOPCon电池", "钙钛矿电池", "BC电池", "氢能源", "稀土永磁", "盐湖提锂", "锂矿", "水利建设", "卫星导航", "可燃冰", "页岩气", "生物疫苗", "基因概念", "维生素", "仿制药", "创新药", "免疫治疗", "CXO概念", "节能环保", "食品安全", "白酒概念", "代糖概念", "猪肉", "鸡肉", "水产品", "碳纤维", "石墨烯", "3D打印", "苹果概念", "阿里概念", "腾讯概念", "小米概念", "百度概念", "华为鸿蒙", "华为海思", "华为汽车", "华为算力", "特斯拉概念", "消费电子概念", "汽车电子", "无线耳机", "生物质能", "地热能", "充电桩", "新能源车", "换电概念", "高压快充", "草甘膦", "安防服务", "垃圾分类", "核污染防治", "风沙治理", "乡村振兴", "土地流转", "体育概念", "博彩概念", "赛马概念", "分散染料", "聚氨酯", "云计算", "边缘计算", "网络游戏", "信息安全", "国产软件", "大数据", "数据中心", "芯片", "MCU芯片", "汽车芯片", "存储芯片", "互联金融", "婴童概念", "养老概念", "网红经济", "民营医院", "特高压", "智能电网", "智能穿戴", "智能交通", "智能家居", "智能医疗", "智慧城市", "智慧政务", "机器人概念", "机器视觉", "超导概念", "职业教育", "物业管理概念", "虚拟现实", "数字孪生", "钛金属", "钴金属", "镍金属", "氟概念", "磷概念", "无人机", "PPP概念", "新零售", "跨境电商", "量子科技", "无人驾驶", "ETC概念", "胎压监测", "OLED概念", "MiniLED", "MicroLED", "超清视频", "区块链", "数字货币", "人工智能", "租购同权", "工业互联", "知识产权", "工业大麻", "工业气体", "人造肉", "预制菜", "种业", "化肥概念", "操作系统", "光刻机", "第三代半导体", "远程办公", "口罩防护", "医废处理", "虫害防治", "超级电容", "C2M概念", "地摊经济", "冷链物流", "抖音概念", "降解塑料", "医美概念", "人脑工程", "烟草概念", "新型烟草", "有机硅概念", "新冠检测", "BIPV概念", "地下管网", "储能", "新材料", "工业母机", "一体压铸", "汽车热管理", "汽车拆解", "NMN概念", "国资云", "元宇宙概念", "NFT概念", "云游戏", "天然气", "绿色电力", "培育钻石", "信创", "幽门螺杆菌", "电子纸", "新冠药概念", "免税概念", "PVDF概念", "装配式建筑", "绿色建筑", "东数西算", "跨境支付CIPS", "中俄贸易", "电子身份证", "家庭医生", "辅助生殖", "肝炎概念", "新型城镇", "粮食概念", "超临界发电", "虚拟电厂", "动力电池回收", "PCB概念", "先进封装", "热泵概念", "EDA概念", "光热发电", "供销社", "Web3概念", "DRG-DIP", "AIGC概念", "复合铜箔", "数据确权", "数据要素", "POE胶膜", "血氧仪", "旅游概念", "中特估", "ChatGPT概念", "CPO概念", "数字水印", "毫米波雷达", "工业软件", "6G概念", "时空大数据", "可控核聚变", "知识付费", "算力租赁", "光通信", "混合现实", "英伟达概念", "减速器", "减肥药", "合成生物", "星闪概念", "液冷服务器", "新型工业化", "短剧游戏", "多模态AI", "PEEK材料", "小米汽车概念", "飞行汽车", "Sora概念", "人形机器人", "AI手机PC", "低空经济", "铜缆高速连接", "军工信息化", "玻璃基板", "商业航天", "车联网", "财税数字化", "折叠屏", "AI眼镜", "智谱AI", "IP经济", "宠物经济", "小红书概念", "AI智能体", "DeepSeek概念", "AI医疗概念", "海洋经济", "外骨骼机器人", "军贸概念"] + + # 从数据库查询所有行业 + industry_list = analyzer.get_industry_list() + industries = [item['name'] for item in industry_list] + logger.info(f"从数据库获取到 {len(industries)} 个行业") + + # 从数据库查询所有概念板块 + concept_list = analyzer.get_concept_list() + concepts = [item['name'] for item in concept_list] + logger.info(f"从数据库获取到 {len(concepts)} 个概念板块") # 批量计算行业和概念板块拥挤度 analyzer.batch_calculate_industry_crowding(industries, concepts) - logger.info("批量计算行业和概念板块的拥挤度指标完成") + logger.info(f"批量计算完成:{len(industries)} 个行业和 {len(concepts)} 个概念板块的拥挤度指标") except Exception as e: logger.error(f"批量计算行业拥挤度指标失败: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({ + "status": "error", + "message": str(e) + }), 500 + return jsonify({ - "status": "success" + "status": "success", + "industries_count": len(industries) if 'industries' in locals() else 0, + "concepts_count": len(concepts) if 'concepts' in locals() else 0 }), 200 @app.route('/') diff --git a/src/quantitative_analysis/company_lifecycle_factor.py b/src/quantitative_analysis/company_lifecycle_factor.py index b5f7082..e19aff8 100644 --- a/src/quantitative_analysis/company_lifecycle_factor.py +++ b/src/quantitative_analysis/company_lifecycle_factor.py @@ -28,7 +28,7 @@ except ImportError: # 导入股票代码格式化工具 try: - from tools.stock_code_formatter import StockCodeFormatter + from src.tools.stock_code_formatter import StockCodeFormatter except ImportError: import importlib.util # project_root 已经是项目根目录,直接拼接 tools 目录 diff --git a/src/quantitative_analysis/financial_indicator_analyzer.py b/src/quantitative_analysis/financial_indicator_analyzer.py index d53c5f9..05eaae3 100644 --- a/src/quantitative_analysis/financial_indicator_analyzer.py +++ b/src/quantitative_analysis/financial_indicator_analyzer.py @@ -9,7 +9,6 @@ import sys import pymongo -import datetime import logging from typing import Dict, List, Optional, Union from pathlib import Path @@ -24,7 +23,7 @@ sys.path.append(str(project_root)) from src.valuation_analysis.config import MONGO_CONFIG2, DB_URL # 导入股票代码格式转换工具 -from tools.stock_code_formatter import StockCodeFormatter +from src.tools.stock_code_formatter import StockCodeFormatter # 设置日志 logging.basicConfig( @@ -478,9 +477,7 @@ class FinancialIndicatorAnalyzer: Returns: List[str]: 季度列表,格式为 YYYY-MM-DD """ - from datetime import datetime, timedelta - import calendar - + quarters = [] year, month, day = map(int, start_date.split('-')) diff --git a/src/quantitative_analysis/overlap_analyzer.py b/src/quantitative_analysis/overlap_analyzer.py index eea511a..53f6ccd 100644 --- a/src/quantitative_analysis/overlap_analyzer.py +++ b/src/quantitative_analysis/overlap_analyzer.py @@ -12,8 +12,7 @@ import logging from typing import Dict, List, Optional, Tuple from pathlib import Path from sqlalchemy import create_engine, text -import pandas as pd -from datetime import datetime, timedelta +from datetime import datetime # 添加项目根路径到Python路径 project_root = Path(__file__).parent.parent.parent @@ -33,7 +32,7 @@ except ImportError: # 导入股票代码格式转换工具 try: - from tools.stock_code_formatter import StockCodeFormatter + from src.tools.stock_code_formatter import StockCodeFormatter except ImportError: # 如果上面的导入失败,尝试直接导入 import importlib.util diff --git a/tools/stock_code_formatter.py b/src/tools/stock_code_formatter.py similarity index 100% rename from tools/stock_code_formatter.py rename to src/tools/stock_code_formatter.py diff --git a/tools/trigger_batch_collection.py b/src/tools/trigger_batch_collection.py similarity index 100% rename from tools/trigger_batch_collection.py rename to src/tools/trigger_batch_collection.py diff --git a/src/tushare_scripts/ths_industry_member_collector.py b/src/tushare_scripts/ths_industry_member_collector.py new file mode 100644 index 0000000..5564c3a --- /dev/null +++ b/src/tushare_scripts/ths_industry_member_collector.py @@ -0,0 +1,243 @@ +# coding:utf-8 +""" +同花顺行业板块成分股采集工具 +功能:从 Tushare 获取同花顺行业板块成分股数据并落库 +API 文档: https://tushare.pro/document/2?doc_id=261 +说明:每次调用时全量覆盖,不需要每日更新 +""" + +import os +import sys +from datetime import datetime +from typing import Optional + +import pandas as pd +import tushare as ts +from sqlalchemy import create_engine, text + +# 添加项目根目录到路径,确保能够读取配置 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(PROJECT_ROOT) + +from src.scripts.config import TUSHARE_TOKEN + + +class THSIndustryMemberCollector: + """同花顺行业板块成分股采集器""" + + def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"): + """ + Args: + db_url: 数据库连接 URL + tushare_token: Tushare Token + table_name: 目标表名,默认 ths_industry_member + """ + self.engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + ) + self.table_name = table_name + + ts.set_token(tushare_token) + self.pro = ts.pro_api() + + print("=" * 60) + print("同花顺行业板块成分股采集工具") + print(f"目标数据表: {self.table_name}") + print("=" * 60) + + @staticmethod + def convert_tushare_code_to_db(ts_code: str) -> str: + """ + 将 Tushare 代码(600000.SH)转换为数据库代码(SH600000) + """ + if not ts_code or "." not in ts_code: + return ts_code + base, market = ts_code.split(".") + return f"{market}{base}" + + def fetch_industry_index_list(self) -> pd.DataFrame: + """ + 获取所有同花顺行业板块列表 + API: ths_index + 文档: https://tushare.pro/document/2?doc_id=259 + + Returns: + 行业板块列表 DataFrame + """ + try: + print("正在获取同花顺行业板块列表...") + df = self.pro.ths_index( + exchange='A', + type='I', # I=行业板块 + ) + if df.empty: + print("未获取到行业板块数据") + return pd.DataFrame() + print(f"成功获取 {len(df)} 个行业板块") + return df + except Exception as exc: + print(f"获取行业板块列表失败: {exc}") + return pd.DataFrame() + + def fetch_industry_members(self, ts_code: str) -> pd.DataFrame: + """ + 获取指定行业板块的成分股 + API: ths_member + 文档: https://tushare.pro/document/2?doc_id=261 + + Args: + ts_code: 行业板块代码,如 '885800.TI' + + Returns: + 成分股列表 DataFrame + """ + try: + df = self.pro.ths_member(ts_code=ts_code) + return df + except Exception as exc: + print(f"获取行业板块 {ts_code} 成分股失败: {exc}") + return pd.DataFrame() + + def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame: + """ + 将 Tushare 返回的数据转换为数据库入库格式 + + Args: + member_df: 成分股信息(包含 ts_code, con_code, con_name 等) + industry_name: 行业板块名称(从行业板块列表中获取) + + Returns: + 转换后的 DataFrame + """ + if member_df.empty: + return pd.DataFrame() + + # 检查必要字段 + if "ts_code" not in member_df.columns: + print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}") + return pd.DataFrame() + if "con_code" not in member_df.columns: + print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}") + return pd.DataFrame() + + result = pd.DataFrame() + # ts_code 是行业板块代码 + result["industry_code"] = member_df["ts_code"] + result["industry_name"] = industry_name + # con_code 是股票代码 + result["stock_ts_code"] = member_df["con_code"] + result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db) + # con_name 是股票名称 + result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else "" + # is_new 是否最新 + result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None + result["created_at"] = datetime.now() + result["updated_at"] = datetime.now() + return result + + def save_dataframe(self, df: pd.DataFrame) -> None: + """ + 将数据写入数据库 + """ + if df.empty: + return + df.to_sql(self.table_name, self.engine, if_exists="append", index=False) + + def run_full_collection(self) -> None: + """ + 执行全量覆盖采集: + - 清空目标表 + - 获取所有行业板块列表 + - 遍历每个行业板块,获取成分股数据 + - 全量写入数据库 + """ + print("=" * 60) + print("开始执行全量覆盖采集(同花顺行业板块成分股)") + print("=" * 60) + + try: + # 清空表 + with self.engine.begin() as conn: + conn.execute(text(f"TRUNCATE TABLE {self.table_name}")) + print(f"{self.table_name} 已清空") + + # 获取所有行业板块列表 + industry_list_df = self.fetch_industry_index_list() + if industry_list_df.empty: + print("未获取到行业板块列表,采集终止") + return + + total_records = 0 + success_count = 0 + failed_count = 0 + + # 遍历每个行业板块,获取成分股 + for idx, row in industry_list_df.iterrows(): + industry_code = row["ts_code"] + industry_name = row["name"] + + print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})") + + try: + # 获取该行业板块的成分股 + member_df = self.fetch_industry_members(industry_code) + if member_df.empty: + print(f" 行业板块 {industry_name} 无成分股数据") + failed_count += 1 + continue + + # 转换数据格式(传入行业板块名称) + result_df = self.transform_data(member_df, industry_name) + if result_df.empty: + print(f" 行业板块 {industry_name} 数据转换失败") + failed_count += 1 + continue + + # 保存数据 + self.save_dataframe(result_df) + total_records += len(result_df) + success_count += 1 + print(f" 成功采集 {len(result_df)} 只成分股") + + except Exception as exc: + print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}") + failed_count += 1 + continue + + print("\n" + "=" * 60) + print("全量覆盖采集完成") + print(f"总行业板块数: {len(industry_list_df)}") + print(f"成功采集: {success_count}") + print(f"失败: {failed_count}") + print(f"累计成分股记录: {total_records}") + print("=" * 60) + except Exception as exc: + print(f"全量采集失败: {exc}") + import traceback + traceback.print_exc() + finally: + self.engine.dispose() + + +def collect_ths_industry_member( + db_url: str, + tushare_token: str, +): + """ + 采集入口 - 全量覆盖采集 + """ + collector = THSIndustryMemberCollector(db_url, tushare_token) + collector.run_full_collection() + + +if __name__ == "__main__": + # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" + DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + TOKEN = TUSHARE_TOKEN + + # 执行全量覆盖采集 + collect_ths_industry_member(DB_URL, TOKEN) + diff --git a/src/tushare_scripts/transfer_ths_to_gp_bk.py b/src/tushare_scripts/transfer_ths_to_gp_bk.py new file mode 100644 index 0000000..7b84c3b --- /dev/null +++ b/src/tushare_scripts/transfer_ths_to_gp_bk.py @@ -0,0 +1,300 @@ +# coding:utf-8 +""" +同花顺板块数据转存脚本 +功能:将 ths_industry_member 和 ths_concept_member 表的数据转存到 gp_hybk 和 gp_gnbk 表。但是提前需要处理ths_concept_member_collector.py和ths_industry_member_collector.py里面的数据。 +说明:全量清理目标表后重新插入,保持数据结构对齐 +""" + +import os +import sys +import re +from typing import Optional + +import pandas as pd +from sqlalchemy import create_engine, text + +# 添加项目根目录到路径 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(PROJECT_ROOT) + + +class THSToGPBKTransfer: + """同花顺板块数据转存器""" + + def __init__(self, db_url: str): + """ + Args: + db_url: 数据库连接 URL + """ + self.engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + ) + + print("=" * 60) + print("同花顺板块数据转存工具") + print("=" * 60) + + @staticmethod + def extract_bk_code(code: str) -> Optional[int]: + """ + 从板块代码中提取数字部分(如:885800.TI -> 885800) + + Args: + code: 板块代码,如 '885800.TI' 或 '885800' + + Returns: + 提取的数字,如果无法提取则返回 None + """ + if not code: + return None + # 提取数字部分 + match = re.search(r'(\d+)', str(code)) + if match: + try: + return int(match.group(1)) + except ValueError: + return None + return None + + def transfer_industry_data(self) -> dict: + """ + 将 ths_industry_member 表数据转存到 gp_hybk 表 + + Returns: + 转存结果统计 + """ + print("\n" + "=" * 60) + print("开始转存行业板块数据 (ths_industry_member -> gp_hybk)") + print("=" * 60) + + result = { + 'success': False, + 'source_count': 0, + 'transferred_count': 0, + 'error': None + } + + try: + # 1. 读取源表数据 + print("正在读取 ths_industry_member 表数据...") + source_query = text(""" + SELECT + industry_code, + industry_name, + stock_symbol, + stock_name + FROM ths_industry_member + WHERE industry_code IS NOT NULL + AND stock_symbol IS NOT NULL + """) + + with self.engine.connect() as conn: + source_df = pd.read_sql(source_query, conn) + + result['source_count'] = len(source_df) + print(f"源表共 {result['source_count']} 条记录") + + if source_df.empty: + print("源表无数据,转存终止") + result['success'] = True + return result + + # 2. 数据转换 + print("正在转换数据格式...") + target_df = pd.DataFrame() + target_df['bk_code'] = source_df['industry_code'].apply(self.extract_bk_code) + target_df['bk_name'] = source_df['industry_name'] + target_df['gp_code'] = source_df['stock_symbol'] + target_df['gp_name'] = source_df['stock_name'] + + # 过滤掉 bk_code 为 None 的记录 + target_df = target_df[target_df['bk_code'].notna()] + + print(f"转换后共 {len(target_df)} 条有效记录") + + # 3. 清空目标表 + print("正在清空 gp_hybk 表...") + with self.engine.begin() as conn: + conn.execute(text("TRUNCATE TABLE gp_hybk")) + print("gp_hybk 表已清空") + + # 4. 插入数据 + print("正在插入数据到 gp_hybk 表...") + target_df.to_sql( + 'gp_hybk', + self.engine, + if_exists='append', + index=False, + method='multi', + chunksize=1000 + ) + + result['transferred_count'] = len(target_df) + result['success'] = True + print(f"成功转存 {result['transferred_count']} 条记录到 gp_hybk 表") + + except Exception as e: + result['error'] = str(e) + print(f"转存行业板块数据失败: {e}") + import traceback + traceback.print_exc() + + return result + + def transfer_concept_data(self) -> dict: + """ + 将 ths_concept_member 表数据转存到 gp_gnbk 表 + + Returns: + 转存结果统计 + """ + print("\n" + "=" * 60) + print("开始转存概念板块数据 (ths_concept_member -> gp_gnbk)") + print("=" * 60) + + result = { + 'success': False, + 'source_count': 0, + 'transferred_count': 0, + 'error': None + } + + try: + # 1. 读取源表数据 + print("正在读取 ths_concept_member 表数据...") + source_query = text(""" + SELECT + concept_code, + concept_name, + stock_symbol, + stock_name + FROM ths_concept_member + WHERE concept_code IS NOT NULL + AND stock_symbol IS NOT NULL + """) + + with self.engine.connect() as conn: + source_df = pd.read_sql(source_query, conn) + + result['source_count'] = len(source_df) + print(f"源表共 {result['source_count']} 条记录") + + if source_df.empty: + print("源表无数据,转存终止") + result['success'] = True + return result + + # 2. 数据转换 + print("正在转换数据格式...") + target_df = pd.DataFrame() + target_df['bk_code'] = source_df['concept_code'].apply(self.extract_bk_code) + target_df['bk_name'] = source_df['concept_name'] + target_df['gp_code'] = source_df['stock_symbol'] + target_df['gp_name'] = source_df['stock_name'] + + # 过滤掉 bk_code 为 None 的记录 + target_df = target_df[target_df['bk_code'].notna()] + + print(f"转换后共 {len(target_df)} 条有效记录") + + # 3. 清空目标表 + print("正在清空 gp_gnbk 表...") + with self.engine.begin() as conn: + conn.execute(text("TRUNCATE TABLE gp_gnbk")) + print("gp_gnbk 表已清空") + + # 4. 插入数据 + print("正在插入数据到 gp_gnbk 表...") + target_df.to_sql( + 'gp_gnbk', + self.engine, + if_exists='append', + index=False, + method='multi', + chunksize=1000 + ) + + result['transferred_count'] = len(target_df) + result['success'] = True + print(f"成功转存 {result['transferred_count']} 条记录到 gp_gnbk 表") + + except Exception as e: + result['error'] = str(e) + print(f"转存概念板块数据失败: {e}") + import traceback + traceback.print_exc() + + return result + + def run_full_transfer(self) -> dict: + """ + 执行全量转存:行业 + 概念 + + Returns: + 转存结果汇总 + """ + print("\n" + "=" * 60) + print("开始执行全量转存") + print("=" * 60) + + summary = { + 'industry': None, + 'concept': None, + 'all_success': False + } + + # 转存行业数据 + industry_result = self.transfer_industry_data() + summary['industry'] = industry_result + + # 转存概念数据 + concept_result = self.transfer_concept_data() + summary['concept'] = concept_result + + # 汇总结果 + summary['all_success'] = industry_result.get('success') and concept_result.get('success') + + # 打印汇总 + print("\n" + "=" * 60) + print("转存完成汇总") + print("=" * 60) + print(f"行业板块转存: {'成功' if industry_result.get('success') else '失败'}") + if industry_result.get('success'): + print(f" - 源表记录数: {industry_result.get('source_count', 0)}") + print(f" - 转存记录数: {industry_result.get('transferred_count', 0)}") + if industry_result.get('error'): + print(f" - 错误: {industry_result.get('error')}") + + print(f"\n概念板块转存: {'成功' if concept_result.get('success') else '失败'}") + if concept_result.get('success'): + print(f" - 源表记录数: {concept_result.get('source_count', 0)}") + print(f" - 转存记录数: {concept_result.get('transferred_count', 0)}") + if concept_result.get('error'): + print(f" - 错误: {concept_result.get('error')}") + + print("=" * 60) + + return summary + + +def transfer_ths_to_gp_bk(db_url: str): + """ + 转存入口函数 + + Args: + db_url: 数据库连接 URL + """ + transfer = THSToGPBKTransfer(db_url) + transfer.run_full_transfer() + + +if __name__ == "__main__": + # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" + DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + + # 执行全量转存 + transfer_ths_to_gp_bk(DB_URL) + diff --git a/src/valuation_analysis/industry_analysis.py b/src/valuation_analysis/industry_analysis.py index 2177d48..a7df201 100644 --- a/src/valuation_analysis/industry_analysis.py +++ b/src/valuation_analysis/industry_analysis.py @@ -14,10 +14,21 @@ import datetime import logging import json import redis -import time -from typing import Tuple, Dict, List, Optional, Union +from typing import Tuple, Dict, List +import matplotlib.pyplot as plt +import seaborn as sns +from pathlib import Path +import os +import sys -from .config import DB_URL, OUTPUT_DIR, LOG_FILE +# 处理相对导入和绝对导入 +try: + from .config import DB_URL, OUTPUT_DIR, LOG_FILE +except ImportError: + # 直接执行时使用绝对导入 + current_dir = Path(__file__).parent + sys.path.insert(0, str(current_dir.parent.parent)) + from src.valuation_analysis.config import DB_URL, OUTPUT_DIR, LOG_FILE # 配置日志 logging.basicConfig( @@ -59,6 +70,132 @@ class IndustryAnalyzer: ) logger.info("行业估值分析器初始化完成") + def _setup_chinese_font(self): + """ + 设置matplotlib中文字体,解决中文显示乱码问题 + 使用setup_fonts工具确保字体存在,并优先使用项目目录中的字体文件 + """ + try: + from matplotlib import font_manager + from matplotlib.font_manager import FontProperties + + # 1. 确保字体文件存在(使用setup_fonts工具) + try: + # 获取项目路径 + current_file = Path(__file__) + project_root = current_file.parent.parent + setup_fonts_path = project_root / "fundamentals_llm" / "setup_fonts.py" + + if setup_fonts_path.exists(): + # 动态导入setup_fonts模块 + import importlib.util + spec = importlib.util.spec_from_file_location("setup_fonts", str(setup_fonts_path)) + setup_fonts = importlib.util.module_from_spec(spec) + spec.loader.exec_module(setup_fonts) + + # 确保字体已安装到fundamentals_llm/fonts目录 + font_dir = project_root / "fundamentals_llm" / "fonts" + font_file = setup_fonts.install_font(str(font_dir)) + + if font_file and os.path.exists(font_file): + logger.info(f"setup_fonts已确保字体存在: {font_file}") + else: + logger.warning("setup_fonts未能安装字体,尝试查找现有字体") + else: + logger.warning(f"未找到setup_fonts.py: {setup_fonts_path}") + except Exception as e: + logger.warning(f"调用setup_fonts失败: {e},尝试查找现有字体") + + # 2. 查找项目中的字体文件(优先使用) + current_file = Path(__file__) + project_root = current_file.parent.parent + + # 可能的字体文件位置 + font_paths = [ + project_root / "fundamentals_llm" / "fonts" / "simhei.ttf", # 现有字体目录 + project_root / "valuation_analysis" / "fonts" / "simhei.ttf", # 本模块字体目录 + ] + + font_file = None + for font_path in font_paths: + if font_path.exists(): + font_file = str(font_path) + logger.info(f"找到项目字体文件: {font_file}") + break + + # 3. 如果找到项目字体文件,直接使用 + if font_file and os.path.exists(font_file): + try: + # 使用字体文件路径直接注册 + font_prop = FontProperties(fname=font_file) + # 获取字体名称 + font_name = font_prop.get_name() + # 设置matplotlib使用该字体 + plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif'] + logger.info(f"成功使用项目字体: {font_name} (来自: {font_file})") + except Exception as e: + logger.warning(f"加载项目字体文件失败: {e},尝试使用系统字体") + font_file = None + + # 4. 如果项目字体不可用,尝试使用系统字体 + if font_file is None: + import platform + system = platform.system() + font_candidates = [] + if system == "Windows": + font_candidates = [ + 'Microsoft YaHei', + 'SimHei', + 'SimSun', + 'KaiTi', + 'FangSong' + ] + elif system == "Darwin": # macOS + font_candidates = [ + 'PingFang SC', + 'STHeiti', + 'Arial Unicode MS', + 'Heiti SC' + ] + else: # Linux + font_candidates = [ + 'WenQuanYi Micro Hei', + 'WenQuanYi Zen Hei', + 'Noto Sans CJK SC', + 'Droid Sans Fallback' + ] + + # 尝试设置系统字体 + for font_name in font_candidates: + try: + font_list = [f.name for f in font_manager.fontManager.ttflist] + if font_name in font_list: + plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif'] + logger.info(f"使用系统字体: {font_name}") + break + except Exception as e: + continue + + # 5. 如果都没找到,使用默认设置 + if 'font.sans-serif' not in plt.rcParams or len(plt.rcParams['font.sans-serif']) == 0: + plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'DejaVu Sans'] + logger.warning("使用默认中文字体设置") + + # 解决负号显示问题 + plt.rcParams['axes.unicode_minus'] = False + + # 清除matplotlib字体缓存,强制重新加载字体 + try: + font_manager._rebuild() + logger.info("已清除matplotlib字体缓存") + except Exception as e: + logger.warning(f"清除字体缓存失败: {e}") + + except Exception as e: + logger.warning(f"设置中文字体失败,使用默认设置: {e}") + plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'DejaVu Sans'] + plt.rcParams['axes.unicode_minus'] = False + def get_industry_list(self) -> List[Dict]: """ 获取所有行业列表 @@ -1233,4 +1370,560 @@ class IndustryAnalyzer: continue except Exception as e: logger.error(f"筛选拥挤度缓存时出错: {e}") - return result \ No newline at end of file + return result + + def _get_all_industries_concepts(self, days: int = 120) -> Tuple[List[str], List[str]]: + """ + 从Redis缓存中获取所有有效的行业和概念板块 + + Args: + days: 需要分析的天数,用于验证数据有效性 + + Returns: + (all_industries, all_concepts) 元组 + """ + try: + # 计算结束日期和开始日期 + end_date = datetime.datetime.now() + start_date = end_date - datetime.timedelta(days=days) + + all_industries = [] + all_concepts = [] + + # 获取所有行业和概念板块的缓存key + industry_keys = redis_client.keys('industry_crowding:*') + concept_keys = redis_client.keys('concept_crowding:*') + + # 处理行业 + for key in industry_keys: + try: + name = key.split(':', 1)[1] + cached_data = redis_client.get(key) + if not cached_data: + continue + + df = pd.DataFrame(json.loads(cached_data)) + if df.empty: + continue + + # 确保trade_date是日期类型 + if 'trade_date' in df.columns: + df['trade_date'] = pd.to_datetime(df['trade_date']) + + # 筛选近120天的数据 + df = df[df['trade_date'] >= start_date] + df = df.sort_values('trade_date') + + if len(df) < 10: # 数据点太少,跳过 + continue + + all_industries.append(name) + except Exception as e: + logger.warning(f"处理行业缓存 {key} 时出错: {e}") + continue + + # 处理概念板块 + for key in concept_keys: + try: + name = key.split(':', 1)[1] + cached_data = redis_client.get(key) + if not cached_data: + continue + + df = pd.DataFrame(json.loads(cached_data)) + if df.empty: + continue + + # 确保trade_date是日期类型 + if 'trade_date' in df.columns: + df['trade_date'] = pd.to_datetime(df['trade_date']) + + # 筛选近120天的数据 + df = df[df['trade_date'] >= start_date] + df = df.sort_values('trade_date') + + if len(df) < 10: # 数据点太少,跳过 + continue + + all_concepts.append(name) + except Exception as e: + logger.warning(f"处理概念缓存 {key} 时出错: {e}") + continue + + logger.info(f"获取到 {len(all_industries)} 个行业和 {len(all_concepts)} 个概念板块") + return all_industries, all_concepts + + except Exception as e: + logger.error(f"获取行业/概念板块列表失败: {e}") + return [], [] + + def _select_top_industries_concepts(self, count: int = 40, days: int = 120) -> Tuple[List[str], List[str]]: + """ + 从Redis缓存中选择拥挤度变化幅度最大的行业和概念板块 + + Args: + count: 选择的行业/概念数量 + days: 需要分析的天数 + + Returns: + (selected_industries, selected_concepts) 元组 + """ + try: + # 计算结束日期和开始日期 + end_date = datetime.datetime.now() + start_date = end_date - datetime.timedelta(days=days) + + industry_scores = [] + concept_scores = [] + + # 获取所有行业和概念板块的缓存key + industry_keys = redis_client.keys('industry_crowding:*') + concept_keys = redis_client.keys('concept_crowding:*') + + # 处理行业 + for key in industry_keys: + try: + name = key.split(':', 1)[1] + cached_data = redis_client.get(key) + if not cached_data: + continue + + df = pd.DataFrame(json.loads(cached_data)) + if df.empty: + continue + + # 确保trade_date是日期类型 + if 'trade_date' in df.columns: + df['trade_date'] = pd.to_datetime(df['trade_date']) + + # 筛选近120天的数据 + df = df[df['trade_date'] >= start_date] + df = df.sort_values('trade_date') + + if len(df) < 10: # 数据点太少,跳过 + continue + + # 计算拥挤度变化幅度(最大值-最小值) + percentile_values = df['percentile'].values + if len(percentile_values) > 0: + change_range = float(percentile_values.max() - percentile_values.min()) + current_value = float(percentile_values[-1]) + + # 综合评分:变化幅度 + 当前值的绝对值(突出极端情况) + score = change_range * 0.7 + abs(current_value - 50) * 0.3 + + industry_scores.append({ + 'name': name, + 'score': score, + 'change_range': change_range, + 'current': current_value + }) + except Exception as e: + logger.warning(f"处理行业缓存 {key} 时出错: {e}") + continue + + # 处理概念板块 + for key in concept_keys: + try: + name = key.split(':', 1)[1] + cached_data = redis_client.get(key) + if not cached_data: + continue + + df = pd.DataFrame(json.loads(cached_data)) + if df.empty: + continue + + # 确保trade_date是日期类型 + if 'trade_date' in df.columns: + df['trade_date'] = pd.to_datetime(df['trade_date']) + + # 筛选近120天的数据 + df = df[df['trade_date'] >= start_date] + df = df.sort_values('trade_date') + + if len(df) < 10: # 数据点太少,跳过 + continue + + # 计算拥挤度变化幅度 + percentile_values = df['percentile'].values + if len(percentile_values) > 0: + change_range = float(percentile_values.max() - percentile_values.min()) + current_value = float(percentile_values[-1]) + + # 综合评分 + score = change_range * 0.7 + abs(current_value - 50) * 0.3 + + concept_scores.append({ + 'name': name, + 'score': score, + 'change_range': change_range, + 'current': current_value + }) + except Exception as e: + logger.warning(f"处理概念缓存 {key} 时出错: {e}") + continue + + # 按评分排序,选择前count个 + industry_scores.sort(key=lambda x: x['score'], reverse=True) + concept_scores.sort(key=lambda x: x['score'], reverse=True) + + selected_industries = [item['name'] for item in industry_scores[:count]] + selected_concepts = [item['name'] for item in concept_scores[:count]] + + logger.info(f"选择了 {len(selected_industries)} 个行业和 {len(selected_concepts)} 个概念板块") + + return selected_industries, selected_concepts + + except Exception as e: + logger.error(f"选择行业/概念板块时出错: {e}") + return [], [] + + def _build_heatmap_data(self, names: List[str], is_concept: bool, days: int = 120) -> pd.DataFrame: + """ + 构建热力图数据矩阵 + + Args: + names: 行业/概念板块名称列表 + is_concept: 是否为概念板块 + days: 需要分析的天数 + + Returns: + 包含热力图数据的DataFrame,行为行业/概念,列为日期,值为拥挤度百分位 + """ + try: + # 计算日期范围 + end_date = datetime.datetime.now() + start_date = end_date - datetime.timedelta(days=days) + + # 获取所有交易日期(从第一个有效数据中获取) + all_dates = set() + data_dict = {} + + prefix = 'concept_crowding:' if is_concept else 'industry_crowding:' + + for name in names: + cache_key = f"{prefix}{name}" + cached_data = redis_client.get(cache_key) + + if not cached_data: + continue + + try: + df = pd.DataFrame(json.loads(cached_data)) + if df.empty: + continue + + # 确保trade_date是日期类型 + if 'trade_date' in df.columns: + df['trade_date'] = pd.to_datetime(df['trade_date']) + + # 筛选日期范围 + df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)] + df = df.sort_values('trade_date') + + if df.empty: + continue + + # 收集所有日期 + all_dates.update(df['trade_date'].tolist()) + + # 存储数据 + data_dict[name] = df.set_index('trade_date')['percentile'].to_dict() + + except Exception as e: + logger.warning(f"处理 {name} 数据时出错: {e}") + continue + + if not all_dates: + logger.warning("没有找到有效数据") + return pd.DataFrame() + + # 创建完整的日期序列(只包含交易日) + all_dates = sorted(list(all_dates)) + + # 构建矩阵 + matrix_data = [] + for name in names: + row = [] + for date in all_dates: + if name in data_dict and date in data_dict[name]: + row.append(float(data_dict[name][date])) + else: + row.append(np.nan) + matrix_data.append(row) + + # 创建DataFrame + heatmap_df = pd.DataFrame(matrix_data, index=names, columns=all_dates) + + logger.info(f"构建热力图数据完成: {heatmap_df.shape}") + return heatmap_df + + except Exception as e: + logger.error(f"构建热力图数据失败: {e}") + return pd.DataFrame() + + def plot_crowding_heatmap(self, days: int = 120, items_per_page: int = 40, output_dir: str = None): + """ + 绘制行业和概念板块拥挤度热力图 + 每40个行业/概念生成一张图 + + Args: + days: 分析的天数,默认120天 + items_per_page: 每张图显示的行业/概念数量,默认40个 + output_dir: 输出目录,默认为src/data目录 + """ + try: + # 设置输出目录 + if output_dir is None: + # 获取项目根目录 + current_file = Path(__file__) + project_root = current_file.parent.parent.parent + output_dir = project_root / "src" / "data" + else: + output_dir = Path(output_dir) + + os.makedirs(output_dir, exist_ok=True) + + # 获取所有行业和概念 + all_industries, all_concepts = self._get_all_industries_concepts(days=days) + + if not all_industries and not all_concepts: + logger.warning("没有找到有效的行业或概念板块数据") + return + + # 设置中文字体(使用setup_fonts工具) + self._setup_chinese_font() + + # 设置seaborn样式 + sns.set_style("whitegrid") + + # 生成时间戳用于文件名 + timestamp = datetime.datetime.now().strftime('%Y%m%d') + + # 绘制行业热力图(按每items_per_page个分组) + if all_industries: + total_industries = len(all_industries) + total_pages = (total_industries + items_per_page - 1) // items_per_page # 向上取整 + logger.info(f"开始绘制行业热力图,共 {total_industries} 个行业,将生成 {total_pages} 张图") + + for page in range(total_pages): + start_idx = page * items_per_page + end_idx = min(start_idx + items_per_page, total_industries) + page_industries = all_industries[start_idx:end_idx] + + logger.info(f"绘制行业热力图第 {page + 1}/{total_pages} 张(行业 {start_idx + 1}-{end_idx})") + industry_heatmap = self._build_heatmap_data(page_industries, is_concept=False, days=days) + + if not industry_heatmap.empty: + self._plot_single_heatmap( + industry_heatmap, + title=f"行业拥挤度热力图(近{days}天,第{page + 1}/{total_pages}页,{len(page_industries)}个行业)", + output_path=output_dir / f"industry_crowding_heatmap_{timestamp}_page{page + 1:02d}.png", + is_concept=False + ) + + # 绘制概念板块热力图(按每items_per_page个分组) + if all_concepts: + total_concepts = len(all_concepts) + total_pages = (total_concepts + items_per_page - 1) // items_per_page # 向上取整 + logger.info(f"开始绘制概念板块热力图,共 {total_concepts} 个概念,将生成 {total_pages} 张图") + + for page in range(total_pages): + start_idx = page * items_per_page + end_idx = min(start_idx + items_per_page, total_concepts) + page_concepts = all_concepts[start_idx:end_idx] + + logger.info(f"绘制概念板块热力图第 {page + 1}/{total_pages} 张(概念 {start_idx + 1}-{end_idx})") + concept_heatmap = self._build_heatmap_data(page_concepts, is_concept=True, days=days) + + if not concept_heatmap.empty: + self._plot_single_heatmap( + concept_heatmap, + title=f"概念板块拥挤度热力图(近{days}天,第{page + 1}/{total_pages}页,{len(page_concepts)}个概念)", + output_path=output_dir / f"concept_crowding_heatmap_{timestamp}_page{page + 1:02d}.png", + is_concept=True + ) + + logger.info("热力图绘制完成") + + except Exception as e: + logger.error(f"绘制热力图失败: {e}") + import traceback + logger.error(traceback.format_exc()) + + def _plot_single_heatmap(self, heatmap_df: pd.DataFrame, title: str, output_path: Path, is_concept: bool): + """ + 绘制单个热力图 + + Args: + heatmap_df: 热力图数据DataFrame + title: 图表标题 + output_path: 输出路径 + is_concept: 是否为概念板块 + """ + try: + # 计算图表尺寸(根据数据量动态调整) + n_rows = len(heatmap_df) + n_cols = len(heatmap_df.columns) + + # 基础尺寸,根据数据量调整 + fig_width = max(20, n_cols * 0.15) + fig_height = max(12, n_rows * 0.4) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + # 获取当前设置的字体 + from matplotlib.font_manager import FontProperties + current_font = plt.rcParams['font.sans-serif'][0] if plt.rcParams['font.sans-serif'] else 'SimHei' + font_prop = FontProperties(fname=str(Path(__file__).parent.parent / "fundamentals_llm" / "fonts" / "simhei.ttf")) if (Path(__file__).parent.parent / "fundamentals_llm" / "fonts" / "simhei.ttf").exists() else None + + # heatmap_df: 行为行业/概念,列为日期 + # seaborn heatmap: 行在y轴,列在x轴 + # 所以:y轴是行业/概念,x轴是日期(符合要求) + plot_data = heatmap_df + + # 使用自定义颜色映射:绿色(不拥挤)-> 黄色(中性)-> 红色(拥挤) + # 创建从绿色到红色渐变的colormap + colors = ['#2ecc71', '#f1c40f', '#e74c3c'] # 绿色、黄色、红色 + from matplotlib.colors import LinearSegmentedColormap + n_bins = 100 + cmap = LinearSegmentedColormap.from_list('crowding', colors, N=n_bins) + + # 绘制热力图 + heatmap = sns.heatmap( + plot_data, + cmap=cmap, + vmin=0, + vmax=100, + annot=False, # 不显示数值(数据太多) + fmt='.1f', + cbar_kws={ + 'label': '拥挤度百分位 (%)', + 'shrink': 0.8, + 'format': '%.0f' + }, + linewidths=0.05, + linecolor='white', + ax=ax, + xticklabels=False, # 先不显示x轴标签,后面自定义 + yticklabels=True # 显示y轴标签(行业/概念名称) + ) + + # 设置colorbar标签字体 + try: + cbar = heatmap.collections[0].colorbar + if cbar is not None: + if font_prop: + cbar.set_label('拥挤度百分位 (%)', fontproperties=font_prop, fontsize=10) + else: + cbar.set_label('拥挤度百分位 (%)', fontfamily=current_font, fontsize=10) + # 设置colorbar刻度标签字体 + for label in cbar.ax.get_yticklabels(): + if font_prop: + label.set_fontproperties(font_prop) + else: + label.set_fontfamily(current_font) + except Exception as e: + logger.warning(f"设置colorbar字体失败: {e}") + + # 设置标题和标签,显式指定字体 + if font_prop: + ax.set_title(title, fontsize=16, fontweight='bold', pad=20, fontproperties=font_prop) + ax.set_xlabel('日期(时间)', fontsize=12, fontproperties=font_prop) + ax.set_ylabel('行业/概念板块', fontsize=12, fontproperties=font_prop) + else: + ax.set_title(title, fontsize=16, fontweight='bold', pad=20, fontfamily=current_font) + ax.set_xlabel('日期(时间)', fontsize=12, fontfamily=current_font) + ax.set_ylabel('行业/概念板块', fontsize=12, fontfamily=current_font) + + # 调整y轴标签(行业/概念名称)- 显示完整名称,但可以旋转 + y_labels = plot_data.index.tolist() + if len(y_labels) > 0: + # 如果名称超过15个字符,截断 + y_labels_display = [label[:15] + '...' if len(str(label)) > 15 else str(label) for label in y_labels] + if font_prop: + ax.set_yticklabels(y_labels_display, fontsize=9, rotation=0, ha='right', fontproperties=font_prop) + else: + ax.set_yticklabels(y_labels_display, fontsize=9, rotation=0, ha='right', fontfamily=current_font) + + # 调整x轴标签(日期)- 只显示部分日期,避免过于密集 + x_labels = plot_data.columns.tolist() + if len(x_labels) > 0: + # 如果日期太多,只显示部分 + if len(x_labels) > 30: + step = max(1, len(x_labels) // 15) # 最多显示15个日期标签 + x_positions = list(range(0, len(x_labels), step)) + x_labels_display = [] + for pos in x_positions: + if pos < len(x_labels): + date_str = str(x_labels[pos]) + # 只显示月-日 + if isinstance(x_labels[pos], pd.Timestamp): + x_labels_display.append(x_labels[pos].strftime('%m-%d')) + else: + x_labels_display.append(date_str[:10] if len(date_str) > 10 else date_str) + ax.set_xticks(x_positions) + if font_prop: + ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontproperties=font_prop) + else: + ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontfamily=current_font) + else: + # 日期不多,全部显示 + x_labels_display = [] + for date in x_labels: + if isinstance(date, pd.Timestamp): + x_labels_display.append(date.strftime('%m-%d')) + else: + date_str = str(date) + x_labels_display.append(date_str[:10] if len(date_str) > 10 else date_str) + if font_prop: + ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontproperties=font_prop) + else: + ax.set_xticklabels(x_labels_display, fontsize=8, rotation=45, ha='right', fontfamily=current_font) + + plt.tight_layout() + + # 保存图片 + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.close() + + logger.info(f"热力图已保存到: {output_path}") + + except Exception as e: + logger.error(f"绘制热力图失败: {e}") + import traceback + logger.error(traceback.format_exc()) + if 'fig' in locals(): + plt.close(fig) + + +def main(): + """主函数:生成行业和概念板块拥挤度热力图""" + try: + logger.info("=" * 60) + logger.info("开始生成行业和概念板块拥挤度热力图") + logger.info("=" * 60) + + # 创建分析器实例 + analyzer = IndustryAnalyzer() + + # 绘制热力图 + # 参数说明: + # days: 分析的天数,默认120天 + # items_per_page: 每张图显示的行业/概念数量,默认40个 + # output_dir: 输出目录,默认为src/data目录 + analyzer.plot_crowding_heatmap(days=120, items_per_page=40) + + logger.info("=" * 60) + logger.info("热力图生成完成!") + logger.info("=" * 60) + + except Exception as e: + logger.error(f"执行主函数失败: {e}") + import traceback + logger.error(traceback.format_exc()) + + +if __name__ == '__main__': + main() \ No newline at end of file