stock_fundamentals/src/tushare_scripts/ths_industry_member_collect...

244 lines
8.4 KiB
Python
Raw Normal View History

2026-01-16 15:42:04 +08:00
# coding:utf-8
"""
同花顺行业板块成分股采集工具
功能 Tushare 获取同花顺行业板块成分股数据并落库
API 文档: https://tushare.pro/document/2?doc_id=261
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSIndustryMemberCollector:
"""同花顺行业板块成分股采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 ths_industry_member
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺行业板块成分股采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
def fetch_industry_index_list(self) -> pd.DataFrame:
"""
获取所有同花顺行业板块列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Returns:
行业板块列表 DataFrame
"""
try:
print("正在获取同花顺行业板块列表...")
df = self.pro.ths_index(
exchange='A',
type='I', # I=行业板块
)
if df.empty:
print("未获取到行业板块数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个行业板块")
return df
except Exception as exc:
print(f"获取行业板块列表失败: {exc}")
return pd.DataFrame()
def fetch_industry_members(self, ts_code: str) -> pd.DataFrame:
"""
获取指定行业板块的成分股
API: ths_member
文档: https://tushare.pro/document/2?doc_id=261
Args:
ts_code: 行业板块代码 '885800.TI'
Returns:
成分股列表 DataFrame
"""
try:
df = self.pro.ths_member(ts_code=ts_code)
return df
except Exception as exc:
print(f"获取行业板块 {ts_code} 成分股失败: {exc}")
return pd.DataFrame()
def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
member_df: 成分股信息包含 ts_code, con_code, con_name
industry_name: 行业板块名称从行业板块列表中获取
Returns:
转换后的 DataFrame
"""
if member_df.empty:
return pd.DataFrame()
# 检查必要字段
if "ts_code" not in member_df.columns:
print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
if "con_code" not in member_df.columns:
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
result = pd.DataFrame()
# ts_code 是行业板块代码
result["industry_code"] = member_df["ts_code"]
result["industry_name"] = industry_name
# con_code 是股票代码
result["stock_ts_code"] = member_df["con_code"]
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
# con_name 是股票名称
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
# is_new 是否最新
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(self) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取所有行业板块列表
- 遍历每个行业板块获取成分股数据
- 全量写入数据库
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺行业板块成分股)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取所有行业板块列表
industry_list_df = self.fetch_industry_index_list()
if industry_list_df.empty:
print("未获取到行业板块列表,采集终止")
return
total_records = 0
success_count = 0
failed_count = 0
# 遍历每个行业板块,获取成分股
for idx, row in industry_list_df.iterrows():
industry_code = row["ts_code"]
industry_name = row["name"]
print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})")
try:
# 获取该行业板块的成分股
member_df = self.fetch_industry_members(industry_code)
if member_df.empty:
print(f" 行业板块 {industry_name} 无成分股数据")
failed_count += 1
continue
# 转换数据格式(传入行业板块名称)
result_df = self.transform_data(member_df, industry_name)
if result_df.empty:
print(f" 行业板块 {industry_name} 数据转换失败")
failed_count += 1
continue
# 保存数据
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
print(f" 成功采集 {len(result_df)} 只成分股")
except Exception as exc:
print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}")
failed_count += 1
continue
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print(f"总行业板块数: {len(industry_list_df)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计成分股记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_industry_member(
db_url: str,
tushare_token: str,
):
"""
采集入口 - 全量覆盖采集
"""
collector = THSIndustryMemberCollector(db_url, tushare_token)
collector.run_full_collection()
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集
collect_ths_industry_member(DB_URL, TOKEN)