stock_fundamentals/src/tushare_scripts/ths_industry_member_collect...

244 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding:utf-8
"""
同花顺行业板块成分股采集工具
功能:从 Tushare 获取同花顺行业板块成分股数据并落库
API 文档: https://tushare.pro/document/2?doc_id=261
说明:每次调用时全量覆盖,不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSIndustryMemberCollector:
"""同花顺行业板块成分股采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名,默认 ths_industry_member
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺行业板块成分股采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
将 Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
def fetch_industry_index_list(self) -> pd.DataFrame:
"""
获取所有同花顺行业板块列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Returns:
行业板块列表 DataFrame
"""
try:
print("正在获取同花顺行业板块列表...")
df = self.pro.ths_index(
exchange='A',
type='I', # I=行业板块
)
if df.empty:
print("未获取到行业板块数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个行业板块")
return df
except Exception as exc:
print(f"获取行业板块列表失败: {exc}")
return pd.DataFrame()
def fetch_industry_members(self, ts_code: str) -> pd.DataFrame:
"""
获取指定行业板块的成分股
API: ths_member
文档: https://tushare.pro/document/2?doc_id=261
Args:
ts_code: 行业板块代码,如 '885800.TI'
Returns:
成分股列表 DataFrame
"""
try:
df = self.pro.ths_member(ts_code=ts_code)
return df
except Exception as exc:
print(f"获取行业板块 {ts_code} 成分股失败: {exc}")
return pd.DataFrame()
def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame:
"""
将 Tushare 返回的数据转换为数据库入库格式
Args:
member_df: 成分股信息(包含 ts_code, con_code, con_name 等)
industry_name: 行业板块名称(从行业板块列表中获取)
Returns:
转换后的 DataFrame
"""
if member_df.empty:
return pd.DataFrame()
# 检查必要字段
if "ts_code" not in member_df.columns:
print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
if "con_code" not in member_df.columns:
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
result = pd.DataFrame()
# ts_code 是行业板块代码
result["industry_code"] = member_df["ts_code"]
result["industry_name"] = industry_name
# con_code 是股票代码
result["stock_ts_code"] = member_df["con_code"]
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
# con_name 是股票名称
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
# is_new 是否最新
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(self) -> None:
"""
执行全量覆盖采集:
- 清空目标表
- 获取所有行业板块列表
- 遍历每个行业板块,获取成分股数据
- 全量写入数据库
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺行业板块成分股)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取所有行业板块列表
industry_list_df = self.fetch_industry_index_list()
if industry_list_df.empty:
print("未获取到行业板块列表,采集终止")
return
total_records = 0
success_count = 0
failed_count = 0
# 遍历每个行业板块,获取成分股
for idx, row in industry_list_df.iterrows():
industry_code = row["ts_code"]
industry_name = row["name"]
print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})")
try:
# 获取该行业板块的成分股
member_df = self.fetch_industry_members(industry_code)
if member_df.empty:
print(f" 行业板块 {industry_name} 无成分股数据")
failed_count += 1
continue
# 转换数据格式(传入行业板块名称)
result_df = self.transform_data(member_df, industry_name)
if result_df.empty:
print(f" 行业板块 {industry_name} 数据转换失败")
failed_count += 1
continue
# 保存数据
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
print(f" 成功采集 {len(result_df)} 只成分股")
except Exception as exc:
print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}")
failed_count += 1
continue
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print(f"总行业板块数: {len(industry_list_df)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计成分股记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_industry_member(
db_url: str,
tushare_token: str,
):
"""
采集入口 - 全量覆盖采集
"""
collector = THSIndustryMemberCollector(db_url, tushare_token)
collector.run_full_collection()
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集
collect_ths_industry_member(DB_URL, TOKEN)