344 lines
12 KiB
Python
344 lines
12 KiB
Python
# coding:utf-8
|
||
"""
|
||
个股资金流向(同花顺)采集工具
|
||
功能:从 Tushare 获取 moneyflow_ths 数据并落库
|
||
API 文档: https://tushare.pro/document/2?doc_id=348
|
||
说明:接口每日盘后更新,需至少 5000 积分
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
import pandas as pd
|
||
import tushare as ts
|
||
from sqlalchemy import create_engine, text
|
||
|
||
# 添加项目根目录到路径,确保能够读取配置
|
||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
sys.path.append(PROJECT_ROOT)
|
||
|
||
from src.scripts.config import TUSHARE_TOKEN
|
||
|
||
|
||
class MoneyflowTHSCollector:
|
||
"""个股资金流向(THS)采集器"""
|
||
|
||
def __init__(self, db_url: str, tushare_token: str, table_name: str = "gp_moneyflow_ths"):
|
||
"""
|
||
Args:
|
||
db_url: 数据库连接 URL
|
||
tushare_token: Tushare Token
|
||
table_name: 目标表名,默认 gp_moneyflow_ths
|
||
"""
|
||
self.engine = create_engine(
|
||
db_url,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
pool_recycle=3600,
|
||
)
|
||
self.table_name = table_name
|
||
|
||
ts.set_token(tushare_token)
|
||
self.pro = ts.pro_api()
|
||
|
||
print("=" * 60)
|
||
print("个股资金流向(THS)采集工具")
|
||
print(f"目标数据表: {self.table_name}")
|
||
print("=" * 60)
|
||
|
||
@staticmethod
|
||
def convert_tushare_code_to_db(ts_code: str) -> str:
|
||
"""
|
||
将 Tushare 代码(600000.SH)转换为数据库代码(SH600000)
|
||
"""
|
||
if not ts_code or "." not in ts_code:
|
||
return ts_code
|
||
base, market = ts_code.split(".")
|
||
return f"{market}{base}"
|
||
|
||
@staticmethod
|
||
def convert_db_code_to_tushare(symbol: str) -> str:
|
||
"""
|
||
将数据库代码(SH600000)转换为 Tushare 代码(600000.SH)
|
||
"""
|
||
if not symbol:
|
||
return symbol
|
||
prefix = symbol[:2]
|
||
code = symbol[2:]
|
||
if prefix in {"SH", "SZ", "BJ"}:
|
||
return f"{code}.{prefix}"
|
||
return symbol
|
||
|
||
def fetch_data(
|
||
self,
|
||
ts_code: Optional[str] = None,
|
||
trade_date: Optional[str] = None,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
) -> pd.DataFrame:
|
||
"""
|
||
调用 Tushare moneyflow_ths 接口获取数据
|
||
Args 参数同接口文档
|
||
"""
|
||
try:
|
||
df = self.pro.moneyflow_ths(
|
||
ts_code=ts_code,
|
||
trade_date=trade_date,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
)
|
||
return df
|
||
except Exception as exc:
|
||
print(f"Tushare moneyflow_ths 调用失败: {exc}")
|
||
return pd.DataFrame()
|
||
|
||
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
将 Tushare 返回的数据转换为数据库入库格式
|
||
"""
|
||
if df.empty:
|
||
return pd.DataFrame()
|
||
|
||
result = pd.DataFrame()
|
||
result["symbol"] = df["ts_code"].apply(self.convert_tushare_code_to_db)
|
||
result["ts_code"] = df["ts_code"]
|
||
result["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
||
result["name"] = df.get("name")
|
||
result["pct_change"] = df.get("pct_change")
|
||
result["latest"] = df.get("latest")
|
||
result["net_amount"] = df.get("net_amount")
|
||
result["net_d5_amount"] = df.get("net_d5_amount")
|
||
result["buy_lg_amount"] = df.get("buy_lg_amount")
|
||
result["buy_lg_amount_rate"] = df.get("buy_lg_amount_rate")
|
||
result["buy_md_amount"] = df.get("buy_md_amount")
|
||
result["buy_md_amount_rate"] = df.get("buy_md_amount_rate")
|
||
result["buy_sm_amount"] = df.get("buy_sm_amount")
|
||
result["buy_sm_amount_rate"] = df.get("buy_sm_amount_rate")
|
||
result["created_at"] = datetime.now()
|
||
result["updated_at"] = datetime.now()
|
||
return result
|
||
|
||
def save_dataframe(self, df: pd.DataFrame) -> None:
|
||
"""
|
||
将数据写入数据库
|
||
"""
|
||
if df.empty:
|
||
print("无数据需要保存")
|
||
return
|
||
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
||
print(f"成功写入 {len(df)} 条记录")
|
||
|
||
def delete_by_date(self, date_str: str) -> None:
|
||
"""
|
||
删除指定交易日的旧数据
|
||
"""
|
||
with self.engine.begin() as conn:
|
||
delete_sql = text(f"DELETE FROM {self.table_name} WHERE trade_date = :trade_date")
|
||
affected = conn.execute(delete_sql, {"trade_date": date_str})
|
||
print(f"已删除 {date_str} 旧数据 {affected.rowcount} 条")
|
||
|
||
def run_daily_collection(self, date: Optional[str] = None) -> None:
|
||
"""
|
||
获取指定交易日(默认当天)的全市场资金流向数据
|
||
"""
|
||
target_date = date or datetime.now().strftime("%Y-%m-%d")
|
||
trade_date = datetime.strptime(target_date, "%Y-%m-%d").strftime("%Y%m%d")
|
||
|
||
print(f"开始采集 {target_date} ({trade_date}) 个股资金流向数据")
|
||
self.delete_by_date(target_date)
|
||
|
||
df = self.fetch_data(trade_date=trade_date)
|
||
if df.empty:
|
||
print("接口未返回数据(可能为非交易日或权限不足)")
|
||
return
|
||
|
||
result_df = self.transform_data(df)
|
||
if result_df.empty:
|
||
print("数据转换失败")
|
||
return
|
||
|
||
self.save_dataframe(result_df)
|
||
print("每日采集完成")
|
||
|
||
def run_full_collection(self) -> None:
|
||
"""
|
||
执行全量覆盖采集:
|
||
- 清空目标表
|
||
- 遍历数据库中所有股票代码,逐只拉取全历史数据
|
||
"""
|
||
print("=" * 60)
|
||
print("开始执行全量覆盖采集(moneyflow_ths)")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 清空表
|
||
with self.engine.begin() as conn:
|
||
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
||
print(f"{self.table_name} 已清空")
|
||
|
||
# 获取全部股票代码
|
||
codes_df = pd.read_sql("SELECT gp_code FROM gp_code_all_copy", self.engine)
|
||
codes = codes_df["gp_code"].tolist()
|
||
print(f"共获取到 {len(codes)} 只股票")
|
||
|
||
total_records = 0
|
||
success_count = 0
|
||
failed_count = 0
|
||
|
||
for symbol in codes:
|
||
ts_code = self.convert_db_code_to_tushare(symbol)
|
||
try:
|
||
df = self.fetch_data(ts_code=ts_code)
|
||
if df.empty:
|
||
failed_count += 1
|
||
continue
|
||
|
||
result_df = self.transform_data(df)
|
||
if result_df.empty:
|
||
failed_count += 1
|
||
continue
|
||
|
||
self.save_dataframe(result_df)
|
||
total_records += len(result_df)
|
||
success_count += 1
|
||
except Exception as exc:
|
||
print(f"\n采集 {symbol} 失败: {exc}")
|
||
failed_count += 1
|
||
continue
|
||
|
||
print("=" * 60)
|
||
print("全量覆盖采集完成")
|
||
print(f"总股票数: {len(codes)}")
|
||
print(f"成功采集: {success_count}")
|
||
print(f"失败: {failed_count}")
|
||
print(f"累计记录: {total_records}")
|
||
print("=" * 60)
|
||
except Exception as exc:
|
||
print(f"全量采集失败: {exc}")
|
||
finally:
|
||
self.engine.dispose()
|
||
|
||
def run_range_collection(self, start_date: str, end_date: str) -> None:
|
||
"""
|
||
获取时间区间内的全市场资金流向数据(按天分批)
|
||
"""
|
||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||
if start > end:
|
||
raise ValueError("start_date 不能大于 end_date")
|
||
|
||
current = start
|
||
while current <= end:
|
||
target = current.strftime("%Y-%m-%d")
|
||
try:
|
||
self.run_daily_collection(target)
|
||
except Exception as exc:
|
||
print(f"{target} 数据采集失败: {exc}")
|
||
current += pd.Timedelta(days=1)
|
||
|
||
def run_stock_collection(
|
||
self,
|
||
symbol: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
) -> None:
|
||
"""
|
||
获取单只股票在指定区间的资金流向数据
|
||
Args:
|
||
symbol: 数据库格式(SH600000、SZ000001 等)或 Tushare ts_code
|
||
"""
|
||
ts_code = symbol if "." in symbol else self.convert_db_code_to_tushare(symbol)
|
||
|
||
df = self.fetch_data(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
if df.empty:
|
||
print("接口无返回数据")
|
||
return
|
||
|
||
result_df = self.transform_data(df)
|
||
if result_df.empty:
|
||
print("数据转换失败")
|
||
return
|
||
|
||
# 删除已有范围数据,避免重复
|
||
if start_date or end_date:
|
||
start = start_date or result_df["trade_date"].min().strftime("%Y-%m-%d")
|
||
end = end_date or result_df["trade_date"].max().strftime("%Y-%m-%d")
|
||
with self.engine.begin() as conn:
|
||
delete_sql = text(
|
||
f"""
|
||
DELETE FROM {self.table_name}
|
||
WHERE symbol = :symbol AND trade_date BETWEEN :start AND :end
|
||
"""
|
||
)
|
||
affected = conn.execute(
|
||
delete_sql,
|
||
{
|
||
"symbol": self.convert_tushare_code_to_db(ts_code),
|
||
"start": start,
|
||
"end": end,
|
||
},
|
||
)
|
||
print(f"已删除 {symbol} {start}~{end} 旧数据 {affected.rowcount} 条")
|
||
|
||
self.save_dataframe(result_df)
|
||
|
||
|
||
def collect_moneyflow_ths(
|
||
db_url: str,
|
||
tushare_token: str,
|
||
mode: str = "daily",
|
||
date: Optional[str] = None,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
symbol: Optional[str] = None,
|
||
):
|
||
"""
|
||
采集入口
|
||
mode:
|
||
- daily: 指定日期(默认今天)的全市场数据
|
||
- range: 指定日期区间的全市场数据(逐日)
|
||
- stock: 指定股票(symbol)的区间数据
|
||
"""
|
||
collector = MoneyflowTHSCollector(db_url, tushare_token)
|
||
|
||
if mode == "daily":
|
||
collector.run_daily_collection(date)
|
||
elif mode == "range":
|
||
if not start_date or not end_date:
|
||
raise ValueError("range 模式需要提供 start_date 和 end_date,格式 YYYY-MM-DD")
|
||
collector.run_range_collection(start_date, end_date)
|
||
elif mode == "stock":
|
||
if not symbol:
|
||
raise ValueError("stock 模式需要提供 symbol 参数")
|
||
collector.run_stock_collection(symbol, start_date, end_date)
|
||
elif mode == "full":
|
||
collector.run_full_collection()
|
||
else:
|
||
raise ValueError(f"未知的采集模式: {mode}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8pattern@192.168.16.150:3306/factordb_mysql"
|
||
TOKEN = TUSHARE_TOKEN
|
||
|
||
# 示例:
|
||
# 1. 采集当日全市场数据
|
||
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily")
|
||
|
||
# collect_moneyflow_ths(DB_URL, TOKEN, mode="full")
|
||
|
||
# 2. 采集指定日期全市场数据
|
||
collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-17")
|
||
#
|
||
# 3. 采集日期区间全市场数据(逐日)
|
||
# collect_moneyflow_ths(DB_URL, TOKEN, mode="range", start_date="2025-11-01", end_date="2025-11-07")
|
||
#
|
||
# 4. 采集单只股票区间数据
|
||
# collect_moneyflow_ths(DB_URL, TOKEN, mode="stock", symbol="600519.SH", start_date="2025-11-01", end_date="2025-11-07")
|
||
|
||
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-07")
|
||
|