244 lines
8.4 KiB
Python
244 lines
8.4 KiB
Python
|
|
# coding:utf-8
|
|||
|
|
"""
|
|||
|
|
同花顺行业板块成分股采集工具
|
|||
|
|
功能:从 Tushare 获取同花顺行业板块成分股数据并落库
|
|||
|
|
API 文档: https://tushare.pro/document/2?doc_id=261
|
|||
|
|
说明:每次调用时全量覆盖,不需要每日更新
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
import pandas as pd
|
|||
|
|
import tushare as ts
|
|||
|
|
from sqlalchemy import create_engine, text
|
|||
|
|
|
|||
|
|
# 添加项目根目录到路径,确保能够读取配置
|
|||
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
sys.path.append(PROJECT_ROOT)
|
|||
|
|
|
|||
|
|
from src.scripts.config import TUSHARE_TOKEN
|
|||
|
|
|
|||
|
|
|
|||
|
|
class THSIndustryMemberCollector:
|
|||
|
|
"""同花顺行业板块成分股采集器"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_industry_member"):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
db_url: 数据库连接 URL
|
|||
|
|
tushare_token: Tushare Token
|
|||
|
|
table_name: 目标表名,默认 ths_industry_member
|
|||
|
|
"""
|
|||
|
|
self.engine = create_engine(
|
|||
|
|
db_url,
|
|||
|
|
pool_size=5,
|
|||
|
|
max_overflow=10,
|
|||
|
|
pool_recycle=3600,
|
|||
|
|
)
|
|||
|
|
self.table_name = table_name
|
|||
|
|
|
|||
|
|
ts.set_token(tushare_token)
|
|||
|
|
self.pro = ts.pro_api()
|
|||
|
|
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("同花顺行业板块成分股采集工具")
|
|||
|
|
print(f"目标数据表: {self.table_name}")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def convert_tushare_code_to_db(ts_code: str) -> str:
|
|||
|
|
"""
|
|||
|
|
将 Tushare 代码(600000.SH)转换为数据库代码(SH600000)
|
|||
|
|
"""
|
|||
|
|
if not ts_code or "." not in ts_code:
|
|||
|
|
return ts_code
|
|||
|
|
base, market = ts_code.split(".")
|
|||
|
|
return f"{market}{base}"
|
|||
|
|
|
|||
|
|
def fetch_industry_index_list(self) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
获取所有同花顺行业板块列表
|
|||
|
|
API: ths_index
|
|||
|
|
文档: https://tushare.pro/document/2?doc_id=259
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
行业板块列表 DataFrame
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
print("正在获取同花顺行业板块列表...")
|
|||
|
|
df = self.pro.ths_index(
|
|||
|
|
exchange='A',
|
|||
|
|
type='I', # I=行业板块
|
|||
|
|
)
|
|||
|
|
if df.empty:
|
|||
|
|
print("未获取到行业板块数据")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
print(f"成功获取 {len(df)} 个行业板块")
|
|||
|
|
return df
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"获取行业板块列表失败: {exc}")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
def fetch_industry_members(self, ts_code: str) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
获取指定行业板块的成分股
|
|||
|
|
API: ths_member
|
|||
|
|
文档: https://tushare.pro/document/2?doc_id=261
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ts_code: 行业板块代码,如 '885800.TI'
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
成分股列表 DataFrame
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
df = self.pro.ths_member(ts_code=ts_code)
|
|||
|
|
return df
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"获取行业板块 {ts_code} 成分股失败: {exc}")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
def transform_data(self, member_df: pd.DataFrame, industry_name: str = "") -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
将 Tushare 返回的数据转换为数据库入库格式
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
member_df: 成分股信息(包含 ts_code, con_code, con_name 等)
|
|||
|
|
industry_name: 行业板块名称(从行业板块列表中获取)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
转换后的 DataFrame
|
|||
|
|
"""
|
|||
|
|
if member_df.empty:
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
# 检查必要字段
|
|||
|
|
if "ts_code" not in member_df.columns:
|
|||
|
|
print(f"警告: 未找到行业板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
if "con_code" not in member_df.columns:
|
|||
|
|
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
result = pd.DataFrame()
|
|||
|
|
# ts_code 是行业板块代码
|
|||
|
|
result["industry_code"] = member_df["ts_code"]
|
|||
|
|
result["industry_name"] = industry_name
|
|||
|
|
# con_code 是股票代码
|
|||
|
|
result["stock_ts_code"] = member_df["con_code"]
|
|||
|
|
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
|
|||
|
|
# con_name 是股票名称
|
|||
|
|
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
|
|||
|
|
# is_new 是否最新
|
|||
|
|
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
|
|||
|
|
result["created_at"] = datetime.now()
|
|||
|
|
result["updated_at"] = datetime.now()
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def save_dataframe(self, df: pd.DataFrame) -> None:
|
|||
|
|
"""
|
|||
|
|
将数据写入数据库
|
|||
|
|
"""
|
|||
|
|
if df.empty:
|
|||
|
|
return
|
|||
|
|
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
|||
|
|
|
|||
|
|
def run_full_collection(self) -> None:
|
|||
|
|
"""
|
|||
|
|
执行全量覆盖采集:
|
|||
|
|
- 清空目标表
|
|||
|
|
- 获取所有行业板块列表
|
|||
|
|
- 遍历每个行业板块,获取成分股数据
|
|||
|
|
- 全量写入数据库
|
|||
|
|
"""
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("开始执行全量覆盖采集(同花顺行业板块成分股)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 清空表
|
|||
|
|
with self.engine.begin() as conn:
|
|||
|
|
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
|||
|
|
print(f"{self.table_name} 已清空")
|
|||
|
|
|
|||
|
|
# 获取所有行业板块列表
|
|||
|
|
industry_list_df = self.fetch_industry_index_list()
|
|||
|
|
if industry_list_df.empty:
|
|||
|
|
print("未获取到行业板块列表,采集终止")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
total_records = 0
|
|||
|
|
success_count = 0
|
|||
|
|
failed_count = 0
|
|||
|
|
|
|||
|
|
# 遍历每个行业板块,获取成分股
|
|||
|
|
for idx, row in industry_list_df.iterrows():
|
|||
|
|
industry_code = row["ts_code"]
|
|||
|
|
industry_name = row["name"]
|
|||
|
|
|
|||
|
|
print(f"\n[{idx + 1}/{len(industry_list_df)}] 正在采集: {industry_name} ({industry_code})")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 获取该行业板块的成分股
|
|||
|
|
member_df = self.fetch_industry_members(industry_code)
|
|||
|
|
if member_df.empty:
|
|||
|
|
print(f" 行业板块 {industry_name} 无成分股数据")
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 转换数据格式(传入行业板块名称)
|
|||
|
|
result_df = self.transform_data(member_df, industry_name)
|
|||
|
|
if result_df.empty:
|
|||
|
|
print(f" 行业板块 {industry_name} 数据转换失败")
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 保存数据
|
|||
|
|
self.save_dataframe(result_df)
|
|||
|
|
total_records += len(result_df)
|
|||
|
|
success_count += 1
|
|||
|
|
print(f" 成功采集 {len(result_df)} 只成分股")
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f" 采集 {industry_name} ({industry_code}) 失败: {exc}")
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("全量覆盖采集完成")
|
|||
|
|
print(f"总行业板块数: {len(industry_list_df)}")
|
|||
|
|
print(f"成功采集: {success_count}")
|
|||
|
|
print(f"失败: {failed_count}")
|
|||
|
|
print(f"累计成分股记录: {total_records}")
|
|||
|
|
print("=" * 60)
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"全量采集失败: {exc}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
finally:
|
|||
|
|
self.engine.dispose()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def collect_ths_industry_member(
|
|||
|
|
db_url: str,
|
|||
|
|
tushare_token: str,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
采集入口 - 全量覆盖采集
|
|||
|
|
"""
|
|||
|
|
collector = THSIndustryMemberCollector(db_url, tushare_token)
|
|||
|
|
collector.run_full_collection()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
|||
|
|
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
|||
|
|
TOKEN = TUSHARE_TOKEN
|
|||
|
|
|
|||
|
|
# 执行全量覆盖采集
|
|||
|
|
collect_ths_industry_member(DB_URL, TOKEN)
|
|||
|
|
|