344 lines
12 KiB
Python
344 lines
12 KiB
Python
|
|
# coding:utf-8
|
|||
|
|
"""
|
|||
|
|
个股资金流向(同花顺)采集工具
|
|||
|
|
功能:从 Tushare 获取 moneyflow_ths 数据并落库
|
|||
|
|
API 文档: https://tushare.pro/document/2?doc_id=348
|
|||
|
|
说明:接口每日盘后更新,需至少 5000 积分
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
import pandas as pd
|
|||
|
|
import tushare as ts
|
|||
|
|
from sqlalchemy import create_engine, text
|
|||
|
|
|
|||
|
|
# 添加项目根目录到路径,确保能够读取配置
|
|||
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
sys.path.append(PROJECT_ROOT)
|
|||
|
|
|
|||
|
|
from src.scripts.config import TUSHARE_TOKEN
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MoneyflowTHSCollector:
|
|||
|
|
"""个股资金流向(THS)采集器"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_url: str, tushare_token: str, table_name: str = "gp_moneyflow_ths"):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
db_url: 数据库连接 URL
|
|||
|
|
tushare_token: Tushare Token
|
|||
|
|
table_name: 目标表名,默认 gp_moneyflow_ths
|
|||
|
|
"""
|
|||
|
|
self.engine = create_engine(
|
|||
|
|
db_url,
|
|||
|
|
pool_size=5,
|
|||
|
|
max_overflow=10,
|
|||
|
|
pool_recycle=3600,
|
|||
|
|
)
|
|||
|
|
self.table_name = table_name
|
|||
|
|
|
|||
|
|
ts.set_token(tushare_token)
|
|||
|
|
self.pro = ts.pro_api()
|
|||
|
|
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("个股资金流向(THS)采集工具")
|
|||
|
|
print(f"目标数据表: {self.table_name}")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def convert_tushare_code_to_db(ts_code: str) -> str:
|
|||
|
|
"""
|
|||
|
|
将 Tushare 代码(600000.SH)转换为数据库代码(SH600000)
|
|||
|
|
"""
|
|||
|
|
if not ts_code or "." not in ts_code:
|
|||
|
|
return ts_code
|
|||
|
|
base, market = ts_code.split(".")
|
|||
|
|
return f"{market}{base}"
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def convert_db_code_to_tushare(symbol: str) -> str:
|
|||
|
|
"""
|
|||
|
|
将数据库代码(SH600000)转换为 Tushare 代码(600000.SH)
|
|||
|
|
"""
|
|||
|
|
if not symbol:
|
|||
|
|
return symbol
|
|||
|
|
prefix = symbol[:2]
|
|||
|
|
code = symbol[2:]
|
|||
|
|
if prefix in {"SH", "SZ", "BJ"}:
|
|||
|
|
return f"{code}.{prefix}"
|
|||
|
|
return symbol
|
|||
|
|
|
|||
|
|
def fetch_data(
|
|||
|
|
self,
|
|||
|
|
ts_code: Optional[str] = None,
|
|||
|
|
trade_date: Optional[str] = None,
|
|||
|
|
start_date: Optional[str] = None,
|
|||
|
|
end_date: Optional[str] = None,
|
|||
|
|
) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
调用 Tushare moneyflow_ths 接口获取数据
|
|||
|
|
Args 参数同接口文档
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
df = self.pro.moneyflow_ths(
|
|||
|
|
ts_code=ts_code,
|
|||
|
|
trade_date=trade_date,
|
|||
|
|
start_date=start_date,
|
|||
|
|
end_date=end_date,
|
|||
|
|
)
|
|||
|
|
return df
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"Tushare moneyflow_ths 调用失败: {exc}")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
将 Tushare 返回的数据转换为数据库入库格式
|
|||
|
|
"""
|
|||
|
|
if df.empty:
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
result = pd.DataFrame()
|
|||
|
|
result["symbol"] = df["ts_code"].apply(self.convert_tushare_code_to_db)
|
|||
|
|
result["ts_code"] = df["ts_code"]
|
|||
|
|
result["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
|||
|
|
result["name"] = df.get("name")
|
|||
|
|
result["pct_change"] = df.get("pct_change")
|
|||
|
|
result["latest"] = df.get("latest")
|
|||
|
|
result["net_amount"] = df.get("net_amount")
|
|||
|
|
result["net_d5_amount"] = df.get("net_d5_amount")
|
|||
|
|
result["buy_lg_amount"] = df.get("buy_lg_amount")
|
|||
|
|
result["buy_lg_amount_rate"] = df.get("buy_lg_amount_rate")
|
|||
|
|
result["buy_md_amount"] = df.get("buy_md_amount")
|
|||
|
|
result["buy_md_amount_rate"] = df.get("buy_md_amount_rate")
|
|||
|
|
result["buy_sm_amount"] = df.get("buy_sm_amount")
|
|||
|
|
result["buy_sm_amount_rate"] = df.get("buy_sm_amount_rate")
|
|||
|
|
result["created_at"] = datetime.now()
|
|||
|
|
result["updated_at"] = datetime.now()
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def save_dataframe(self, df: pd.DataFrame) -> None:
|
|||
|
|
"""
|
|||
|
|
将数据写入数据库
|
|||
|
|
"""
|
|||
|
|
if df.empty:
|
|||
|
|
print("无数据需要保存")
|
|||
|
|
return
|
|||
|
|
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
|||
|
|
print(f"成功写入 {len(df)} 条记录")
|
|||
|
|
|
|||
|
|
def delete_by_date(self, date_str: str) -> None:
|
|||
|
|
"""
|
|||
|
|
删除指定交易日的旧数据
|
|||
|
|
"""
|
|||
|
|
with self.engine.begin() as conn:
|
|||
|
|
delete_sql = text(f"DELETE FROM {self.table_name} WHERE trade_date = :trade_date")
|
|||
|
|
affected = conn.execute(delete_sql, {"trade_date": date_str})
|
|||
|
|
print(f"已删除 {date_str} 旧数据 {affected.rowcount} 条")
|
|||
|
|
|
|||
|
|
def run_daily_collection(self, date: Optional[str] = None) -> None:
|
|||
|
|
"""
|
|||
|
|
获取指定交易日(默认当天)的全市场资金流向数据
|
|||
|
|
"""
|
|||
|
|
target_date = date or datetime.now().strftime("%Y-%m-%d")
|
|||
|
|
trade_date = datetime.strptime(target_date, "%Y-%m-%d").strftime("%Y%m%d")
|
|||
|
|
|
|||
|
|
print(f"开始采集 {target_date} ({trade_date}) 个股资金流向数据")
|
|||
|
|
self.delete_by_date(target_date)
|
|||
|
|
|
|||
|
|
df = self.fetch_data(trade_date=trade_date)
|
|||
|
|
if df.empty:
|
|||
|
|
print("接口未返回数据(可能为非交易日或权限不足)")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
result_df = self.transform_data(df)
|
|||
|
|
if result_df.empty:
|
|||
|
|
print("数据转换失败")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
self.save_dataframe(result_df)
|
|||
|
|
print("每日采集完成")
|
|||
|
|
|
|||
|
|
def run_full_collection(self) -> None:
|
|||
|
|
"""
|
|||
|
|
执行全量覆盖采集:
|
|||
|
|
- 清空目标表
|
|||
|
|
- 遍历数据库中所有股票代码,逐只拉取全历史数据
|
|||
|
|
"""
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("开始执行全量覆盖采集(moneyflow_ths)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 清空表
|
|||
|
|
with self.engine.begin() as conn:
|
|||
|
|
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
|||
|
|
print(f"{self.table_name} 已清空")
|
|||
|
|
|
|||
|
|
# 获取全部股票代码
|
|||
|
|
codes_df = pd.read_sql("SELECT gp_code FROM gp_code_all_copy", self.engine)
|
|||
|
|
codes = codes_df["gp_code"].tolist()
|
|||
|
|
print(f"共获取到 {len(codes)} 只股票")
|
|||
|
|
|
|||
|
|
total_records = 0
|
|||
|
|
success_count = 0
|
|||
|
|
failed_count = 0
|
|||
|
|
|
|||
|
|
for symbol in codes:
|
|||
|
|
ts_code = self.convert_db_code_to_tushare(symbol)
|
|||
|
|
try:
|
|||
|
|
df = self.fetch_data(ts_code=ts_code)
|
|||
|
|
if df.empty:
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
result_df = self.transform_data(df)
|
|||
|
|
if result_df.empty:
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
self.save_dataframe(result_df)
|
|||
|
|
total_records += len(result_df)
|
|||
|
|
success_count += 1
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"\n采集 {symbol} 失败: {exc}")
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("全量覆盖采集完成")
|
|||
|
|
print(f"总股票数: {len(codes)}")
|
|||
|
|
print(f"成功采集: {success_count}")
|
|||
|
|
print(f"失败: {failed_count}")
|
|||
|
|
print(f"累计记录: {total_records}")
|
|||
|
|
print("=" * 60)
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"全量采集失败: {exc}")
|
|||
|
|
finally:
|
|||
|
|
self.engine.dispose()
|
|||
|
|
|
|||
|
|
def run_range_collection(self, start_date: str, end_date: str) -> None:
|
|||
|
|
"""
|
|||
|
|
获取时间区间内的全市场资金流向数据(按天分批)
|
|||
|
|
"""
|
|||
|
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
|||
|
|
end = datetime.strptime(end_date, "%Y-%m-%d")
|
|||
|
|
if start > end:
|
|||
|
|
raise ValueError("start_date 不能大于 end_date")
|
|||
|
|
|
|||
|
|
current = start
|
|||
|
|
while current <= end:
|
|||
|
|
target = current.strftime("%Y-%m-%d")
|
|||
|
|
try:
|
|||
|
|
self.run_daily_collection(target)
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"{target} 数据采集失败: {exc}")
|
|||
|
|
current += pd.Timedelta(days=1)
|
|||
|
|
|
|||
|
|
def run_stock_collection(
|
|||
|
|
self,
|
|||
|
|
symbol: str,
|
|||
|
|
start_date: Optional[str] = None,
|
|||
|
|
end_date: Optional[str] = None,
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
获取单只股票在指定区间的资金流向数据
|
|||
|
|
Args:
|
|||
|
|
symbol: 数据库格式(SH600000、SZ000001 等)或 Tushare ts_code
|
|||
|
|
"""
|
|||
|
|
ts_code = symbol if "." in symbol else self.convert_db_code_to_tushare(symbol)
|
|||
|
|
|
|||
|
|
df = self.fetch_data(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
|||
|
|
if df.empty:
|
|||
|
|
print("接口无返回数据")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
result_df = self.transform_data(df)
|
|||
|
|
if result_df.empty:
|
|||
|
|
print("数据转换失败")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 删除已有范围数据,避免重复
|
|||
|
|
if start_date or end_date:
|
|||
|
|
start = start_date or result_df["trade_date"].min().strftime("%Y-%m-%d")
|
|||
|
|
end = end_date or result_df["trade_date"].max().strftime("%Y-%m-%d")
|
|||
|
|
with self.engine.begin() as conn:
|
|||
|
|
delete_sql = text(
|
|||
|
|
f"""
|
|||
|
|
DELETE FROM {self.table_name}
|
|||
|
|
WHERE symbol = :symbol AND trade_date BETWEEN :start AND :end
|
|||
|
|
"""
|
|||
|
|
)
|
|||
|
|
affected = conn.execute(
|
|||
|
|
delete_sql,
|
|||
|
|
{
|
|||
|
|
"symbol": self.convert_tushare_code_to_db(ts_code),
|
|||
|
|
"start": start,
|
|||
|
|
"end": end,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
print(f"已删除 {symbol} {start}~{end} 旧数据 {affected.rowcount} 条")
|
|||
|
|
|
|||
|
|
self.save_dataframe(result_df)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def collect_moneyflow_ths(
|
|||
|
|
db_url: str,
|
|||
|
|
tushare_token: str,
|
|||
|
|
mode: str = "daily",
|
|||
|
|
date: Optional[str] = None,
|
|||
|
|
start_date: Optional[str] = None,
|
|||
|
|
end_date: Optional[str] = None,
|
|||
|
|
symbol: Optional[str] = None,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
采集入口
|
|||
|
|
mode:
|
|||
|
|
- daily: 指定日期(默认今天)的全市场数据
|
|||
|
|
- range: 指定日期区间的全市场数据(逐日)
|
|||
|
|
- stock: 指定股票(symbol)的区间数据
|
|||
|
|
"""
|
|||
|
|
collector = MoneyflowTHSCollector(db_url, tushare_token)
|
|||
|
|
|
|||
|
|
if mode == "daily":
|
|||
|
|
collector.run_daily_collection(date)
|
|||
|
|
elif mode == "range":
|
|||
|
|
if not start_date or not end_date:
|
|||
|
|
raise ValueError("range 模式需要提供 start_date 和 end_date,格式 YYYY-MM-DD")
|
|||
|
|
collector.run_range_collection(start_date, end_date)
|
|||
|
|
elif mode == "stock":
|
|||
|
|
if not symbol:
|
|||
|
|
raise ValueError("stock 模式需要提供 symbol 参数")
|
|||
|
|
collector.run_stock_collection(symbol, start_date, end_date)
|
|||
|
|
elif mode == "full":
|
|||
|
|
collector.run_full_collection()
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"未知的采集模式: {mode}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
|||
|
|
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8pattern@192.168.16.150:3306/factordb_mysql"
|
|||
|
|
TOKEN = TUSHARE_TOKEN
|
|||
|
|
|
|||
|
|
# 示例:
|
|||
|
|
# 1. 采集当日全市场数据
|
|||
|
|
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily")
|
|||
|
|
|
|||
|
|
# collect_moneyflow_ths(DB_URL, TOKEN, mode="full")
|
|||
|
|
|
|||
|
|
# 2. 采集指定日期全市场数据
|
|||
|
|
collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-17")
|
|||
|
|
#
|
|||
|
|
# 3. 采集日期区间全市场数据(逐日)
|
|||
|
|
# collect_moneyflow_ths(DB_URL, TOKEN, mode="range", start_date="2025-11-01", end_date="2025-11-07")
|
|||
|
|
#
|
|||
|
|
# 4. 采集单只股票区间数据
|
|||
|
|
# collect_moneyflow_ths(DB_URL, TOKEN, mode="stock", symbol="600519.SH", start_date="2025-11-01", end_date="2025-11-07")
|
|||
|
|
|
|||
|
|
# collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-07")
|
|||
|
|
|