stock_fundamentals/src/tushare_scripts/moneyflow_ths_collector.py

344 lines
12 KiB
Python
Raw Normal View History

2025-11-18 16:50:31 +08:00
# coding:utf-8
"""
个股资金流向同花顺采集工具
功能 Tushare 获取 moneyflow_ths 数据并落库
API 文档: https://tushare.pro/document/2?doc_id=348
说明接口每日盘后更新需至少 5000 积分
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class MoneyflowTHSCollector:
"""个股资金流向THS采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "gp_moneyflow_ths"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 gp_moneyflow_ths
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("个股资金流向THS采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
@staticmethod
def convert_db_code_to_tushare(symbol: str) -> str:
"""
将数据库代码SH600000转换为 Tushare 代码600000.SH
"""
if not symbol:
return symbol
prefix = symbol[:2]
code = symbol[2:]
if prefix in {"SH", "SZ", "BJ"}:
return f"{code}.{prefix}"
return symbol
def fetch_data(
self,
ts_code: Optional[str] = None,
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""
调用 Tushare moneyflow_ths 接口获取数据
Args 参数同接口文档
"""
try:
df = self.pro.moneyflow_ths(
ts_code=ts_code,
trade_date=trade_date,
start_date=start_date,
end_date=end_date,
)
return df
except Exception as exc:
print(f"Tushare moneyflow_ths 调用失败: {exc}")
return pd.DataFrame()
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
"""
if df.empty:
return pd.DataFrame()
result = pd.DataFrame()
result["symbol"] = df["ts_code"].apply(self.convert_tushare_code_to_db)
result["ts_code"] = df["ts_code"]
result["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
result["name"] = df.get("name")
result["pct_change"] = df.get("pct_change")
result["latest"] = df.get("latest")
result["net_amount"] = df.get("net_amount")
result["net_d5_amount"] = df.get("net_d5_amount")
result["buy_lg_amount"] = df.get("buy_lg_amount")
result["buy_lg_amount_rate"] = df.get("buy_lg_amount_rate")
result["buy_md_amount"] = df.get("buy_md_amount")
result["buy_md_amount_rate"] = df.get("buy_md_amount_rate")
result["buy_sm_amount"] = df.get("buy_sm_amount")
result["buy_sm_amount_rate"] = df.get("buy_sm_amount_rate")
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
print("无数据需要保存")
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
print(f"成功写入 {len(df)} 条记录")
def delete_by_date(self, date_str: str) -> None:
"""
删除指定交易日的旧数据
"""
with self.engine.begin() as conn:
delete_sql = text(f"DELETE FROM {self.table_name} WHERE trade_date = :trade_date")
affected = conn.execute(delete_sql, {"trade_date": date_str})
print(f"已删除 {date_str} 旧数据 {affected.rowcount}")
def run_daily_collection(self, date: Optional[str] = None) -> None:
"""
获取指定交易日默认当天的全市场资金流向数据
"""
target_date = date or datetime.now().strftime("%Y-%m-%d")
trade_date = datetime.strptime(target_date, "%Y-%m-%d").strftime("%Y%m%d")
print(f"开始采集 {target_date} ({trade_date}) 个股资金流向数据")
self.delete_by_date(target_date)
df = self.fetch_data(trade_date=trade_date)
if df.empty:
print("接口未返回数据(可能为非交易日或权限不足)")
return
result_df = self.transform_data(df)
if result_df.empty:
print("数据转换失败")
return
self.save_dataframe(result_df)
print("每日采集完成")
def run_full_collection(self) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 遍历数据库中所有股票代码逐只拉取全历史数据
"""
print("=" * 60)
print("开始执行全量覆盖采集moneyflow_ths")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取全部股票代码
codes_df = pd.read_sql("SELECT gp_code FROM gp_code_all_copy", self.engine)
codes = codes_df["gp_code"].tolist()
print(f"共获取到 {len(codes)} 只股票")
total_records = 0
success_count = 0
failed_count = 0
for symbol in codes:
ts_code = self.convert_db_code_to_tushare(symbol)
try:
df = self.fetch_data(ts_code=ts_code)
if df.empty:
failed_count += 1
continue
result_df = self.transform_data(df)
if result_df.empty:
failed_count += 1
continue
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
except Exception as exc:
print(f"\n采集 {symbol} 失败: {exc}")
failed_count += 1
continue
print("=" * 60)
print("全量覆盖采集完成")
print(f"总股票数: {len(codes)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
finally:
self.engine.dispose()
def run_range_collection(self, start_date: str, end_date: str) -> None:
"""
获取时间区间内的全市场资金流向数据按天分批
"""
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
if start > end:
raise ValueError("start_date 不能大于 end_date")
current = start
while current <= end:
target = current.strftime("%Y-%m-%d")
try:
self.run_daily_collection(target)
except Exception as exc:
print(f"{target} 数据采集失败: {exc}")
current += pd.Timedelta(days=1)
def run_stock_collection(
self,
symbol: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> None:
"""
获取单只股票在指定区间的资金流向数据
Args:
symbol: 数据库格式SH600000SZ000001 Tushare ts_code
"""
ts_code = symbol if "." in symbol else self.convert_db_code_to_tushare(symbol)
df = self.fetch_data(ts_code=ts_code, start_date=start_date, end_date=end_date)
if df.empty:
print("接口无返回数据")
return
result_df = self.transform_data(df)
if result_df.empty:
print("数据转换失败")
return
# 删除已有范围数据,避免重复
if start_date or end_date:
start = start_date or result_df["trade_date"].min().strftime("%Y-%m-%d")
end = end_date or result_df["trade_date"].max().strftime("%Y-%m-%d")
with self.engine.begin() as conn:
delete_sql = text(
f"""
DELETE FROM {self.table_name}
WHERE symbol = :symbol AND trade_date BETWEEN :start AND :end
"""
)
affected = conn.execute(
delete_sql,
{
"symbol": self.convert_tushare_code_to_db(ts_code),
"start": start,
"end": end,
},
)
print(f"已删除 {symbol} {start}~{end} 旧数据 {affected.rowcount}")
self.save_dataframe(result_df)
def collect_moneyflow_ths(
db_url: str,
tushare_token: str,
mode: str = "daily",
date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
symbol: Optional[str] = None,
):
"""
采集入口
mode:
- daily: 指定日期默认今天的全市场数据
- range: 指定日期区间的全市场数据逐日
- stock: 指定股票symbol的区间数据
"""
collector = MoneyflowTHSCollector(db_url, tushare_token)
if mode == "daily":
collector.run_daily_collection(date)
elif mode == "range":
if not start_date or not end_date:
raise ValueError("range 模式需要提供 start_date 和 end_date格式 YYYY-MM-DD")
collector.run_range_collection(start_date, end_date)
elif mode == "stock":
if not symbol:
raise ValueError("stock 模式需要提供 symbol 参数")
collector.run_stock_collection(symbol, start_date, end_date)
elif mode == "full":
collector.run_full_collection()
else:
raise ValueError(f"未知的采集模式: {mode}")
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8pattern@192.168.16.150:3306/factordb_mysql"
TOKEN = TUSHARE_TOKEN
# 示例:
# 1. 采集当日全市场数据
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily")
# collect_moneyflow_ths(DB_URL, TOKEN, mode="full")
# 2. 采集指定日期全市场数据
collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-17")
#
# 3. 采集日期区间全市场数据(逐日)
# collect_moneyflow_ths(DB_URL, TOKEN, mode="range", start_date="2025-11-01", end_date="2025-11-07")
#
# 4. 采集单只股票区间数据
# collect_moneyflow_ths(DB_URL, TOKEN, mode="stock", symbol="600519.SH", start_date="2025-11-01", end_date="2025-11-07")
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-07")