244 lines
8.4 KiB
Python
244 lines
8.4 KiB
Python
# coding:utf-8
|
||
"""
|
||
同花顺行业板块成分股采集工具
|
||
功能:从 Tushare 获取同花顺行业板块成分股数据并落库
|
||
API 文档: https://tushare.pro/document/2?doc_id=261
|
||
说明:每次调用时全量覆盖,不需要每日更新
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
import pandas as pd
|
||
import tushare as ts
|
||
from sqlalchemy import create_engine, text
|
||
|
||
# 添加项目根目录到路径,确保能够读取配置
|
||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
sys.path.append(PROJECT_ROOT)
|
||
|
||
from src.scripts.config import TUSHARE_TOKEN
|
||
|
||
|
||
class THSIndustryMemberCollector:
|
||
"""同花顺行业板块成分股采集器"""
|
||
|
||
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"):
|
||
"""
|
||
Args:
|
||
db_url: 数据库连接 URL
|
||
tushare_token: Tushare Token
|
||
table_name: 目标表名,默认 ths_industry_member
|
||
"""
|
||
self.engine = create_engine(
|
||
db_url,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
pool_recycle=3600,
|
||
)
|
||
self.table_name = table_name
|
||
|
||
ts.set_token(tushare_token)
|
||
self.pro = ts.pro_api()
|
||
|
||
print("=" * 60)
|
||
print("同花顺行业板块成分股采集工具")
|
||
print(f"目标数据表: {self.table_name}")
|
||
print("=" * 60)
|
||
|
||
@staticmethod
|
||
def convert_tushare_code_to_db(ts_code: str) -> str:
|
||
"""
|
||
将 Tushare 代码(600000.SH)转换为数据库代码(SH600000)
|
||
"""
|
||
if not ts_code or "." not in ts_code:
|
||
return ts_code
|
||
base, market = ts_code.split(".")
|
||
return f"{market}{base}"
|
||
|
||
def fetch_industry_index_list(self) -> pd.DataFrame:
|
||
"""
|
||
获取所有同花顺行业板块列表
|
||
API: ths_index
|
||
文档: https://tushare.pro/document/2?doc_id=259
|
||
|
||
Returns:
|
||
行业板块列表 DataFrame
|
||
"""
|
||
try:
|
||
print("正在获取同花顺行业板块列表...")
|
||
df = self.pro.ths_index(
|
||
exchange='A',
|
||
type='I', # I=行业板块
|
||
)
|
||
if df.empty:
|
||
print("未获取到行业板块数据")
|
||
return pd.DataFrame()
|
||
print(f"成功获取 {len(df)} 个行业板块")
|
||
return df
|
||
except Exception as exc:
|
||
print(f"获取行业板块列表失败: {exc}")
|
||
return pd.DataFrame()
|
||
|
||
def fetch_industry_members(self, ts_code: str) -> pd.DataFrame:
|
||
"""
|
||
获取指定行业板块的成分股
|
||
API: ths_member
|
||
文档: https://tushare.pro/document/2?doc_id=261
|
||
|
||
Args:
|
||
ts_code: 行业板块代码,如 '885800.TI'
|
||
|
||
Returns:
|
||
成分股列表 DataFrame
|
||
"""
|
||
try:
|
||
df = self.pro.ths_member(ts_code=ts_code)
|
||
return df
|
||
except Exception as exc:
|
||
print(f"获取行业板块 {ts_code} 成分股失败: {exc}")
|
||
return pd.DataFrame()
|
||
|
||
def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame:
|
||
"""
|
||
将 Tushare 返回的数据转换为数据库入库格式
|
||
|
||
Args:
|
||
member_df: 成分股信息(包含 ts_code, con_code, con_name 等)
|
||
industry_name: 行业板块名称(从行业板块列表中获取)
|
||
|
||
Returns:
|
||
转换后的 DataFrame
|
||
"""
|
||
if member_df.empty:
|
||
return pd.DataFrame()
|
||
|
||
# 检查必要字段
|
||
if "ts_code" not in member_df.columns:
|
||
print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
|
||
return pd.DataFrame()
|
||
if "con_code" not in member_df.columns:
|
||
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
|
||
return pd.DataFrame()
|
||
|
||
result = pd.DataFrame()
|
||
# ts_code 是行业板块代码
|
||
result["industry_code"] = member_df["ts_code"]
|
||
result["industry_name"] = industry_name
|
||
# con_code 是股票代码
|
||
result["stock_ts_code"] = member_df["con_code"]
|
||
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
|
||
# con_name 是股票名称
|
||
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
|
||
# is_new 是否最新
|
||
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
|
||
result["created_at"] = datetime.now()
|
||
result["updated_at"] = datetime.now()
|
||
return result
|
||
|
||
def save_dataframe(self, df: pd.DataFrame) -> None:
|
||
"""
|
||
将数据写入数据库
|
||
"""
|
||
if df.empty:
|
||
return
|
||
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
||
|
||
def run_full_collection(self) -> None:
|
||
"""
|
||
执行全量覆盖采集:
|
||
- 清空目标表
|
||
- 获取所有行业板块列表
|
||
- 遍历每个行业板块,获取成分股数据
|
||
- 全量写入数据库
|
||
"""
|
||
print("=" * 60)
|
||
print("开始执行全量覆盖采集(同花顺行业板块成分股)")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 清空表
|
||
with self.engine.begin() as conn:
|
||
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
||
print(f"{self.table_name} 已清空")
|
||
|
||
# 获取所有行业板块列表
|
||
industry_list_df = self.fetch_industry_index_list()
|
||
if industry_list_df.empty:
|
||
print("未获取到行业板块列表,采集终止")
|
||
return
|
||
|
||
total_records = 0
|
||
success_count = 0
|
||
failed_count = 0
|
||
|
||
# 遍历每个行业板块,获取成分股
|
||
for idx, row in industry_list_df.iterrows():
|
||
industry_code = row["ts_code"]
|
||
industry_name = row["name"]
|
||
|
||
print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})")
|
||
|
||
try:
|
||
# 获取该行业板块的成分股
|
||
member_df = self.fetch_industry_members(industry_code)
|
||
if member_df.empty:
|
||
print(f" 行业板块 {industry_name} 无成分股数据")
|
||
failed_count += 1
|
||
continue
|
||
|
||
# 转换数据格式(传入行业板块名称)
|
||
result_df = self.transform_data(member_df, industry_name)
|
||
if result_df.empty:
|
||
print(f" 行业板块 {industry_name} 数据转换失败")
|
||
failed_count += 1
|
||
continue
|
||
|
||
# 保存数据
|
||
self.save_dataframe(result_df)
|
||
total_records += len(result_df)
|
||
success_count += 1
|
||
print(f" 成功采集 {len(result_df)} 只成分股")
|
||
|
||
except Exception as exc:
|
||
print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}")
|
||
failed_count += 1
|
||
continue
|
||
|
||
print("\n" + "=" * 60)
|
||
print("全量覆盖采集完成")
|
||
print(f"总行业板块数: {len(industry_list_df)}")
|
||
print(f"成功采集: {success_count}")
|
||
print(f"失败: {failed_count}")
|
||
print(f"累计成分股记录: {total_records}")
|
||
print("=" * 60)
|
||
except Exception as exc:
|
||
print(f"全量采集失败: {exc}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
self.engine.dispose()
|
||
|
||
|
||
def collect_ths_industry_member(
|
||
db_url: str,
|
||
tushare_token: str,
|
||
):
|
||
"""
|
||
采集入口 - 全量覆盖采集
|
||
"""
|
||
collector = THSIndustryMemberCollector(db_url, tushare_token)
|
||
collector.run_full_collection()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||
TOKEN = TUSHARE_TOKEN
|
||
|
||
# 执行全量覆盖采集
|
||
collect_ths_industry_member(DB_URL, TOKEN)
|
||
|