stock_fundamentals/src/tushare_scripts/moneyflow_ths_collector.py

344 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding:utf-8
"""
个股资金流向(同花顺)采集工具
功能:从 Tushare 获取 moneyflow_ths 数据并落库
API 文档: https://tushare.pro/document/2?doc_id=348
说明:接口每日盘后更新,需至少 5000 积分
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class MoneyflowTHSCollector:
"""个股资金流向THS采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "gp_moneyflow_ths"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名,默认 gp_moneyflow_ths
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("个股资金流向THS采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
将 Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
@staticmethod
def convert_db_code_to_tushare(symbol: str) -> str:
"""
将数据库代码SH600000转换为 Tushare 代码600000.SH
"""
if not symbol:
return symbol
prefix = symbol[:2]
code = symbol[2:]
if prefix in {"SH", "SZ", "BJ"}:
return f"{code}.{prefix}"
return symbol
def fetch_data(
self,
ts_code: Optional[str] = None,
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""
调用 Tushare moneyflow_ths 接口获取数据
Args 参数同接口文档
"""
try:
df = self.pro.moneyflow_ths(
ts_code=ts_code,
trade_date=trade_date,
start_date=start_date,
end_date=end_date,
)
return df
except Exception as exc:
print(f"Tushare moneyflow_ths 调用失败: {exc}")
return pd.DataFrame()
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
将 Tushare 返回的数据转换为数据库入库格式
"""
if df.empty:
return pd.DataFrame()
result = pd.DataFrame()
result["symbol"] = df["ts_code"].apply(self.convert_tushare_code_to_db)
result["ts_code"] = df["ts_code"]
result["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
result["name"] = df.get("name")
result["pct_change"] = df.get("pct_change")
result["latest"] = df.get("latest")
result["net_amount"] = df.get("net_amount")
result["net_d5_amount"] = df.get("net_d5_amount")
result["buy_lg_amount"] = df.get("buy_lg_amount")
result["buy_lg_amount_rate"] = df.get("buy_lg_amount_rate")
result["buy_md_amount"] = df.get("buy_md_amount")
result["buy_md_amount_rate"] = df.get("buy_md_amount_rate")
result["buy_sm_amount"] = df.get("buy_sm_amount")
result["buy_sm_amount_rate"] = df.get("buy_sm_amount_rate")
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
print("无数据需要保存")
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
print(f"成功写入 {len(df)} 条记录")
def delete_by_date(self, date_str: str) -> None:
"""
删除指定交易日的旧数据
"""
with self.engine.begin() as conn:
delete_sql = text(f"DELETE FROM {self.table_name} WHERE trade_date = :trade_date")
affected = conn.execute(delete_sql, {"trade_date": date_str})
print(f"已删除 {date_str} 旧数据 {affected.rowcount}")
def run_daily_collection(self, date: Optional[str] = None) -> None:
"""
获取指定交易日(默认当天)的全市场资金流向数据
"""
target_date = date or datetime.now().strftime("%Y-%m-%d")
trade_date = datetime.strptime(target_date, "%Y-%m-%d").strftime("%Y%m%d")
print(f"开始采集 {target_date} ({trade_date}) 个股资金流向数据")
self.delete_by_date(target_date)
df = self.fetch_data(trade_date=trade_date)
if df.empty:
print("接口未返回数据(可能为非交易日或权限不足)")
return
result_df = self.transform_data(df)
if result_df.empty:
print("数据转换失败")
return
self.save_dataframe(result_df)
print("每日采集完成")
def run_full_collection(self) -> None:
"""
执行全量覆盖采集:
- 清空目标表
- 遍历数据库中所有股票代码,逐只拉取全历史数据
"""
print("=" * 60)
print("开始执行全量覆盖采集moneyflow_ths")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取全部股票代码
codes_df = pd.read_sql("SELECT gp_code FROM gp_code_all_copy", self.engine)
codes = codes_df["gp_code"].tolist()
print(f"共获取到 {len(codes)} 只股票")
total_records = 0
success_count = 0
failed_count = 0
for symbol in codes:
ts_code = self.convert_db_code_to_tushare(symbol)
try:
df = self.fetch_data(ts_code=ts_code)
if df.empty:
failed_count += 1
continue
result_df = self.transform_data(df)
if result_df.empty:
failed_count += 1
continue
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
except Exception as exc:
print(f"\n采集 {symbol} 失败: {exc}")
failed_count += 1
continue
print("=" * 60)
print("全量覆盖采集完成")
print(f"总股票数: {len(codes)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
finally:
self.engine.dispose()
def run_range_collection(self, start_date: str, end_date: str) -> None:
"""
获取时间区间内的全市场资金流向数据(按天分批)
"""
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
if start > end:
raise ValueError("start_date 不能大于 end_date")
current = start
while current <= end:
target = current.strftime("%Y-%m-%d")
try:
self.run_daily_collection(target)
except Exception as exc:
print(f"{target} 数据采集失败: {exc}")
current += pd.Timedelta(days=1)
def run_stock_collection(
self,
symbol: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> None:
"""
获取单只股票在指定区间的资金流向数据
Args:
symbol: 数据库格式SH600000、SZ000001 等)或 Tushare ts_code
"""
ts_code = symbol if "." in symbol else self.convert_db_code_to_tushare(symbol)
df = self.fetch_data(ts_code=ts_code, start_date=start_date, end_date=end_date)
if df.empty:
print("接口无返回数据")
return
result_df = self.transform_data(df)
if result_df.empty:
print("数据转换失败")
return
# 删除已有范围数据,避免重复
if start_date or end_date:
start = start_date or result_df["trade_date"].min().strftime("%Y-%m-%d")
end = end_date or result_df["trade_date"].max().strftime("%Y-%m-%d")
with self.engine.begin() as conn:
delete_sql = text(
f"""
DELETE FROM {self.table_name}
WHERE symbol = :symbol AND trade_date BETWEEN :start AND :end
"""
)
affected = conn.execute(
delete_sql,
{
"symbol": self.convert_tushare_code_to_db(ts_code),
"start": start,
"end": end,
},
)
print(f"已删除 {symbol} {start}~{end} 旧数据 {affected.rowcount}")
self.save_dataframe(result_df)
def collect_moneyflow_ths(
db_url: str,
tushare_token: str,
mode: str = "daily",
date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
symbol: Optional[str] = None,
):
"""
采集入口
mode:
- daily: 指定日期(默认今天)的全市场数据
- range: 指定日期区间的全市场数据(逐日)
- stock: 指定股票symbol的区间数据
"""
collector = MoneyflowTHSCollector(db_url, tushare_token)
if mode == "daily":
collector.run_daily_collection(date)
elif mode == "range":
if not start_date or not end_date:
raise ValueError("range 模式需要提供 start_date 和 end_date格式 YYYY-MM-DD")
collector.run_range_collection(start_date, end_date)
elif mode == "stock":
if not symbol:
raise ValueError("stock 模式需要提供 symbol 参数")
collector.run_stock_collection(symbol, start_date, end_date)
elif mode == "full":
collector.run_full_collection()
else:
raise ValueError(f"未知的采集模式: {mode}")
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8pattern@192.168.16.150:3306/factordb_mysql"
TOKEN = TUSHARE_TOKEN
# 示例:
# 1. 采集当日全市场数据
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily")
# collect_moneyflow_ths(DB_URL, TOKEN, mode="full")
# 2. 采集指定日期全市场数据
collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-17")
#
# 3. 采集日期区间全市场数据(逐日)
# collect_moneyflow_ths(DB_URL, TOKEN, mode="range", start_date="2025-11-01", end_date="2025-11-07")
#
# 4. 采集单只股票区间数据
# collect_moneyflow_ths(DB_URL, TOKEN, mode="stock", symbol="600519.SH", start_date="2025-11-01", end_date="2025-11-07")
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-07")