# coding:utf-8 """ 个股资金流向(同花顺)采集工具 功能:从 Tushare 获取 moneyflow_ths 数据并落库 API 文档: https://tushare.pro/document/2?doc_id=348 说明:接口每日盘后更新,需至少 5000 积分 """ import os import sys from datetime import datetime from typing import Optional import pandas as pd import tushare as ts from sqlalchemy import create_engine, text # 添加项目根目录到路径,确保能够读取配置 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(PROJECT_ROOT) from src.scripts.config import TUSHARE_TOKEN class MoneyflowTHSCollector: """个股资金流向(THS)采集器""" def __init__(self, db_url: str, tushare_token: str, table_name: str = "gp_moneyflow_ths"): """ Args: db_url: 数据库连接 URL tushare_token: Tushare Token table_name: 目标表名,默认 gp_moneyflow_ths """ self.engine = create_engine( db_url, pool_size=5, max_overflow=10, pool_recycle=3600, ) self.table_name = table_name ts.set_token(tushare_token) self.pro = ts.pro_api() print("=" * 60) print("个股资金流向(THS)采集工具") print(f"目标数据表: {self.table_name}") print("=" * 60) @staticmethod def convert_tushare_code_to_db(ts_code: str) -> str: """ 将 Tushare 代码(600000.SH)转换为数据库代码(SH600000) """ if not ts_code or "." not in ts_code: return ts_code base, market = ts_code.split(".") return f"{market}{base}" @staticmethod def convert_db_code_to_tushare(symbol: str) -> str: """ 将数据库代码(SH600000)转换为 Tushare 代码(600000.SH) """ if not symbol: return symbol prefix = symbol[:2] code = symbol[2:] if prefix in {"SH", "SZ", "BJ"}: return f"{code}.{prefix}" return symbol def fetch_data( self, ts_code: Optional[str] = None, trade_date: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, ) -> pd.DataFrame: """ 调用 Tushare moneyflow_ths 接口获取数据 Args 参数同接口文档 """ try: df = self.pro.moneyflow_ths( ts_code=ts_code, trade_date=trade_date, start_date=start_date, end_date=end_date, ) return df except Exception as exc: print(f"Tushare moneyflow_ths 调用失败: {exc}") return pd.DataFrame() def transform_data(self, df: pd.DataFrame) -> pd.DataFrame: """ 将 Tushare 返回的数据转换为数据库入库格式 """ if df.empty: return pd.DataFrame() result = pd.DataFrame() result["symbol"] = df["ts_code"].apply(self.convert_tushare_code_to_db) result["ts_code"] = df["ts_code"] result["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d") result["name"] = df.get("name") result["pct_change"] = df.get("pct_change") result["latest"] = df.get("latest") result["net_amount"] = df.get("net_amount") result["net_d5_amount"] = df.get("net_d5_amount") result["buy_lg_amount"] = df.get("buy_lg_amount") result["buy_lg_amount_rate"] = df.get("buy_lg_amount_rate") result["buy_md_amount"] = df.get("buy_md_amount") result["buy_md_amount_rate"] = df.get("buy_md_amount_rate") result["buy_sm_amount"] = df.get("buy_sm_amount") result["buy_sm_amount_rate"] = df.get("buy_sm_amount_rate") result["created_at"] = datetime.now() result["updated_at"] = datetime.now() return result def save_dataframe(self, df: pd.DataFrame) -> None: """ 将数据写入数据库 """ if df.empty: print("无数据需要保存") return df.to_sql(self.table_name, self.engine, if_exists="append", index=False) print(f"成功写入 {len(df)} 条记录") def delete_by_date(self, date_str: str) -> None: """ 删除指定交易日的旧数据 """ with self.engine.begin() as conn: delete_sql = text(f"DELETE FROM {self.table_name} WHERE trade_date = :trade_date") affected = conn.execute(delete_sql, {"trade_date": date_str}) print(f"已删除 {date_str} 旧数据 {affected.rowcount} 条") def run_daily_collection(self, date: Optional[str] = None) -> None: """ 获取指定交易日(默认当天)的全市场资金流向数据 """ target_date = date or datetime.now().strftime("%Y-%m-%d") trade_date = datetime.strptime(target_date, "%Y-%m-%d").strftime("%Y%m%d") print(f"开始采集 {target_date} ({trade_date}) 个股资金流向数据") self.delete_by_date(target_date) df = self.fetch_data(trade_date=trade_date) if df.empty: print("接口未返回数据(可能为非交易日或权限不足)") return result_df = self.transform_data(df) if result_df.empty: print("数据转换失败") return self.save_dataframe(result_df) print("每日采集完成") def run_full_collection(self) -> None: """ 执行全量覆盖采集: - 清空目标表 - 遍历数据库中所有股票代码,逐只拉取全历史数据 """ print("=" * 60) print("开始执行全量覆盖采集(moneyflow_ths)") print("=" * 60) try: # 清空表 with self.engine.begin() as conn: conn.execute(text(f"TRUNCATE TABLE {self.table_name}")) print(f"{self.table_name} 已清空") # 获取全部股票代码 codes_df = pd.read_sql("SELECT gp_code FROM gp_code_all_copy", self.engine) codes = codes_df["gp_code"].tolist() print(f"共获取到 {len(codes)} 只股票") total_records = 0 success_count = 0 failed_count = 0 for symbol in codes: ts_code = self.convert_db_code_to_tushare(symbol) try: df = self.fetch_data(ts_code=ts_code) if df.empty: failed_count += 1 continue result_df = self.transform_data(df) if result_df.empty: failed_count += 1 continue self.save_dataframe(result_df) total_records += len(result_df) success_count += 1 except Exception as exc: print(f"\n采集 {symbol} 失败: {exc}") failed_count += 1 continue print("=" * 60) print("全量覆盖采集完成") print(f"总股票数: {len(codes)}") print(f"成功采集: {success_count}") print(f"失败: {failed_count}") print(f"累计记录: {total_records}") print("=" * 60) except Exception as exc: print(f"全量采集失败: {exc}") finally: self.engine.dispose() def run_range_collection(self, start_date: str, end_date: str) -> None: """ 获取时间区间内的全市场资金流向数据(按天分批) """ start = datetime.strptime(start_date, "%Y-%m-%d") end = datetime.strptime(end_date, "%Y-%m-%d") if start > end: raise ValueError("start_date 不能大于 end_date") current = start while current <= end: target = current.strftime("%Y-%m-%d") try: self.run_daily_collection(target) except Exception as exc: print(f"{target} 数据采集失败: {exc}") current += pd.Timedelta(days=1) def run_stock_collection( self, symbol: str, start_date: Optional[str] = None, end_date: Optional[str] = None, ) -> None: """ 获取单只股票在指定区间的资金流向数据 Args: symbol: 数据库格式(SH600000、SZ000001 等)或 Tushare ts_code """ ts_code = symbol if "." in symbol else self.convert_db_code_to_tushare(symbol) df = self.fetch_data(ts_code=ts_code, start_date=start_date, end_date=end_date) if df.empty: print("接口无返回数据") return result_df = self.transform_data(df) if result_df.empty: print("数据转换失败") return # 删除已有范围数据,避免重复 if start_date or end_date: start = start_date or result_df["trade_date"].min().strftime("%Y-%m-%d") end = end_date or result_df["trade_date"].max().strftime("%Y-%m-%d") with self.engine.begin() as conn: delete_sql = text( f""" DELETE FROM {self.table_name} WHERE symbol = :symbol AND trade_date BETWEEN :start AND :end """ ) affected = conn.execute( delete_sql, { "symbol": self.convert_tushare_code_to_db(ts_code), "start": start, "end": end, }, ) print(f"已删除 {symbol} {start}~{end} 旧数据 {affected.rowcount} 条") self.save_dataframe(result_df) def collect_moneyflow_ths( db_url: str, tushare_token: str, mode: str = "daily", date: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, symbol: Optional[str] = None, ): """ 采集入口 mode: - daily: 指定日期(默认今天)的全市场数据 - range: 指定日期区间的全市场数据(逐日) - stock: 指定股票(symbol)的区间数据 """ collector = MoneyflowTHSCollector(db_url, tushare_token) if mode == "daily": collector.run_daily_collection(date) elif mode == "range": if not start_date or not end_date: raise ValueError("range 模式需要提供 start_date 和 end_date,格式 YYYY-MM-DD") collector.run_range_collection(start_date, end_date) elif mode == "stock": if not symbol: raise ValueError("stock 模式需要提供 symbol 参数") collector.run_stock_collection(symbol, start_date, end_date) elif mode == "full": collector.run_full_collection() else: raise ValueError(f"未知的采集模式: {mode}") if __name__ == "__main__": # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8pattern@192.168.16.150:3306/factordb_mysql" TOKEN = TUSHARE_TOKEN # 示例: # 1. 采集当日全市场数据 # collect_moneyflow_ths(DB_URL, TOKEN, mode="daily") # collect_moneyflow_ths(DB_URL, TOKEN, mode="full") # 2. 采集指定日期全市场数据 collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-17") # # 3. 采集日期区间全市场数据(逐日) # collect_moneyflow_ths(DB_URL, TOKEN, mode="range", start_date="2025-11-01", end_date="2025-11-07") # # 4. 采集单只股票区间数据 # collect_moneyflow_ths(DB_URL, TOKEN, mode="stock", symbol="600519.SH", start_date="2025-11-01", end_date="2025-11-07") # collect_moneyflow_ths(DB_URL, TOKEN, mode="daily", date="2025-11-07")