# coding:utf-8 """ 同花顺行业板块成分股采集工具 功能:从 Tushare 获取同花顺行业板块成分股数据并落库 API 文档: https://tushare.pro/document/2?doc_id=261 说明:每次调用时全量覆盖,不需要每日更新 """ import os import sys from datetime import datetime from typing import Optional import pandas as pd import tushare as ts from sqlalchemy import create_engine, text # 添加项目根目录到路径,确保能够读取配置 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(PROJECT_ROOT) from src.scripts.config import TUSHARE_TOKEN class THSIndustryMemberCollector: """同花顺行业板块成分股采集器""" def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"): """ Args: db_url: 数据库连接 URL tushare_token: Tushare Token table_name: 目标表名,默认 ths_industry_member """ self.engine = create_engine( db_url, pool_size=5, max_overflow=10, pool_recycle=3600, ) self.table_name = table_name ts.set_token(tushare_token) self.pro = ts.pro_api() print("=" * 60) print("同花顺行业板块成分股采集工具") print(f"目标数据表: {self.table_name}") print("=" * 60) @staticmethod def convert_tushare_code_to_db(ts_code: str) -> str: """ 将 Tushare 代码(600000.SH)转换为数据库代码(SH600000) """ if not ts_code or "." not in ts_code: return ts_code base, market = ts_code.split(".") return f"{market}{base}" def fetch_industry_index_list(self) -> pd.DataFrame: """ 获取所有同花顺行业板块列表 API: ths_index 文档: https://tushare.pro/document/2?doc_id=259 Returns: 行业板块列表 DataFrame """ try: print("正在获取同花顺行业板块列表...") df = self.pro.ths_index( exchange='A', type='I', # I=行业板块 ) if df.empty: print("未获取到行业板块数据") return pd.DataFrame() print(f"成功获取 {len(df)} 个行业板块") return df except Exception as exc: print(f"获取行业板块列表失败: {exc}") return pd.DataFrame() def fetch_industry_members(self, ts_code: str) -> pd.DataFrame: """ 获取指定行业板块的成分股 API: ths_member 文档: https://tushare.pro/document/2?doc_id=261 Args: ts_code: 行业板块代码,如 '885800.TI' Returns: 成分股列表 DataFrame """ try: df = self.pro.ths_member(ts_code=ts_code) return df except Exception as exc: print(f"获取行业板块 {ts_code} 成分股失败: {exc}") return pd.DataFrame() def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame: """ 将 Tushare 返回的数据转换为数据库入库格式 Args: member_df: 成分股信息(包含 ts_code, con_code, con_name 等) industry_name: 行业板块名称(从行业板块列表中获取) Returns: 转换后的 DataFrame """ if member_df.empty: return pd.DataFrame() # 检查必要字段 if "ts_code" not in member_df.columns: print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}") return pd.DataFrame() if "con_code" not in member_df.columns: print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}") return pd.DataFrame() result = pd.DataFrame() # ts_code 是行业板块代码 result["industry_code"] = member_df["ts_code"] result["industry_name"] = industry_name # con_code 是股票代码 result["stock_ts_code"] = member_df["con_code"] result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db) # con_name 是股票名称 result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else "" # is_new 是否最新 result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None result["created_at"] = datetime.now() result["updated_at"] = datetime.now() return result def save_dataframe(self, df: pd.DataFrame) -> None: """ 将数据写入数据库 """ if df.empty: return df.to_sql(self.table_name, self.engine, if_exists="append", index=False) def run_full_collection(self) -> None: """ 执行全量覆盖采集: - 清空目标表 - 获取所有行业板块列表 - 遍历每个行业板块,获取成分股数据 - 全量写入数据库 """ print("=" * 60) print("开始执行全量覆盖采集(同花顺行业板块成分股)") print("=" * 60) try: # 清空表 with self.engine.begin() as conn: conn.execute(text(f"TRUNCATE TABLE {self.table_name}")) print(f"{self.table_name} 已清空") # 获取所有行业板块列表 industry_list_df = self.fetch_industry_index_list() if industry_list_df.empty: print("未获取到行业板块列表,采集终止") return total_records = 0 success_count = 0 failed_count = 0 # 遍历每个行业板块,获取成分股 for idx, row in industry_list_df.iterrows(): industry_code = row["ts_code"] industry_name = row["name"] print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})") try: # 获取该行业板块的成分股 member_df = self.fetch_industry_members(industry_code) if member_df.empty: print(f" 行业板块 {industry_name} 无成分股数据") failed_count += 1 continue # 转换数据格式(传入行业板块名称) result_df = self.transform_data(member_df, industry_name) if result_df.empty: print(f" 行业板块 {industry_name} 数据转换失败") failed_count += 1 continue # 保存数据 self.save_dataframe(result_df) total_records += len(result_df) success_count += 1 print(f" 成功采集 {len(result_df)} 只成分股") except Exception as exc: print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}") failed_count += 1 continue print("\n" + "=" * 60) print("全量覆盖采集完成") print(f"总行业板块数: {len(industry_list_df)}") print(f"成功采集: {success_count}") print(f"失败: {failed_count}") print(f"累计成分股记录: {total_records}") print("=" * 60) except Exception as exc: print(f"全量采集失败: {exc}") import traceback traceback.print_exc() finally: self.engine.dispose() def collect_ths_industry_member( db_url: str, tushare_token: str, ): """ 采集入口 - 全量覆盖采集 """ collector = THSIndustryMemberCollector(db_url, tushare_token) collector.run_full_collection() if __name__ == "__main__": # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" TOKEN = TUSHARE_TOKEN # 执行全量覆盖采集 collect_ths_industry_member(DB_URL, TOKEN)