From b3332eb9206cadc87a0cccb3f588e392473f8eb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BB=A1=E8=84=B8=E5=B0=8F=E6=98=9F=E6=98=9F?= Date: Fri, 28 Nov 2025 15:34:10 +0800 Subject: [PATCH] commit; --- src/app.py | 76 +++++- .../average_distance_factor.py | 44 +++- .../company_lifecycle_factor.py | 17 +- .../chip_distribution_collector.py | 6 +- .../ths_concept_member_collector.py | 242 ++++++++++++++++++ src/tushare_scripts/ths_index_collector.py | 225 ++++++++++++++++ src/tushare_scripts/trade_cal_collector.py | 236 +++++++++++++++++ 7 files changed, 828 insertions(+), 18 deletions(-) create mode 100644 src/tushare_scripts/ths_concept_member_collector.py create mode 100644 src/tushare_scripts/ths_index_collector.py create mode 100644 src/tushare_scripts/trade_cal_collector.py diff --git a/src/app.py b/src/app.py index bbfbf50..722aa2f 100644 --- a/src/app.py +++ b/src/app.py @@ -2,6 +2,7 @@ import sys import os from datetime import datetime, timedelta +from typing import List import pandas as pd import uuid import json @@ -74,6 +75,9 @@ from src.tushare_scripts.chip_distribution_collector import collect_chip_distrib from src.tushare_scripts.stock_factor_collector import collect_stock_factor from src.tushare_scripts.stock_factor_pro_collector import collect_stock_factor_pro +# 导入科技主题基本面因子选股策略 +from src.quantitative_analysis.tech_fundamental_factor_strategy_v3 import TechFundamentalFactorStrategy + # 设置日志 logging.basicConfig( level=logging.INFO, @@ -460,7 +464,7 @@ def run_chip_distribution_collection(): # mode='full' # ) collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full', - start_date='2022-01-03', end_date='2022-12-31') + start_date='2021-01-03', end_date='2021-09-30') # collect_chip_distribution( # db_url=db_url, # tushare_token=TUSHARE_TOKEN, @@ -3780,6 +3784,76 @@ def analyze_stock_overlap(): }), 500 +@app.route('/scheduler/techFundamentalStrategy/batch', methods=['GET', 'POST']) +def run_tech_fundamental_strategy_batch(): + """批量运行科技主题基本面因子选股策略并保存到数据库""" + try: + # 获取参数 + start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date') + end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date') + + if not start_date or not end_date: + return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400 + + # 验证日期 + start_dt = datetime.strptime(start_date, '%Y-%m-%d') + end_dt = datetime.strptime(end_date, '%Y-%m-%d') + + # 数据库连接(用于检查交易日数据) + from sqlalchemy import create_engine + db_url = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + engine = create_engine(db_url) + + # 循环日期区间 + success_count = 0 + failed_count = 0 + skipped_count = 0 + current_date = start_dt + + while current_date <= end_dt: + trade_date = current_date.strftime('%Y-%m-%d') + trade_datetime = f"{trade_date} 00:00:00" + + # 检查该日期是否有数据 + check_query = text(""" + SELECT COUNT(0) as cnt + FROM gp_day_data + WHERE `timestamp` = :trade_datetime + """) + + with engine.connect() as conn: + result = conn.execute(check_query, {"trade_datetime": trade_datetime}).fetchone() + count = result[0] if result else 0 + + # 只有当有数据时才运行策略 + if count > 0: + try: + strategy = TechFundamentalFactorStrategy(target_date=trade_date) + strategy.run_strategy() + strategy.close_connections() + success_count += 1 + except Exception as e: + failed_count += 1 + logger.error(f"日期 {trade_date} 处理失败: {str(e)}") + else: + skipped_count += 1 + + current_date += timedelta(days=1) + + engine.dispose() + + return jsonify({ + "status": "success", + "success_count": success_count, + "failed_count": failed_count, + "skipped_count": skipped_count + }) + + except Exception as e: + logger.error(f"批量运行失败: {str(e)}") + return jsonify({"status": "error", "message": str(e)}), 500 + + if __name__ == '__main__': # 启动Web服务器 diff --git a/src/quantitative_analysis/average_distance_factor.py b/src/quantitative_analysis/average_distance_factor.py index a6be34a..03f8680 100644 --- a/src/quantitative_analysis/average_distance_factor.py +++ b/src/quantitative_analysis/average_distance_factor.py @@ -55,16 +55,32 @@ class AverageDistanceFactor: print(f"获取股票列表失败: {e}") return [] - def get_stock_data(self, symbols, days=20): - """获取股票的历史数据""" + def get_stock_data(self, symbols, days=20, end_date=None): + """ + 获取股票的历史数据 + + Args: + symbols: 股票代码列表 + days: 需要获取的天数 + end_date: 结束日期(datetime对象或字符串),如果为None则使用今天 + """ if not symbols: return pd.DataFrame() - # 计算开始日期 - end_date = datetime.now() - start_date = end_date - timedelta(days=days * 2) # 多取一些数据以防节假日 + # 计算结束日期(及时数据使用目标日期当天) + if end_date is None: + end_date = datetime.now() + elif isinstance(end_date, str): + end_date = datetime.strptime(end_date, '%Y-%m-%d') - # 构建SQL查询 + # 及时数据(日线数据)应该包含目标日期当天的数据 + # 例如:如果目标日期是2025-11-01,则使用2025-11-01及之前的数据 + query_end_date = end_date + + # 计算开始日期(多取一些数据以防节假日) + start_date = query_end_date - timedelta(days=days * 2) + + # 构建SQL查询(使用 <= 包含目标日期当天) symbols_str = "', '".join(symbols) query = f""" SELECT symbol, timestamp, volume, open, high, low, close, @@ -72,28 +88,36 @@ class AverageDistanceFactor: FROM gp_day_data WHERE symbol IN ('{symbols_str}') AND timestamp >= '{start_date.strftime('%Y-%m-%d')}' + AND timestamp <= '{end_date.strftime('%Y-%m-%d')}' ORDER BY symbol, timestamp DESC """ try: df = pd.read_sql(query, self.engine) - print(f"获取到 {len(df)} 条历史数据") return df except Exception as e: - print(f"获取历史数据失败: {e}") return pd.DataFrame() def calculate_technical_indicators(self, df, days=20): - """计算技术指标""" + """ + 计算技术指标 + + Args: + df: 股票历史数据DataFrame(应已按目标日期筛选) + days: 计算指标使用的天数(取目标日期前N天) + """ result_data = [] for symbol in df['symbol'].unique(): stock_data = df[df['symbol'] == symbol].copy() + # 按时间升序排序,确保时间顺序正确 stock_data = stock_data.sort_values('timestamp') - # 只取最近N天的数据 + # 只取目标日期前N个交易日的数据(包括目标日期当天) + # tail取最后N行,即最接近目标日期的N个交易日(包括目标日期当天) stock_data = stock_data.tail(days) + # 检查数据是否足够(需要至少有days个交易日的数据) if len(stock_data) < days: continue # 数据不足,跳过 diff --git a/src/quantitative_analysis/company_lifecycle_factor.py b/src/quantitative_analysis/company_lifecycle_factor.py index 5cc228b..b5f7082 100644 --- a/src/quantitative_analysis/company_lifecycle_factor.py +++ b/src/quantitative_analysis/company_lifecycle_factor.py @@ -8,7 +8,11 @@ import sys import os # 添加项目根目录到路径 -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py +# 需要获取项目根目录: /app +current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis +src_dir = os.path.dirname(current_file_dir) # /app/src +project_root = os.path.dirname(src_dir) # /app (项目根目录) sys.path.append(project_root) # 导入配置 @@ -16,7 +20,7 @@ try: from valuation_analysis.config import MONGO_CONFIG2 except ImportError: import importlib.util - config_path = os.path.join(project_root, 'valuation_analysis', 'config.py') + config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py') spec = importlib.util.spec_from_file_location("config", config_path) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) @@ -27,7 +31,12 @@ try: from tools.stock_code_formatter import StockCodeFormatter except ImportError: import importlib.util - formatter_path = os.path.join(os.path.dirname(project_root), 'tools', 'stock_code_formatter.py') + # project_root 已经是项目根目录,直接拼接 tools 目录 + formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py') + + if not os.path.exists(formatter_path): + raise ImportError(f"无法找到 stock_code_formatter.py,路径: {formatter_path}") + spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path) formatter_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(formatter_module) @@ -120,7 +129,7 @@ class CompanyLifecycleFactor: annual_data = self.collection.find_one(query) if annual_data: - logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") + # logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") return annual_data else: logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") diff --git a/src/tushare_scripts/chip_distribution_collector.py b/src/tushare_scripts/chip_distribution_collector.py index 670cddf..a4ea8ea 100644 --- a/src/tushare_scripts/chip_distribution_collector.py +++ b/src/tushare_scripts/chip_distribution_collector.py @@ -464,11 +464,11 @@ if __name__ == "__main__": # collect_chip_distribution(db_url, tushare_token, mode='full') # 3. 采集指定日期的数据 - collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-24') + # collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-25') # 4. 采集指定日期范围的数据 - # collect_chip_distribution(db_url, tushare_token, mode='full', - # start_date='2021-11-01', end_date='2021-11-30') + collect_chip_distribution(db_url, tushare_token, mode='full', + start_date='2021-01-03', end_date='2021-09-30') # 5. 调整批量入库大小(默认100只股票一批) # collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200) diff --git a/src/tushare_scripts/ths_concept_member_collector.py b/src/tushare_scripts/ths_concept_member_collector.py new file mode 100644 index 0000000..4c1f54f --- /dev/null +++ b/src/tushare_scripts/ths_concept_member_collector.py @@ -0,0 +1,242 @@ +# coding:utf-8 +""" +同花顺概念板块成分股采集工具 +功能:从 Tushare 获取同花顺概念板块成分股数据并落库 +API 文档: https://tushare.pro/document/2?doc_id=261 +说明:每次调用时全量覆盖,不需要每日更新 +""" + +import os +import sys +from datetime import datetime +from typing import Optional + +import pandas as pd +import tushare as ts +from sqlalchemy import create_engine, text + +# 添加项目根目录到路径,确保能够读取配置 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(PROJECT_ROOT) + +from src.scripts.config import TUSHARE_TOKEN + + +class THSConceptMemberCollector: + """同花顺概念板块成分股采集器""" + + def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_concept_member"): + """ + Args: + db_url: 数据库连接 URL + tushare_token: Tushare Token + table_name: 目标表名,默认 ths_concept_member + """ + self.engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + ) + self.table_name = table_name + + ts.set_token(tushare_token) + self.pro = ts.pro_api() + + print("=" * 60) + print("同花顺概念板块成分股采集工具") + print(f"目标数据表: {self.table_name}") + print("=" * 60) + + @staticmethod + def convert_tushare_code_to_db(ts_code: str) -> str: + """ + 将 Tushare 代码(600000.SH)转换为数据库代码(SH600000) + """ + if not ts_code or "." not in ts_code: + return ts_code + base, market = ts_code.split(".") + return f"{market}{base}" + + def fetch_concept_index_list(self) -> pd.DataFrame: + """ + 获取所有同花顺概念板块列表 + API: ths_index + 文档: https://tushare.pro/document/2?doc_id=259 + + Returns: + 概念板块列表 DataFrame + """ + try: + print("正在获取同花顺概念板块列表...") + df = self.pro.ths_index( + exchange='A', + type='N', # N=概念板块 + ) + if df.empty: + print("未获取到概念板块数据") + return pd.DataFrame() + print(f"成功获取 {len(df)} 个概念板块") + return df + except Exception as exc: + print(f"获取概念板块列表失败: {exc}") + return pd.DataFrame() + + def fetch_concept_members(self, ts_code: str) -> pd.DataFrame: + """ + 获取指定概念板块的成分股 + API: ths_member + 文档: https://tushare.pro/document/2?doc_id=261 + + Args: + ts_code: 概念板块代码,如 '885800.TI' + + Returns: + 成分股列表 DataFrame + """ + try: + df = self.pro.ths_member(ts_code=ts_code) + return df + except Exception as exc: + print(f"获取概念板块 {ts_code} 成分股失败: {exc}") + return pd.DataFrame() + + def transform_data(self, member_df: pd.DataFrame, concept_name: str = "") -> pd.DataFrame: + """ + 将 Tushare 返回的数据转换为数据库入库格式 + + Args: + member_df: 成分股信息(包含 ts_code, con_code, con_name 等) + concept_name: 概念板块名称(从概念板块列表中获取) + + Returns: + 转换后的 DataFrame + """ + if member_df.empty: + return pd.DataFrame() + + # 检查必要字段 + if "ts_code" not in member_df.columns: + print(f"警告: 未找到概念板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}") + return pd.DataFrame() + if "con_code" not in member_df.columns: + print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}") + return pd.DataFrame() + + result = pd.DataFrame() + # ts_code 是概念板块代码 + result["concept_code"] = member_df["ts_code"] + result["concept_name"] = concept_name + # con_code 是股票代码 + result["stock_ts_code"] = member_df["con_code"] + result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db) + # con_name 是股票名称 + result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else "" + # is_new 是否最新(接口有返回,但weight、in_date、out_date暂无数据,暂不存储) + result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None + result["created_at"] = datetime.now() + result["updated_at"] = datetime.now() + return result + + def save_dataframe(self, df: pd.DataFrame) -> None: + """ + 将数据写入数据库 + """ + if df.empty: + return + df.to_sql(self.table_name, self.engine, if_exists="append", index=False) + + def run_full_collection(self) -> None: + """ + 执行全量覆盖采集: + - 清空目标表 + - 获取所有概念板块列表 + - 遍历每个概念板块,获取成分股数据 + - 全量写入数据库 + """ + print("=" * 60) + print("开始执行全量覆盖采集(同花顺概念板块成分股)") + print("=" * 60) + + try: + # 清空表 + with self.engine.begin() as conn: + conn.execute(text(f"TRUNCATE TABLE {self.table_name}")) + print(f"{self.table_name} 已清空") + + # 获取所有概念板块列表 + concept_list_df = self.fetch_concept_index_list() + if concept_list_df.empty: + print("未获取到概念板块列表,采集终止") + return + + total_records = 0 + success_count = 0 + failed_count = 0 + + # 遍历每个概念板块,获取成分股 + for idx, row in concept_list_df.iterrows(): + concept_code = row["ts_code"] + concept_name = row["name"] + + print(f"\n[{idx + 1}/{len(concept_list_df)}] 正在采集: {concept_name} ({concept_code})") + + try: + # 获取该概念板块的成分股 + member_df = self.fetch_concept_members(concept_code) + if member_df.empty: + print(f" 概念板块 {concept_name} 无成分股数据") + failed_count += 1 + continue + + # 转换数据格式(传入概念板块名称) + result_df = self.transform_data(member_df, concept_name) + if result_df.empty: + print(f" 概念板块 {concept_name} 数据转换失败") + failed_count += 1 + continue + + # 保存数据 + self.save_dataframe(result_df) + total_records += len(result_df) + success_count += 1 + print(f" 成功采集 {len(result_df)} 只成分股") + + except Exception as exc: + print(f" 采集 {concept_name} ({concept_code}) 失败: {exc}") + failed_count += 1 + continue + + print("\n" + "=" * 60) + print("全量覆盖采集完成") + print(f"总概念板块数: {len(concept_list_df)}") + print(f"成功采集: {success_count}") + print(f"失败: {failed_count}") + print(f"累计成分股记录: {total_records}") + print("=" * 60) + except Exception as exc: + print(f"全量采集失败: {exc}") + import traceback + traceback.print_exc() + finally: + self.engine.dispose() + + +def collect_ths_concept_member( + db_url: str, + tushare_token: str, +): + """ + 采集入口 - 全量覆盖采集 + """ + collector = THSConceptMemberCollector(db_url, tushare_token) + collector.run_full_collection() + + +if __name__ == "__main__": + # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" + DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + TOKEN = TUSHARE_TOKEN + + # 执行全量覆盖采集 + collect_ths_concept_member(DB_URL, TOKEN) diff --git a/src/tushare_scripts/ths_index_collector.py b/src/tushare_scripts/ths_index_collector.py new file mode 100644 index 0000000..64a1bd3 --- /dev/null +++ b/src/tushare_scripts/ths_index_collector.py @@ -0,0 +1,225 @@ +# coding:utf-8 +""" +同花顺概念和行业指数采集工具 +功能:从 Tushare 获取同花顺概念和行业指数数据并落库 +API 文档: https://tushare.pro/document/2?doc_id=259 +说明:每次调用时全量覆盖,不需要每日更新 +""" + +import os +import sys +from datetime import datetime +from typing import Optional + +import pandas as pd +import tushare as ts +from sqlalchemy import create_engine, text + +# 添加项目根目录到路径,确保能够读取配置 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(PROJECT_ROOT) + +from src.scripts.config import TUSHARE_TOKEN + + +class THSIndexCollector: + """同花顺概念和行业指数采集器""" + + def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_index"): + """ + Args: + db_url: 数据库连接 URL + tushare_token: Tushare Token + table_name: 目标表名,默认 ths_index + """ + self.engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + ) + self.table_name = table_name + + ts.set_token(tushare_token) + self.pro = ts.pro_api() + + print("=" * 60) + print("同花顺概念和行业指数采集工具") + print(f"目标数据表: {self.table_name}") + print("=" * 60) + + def fetch_index_list( + self, + exchange: Optional[str] = None, + index_type: Optional[str] = None, + ) -> pd.DataFrame: + """ + 获取同花顺概念和行业指数列表 + API: ths_index + 文档: https://tushare.pro/document/2?doc_id=259 + + Args: + exchange: 市场类型 A-a股 HK-港股 US-美股,None表示全部 + index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数 + ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数,None表示全部 + + Returns: + 指数列表 DataFrame + """ + try: + params = {} + if exchange: + params['exchange'] = exchange + if index_type: + params['type'] = index_type + + print(f"正在获取同花顺指数列表...") + if params: + print(f"筛选条件: {params}") + + df = self.pro.ths_index(**params) + if df.empty: + print("未获取到指数数据") + return pd.DataFrame() + print(f"成功获取 {len(df)} 个指数") + return df + except Exception as exc: + print(f"获取指数列表失败: {exc}") + return pd.DataFrame() + + def transform_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + 将 Tushare 返回的数据转换为数据库入库格式 + + Args: + df: 指数信息 DataFrame + + Returns: + 转换后的 DataFrame + """ + if df.empty: + return pd.DataFrame() + + result = pd.DataFrame() + result["ts_code"] = df["ts_code"] + result["name"] = df["name"] + result["count"] = df.get("count", None) + result["exchange"] = df.get("exchange", None) + result["list_date"] = pd.to_datetime(df["list_date"], format="%Y%m%d", errors='coerce') if "list_date" in df.columns else None + result["type"] = df.get("type", None) + result["created_at"] = datetime.now() + result["updated_at"] = datetime.now() + return result + + def save_dataframe(self, df: pd.DataFrame) -> None: + """ + 将数据写入数据库 + """ + if df.empty: + return + df.to_sql(self.table_name, self.engine, if_exists="append", index=False) + + def run_full_collection( + self, + exchange: Optional[str] = None, + index_type: Optional[str] = None, + ) -> None: + """ + 执行全量覆盖采集: + - 清空目标表 + - 获取所有指数数据 + - 全量写入数据库 + + Args: + exchange: 市场类型,None表示全部 + index_type: 指数类型,None表示全部 + """ + print("=" * 60) + print("开始执行全量覆盖采集(同花顺概念和行业指数)") + print("=" * 60) + + try: + # 清空表 + with self.engine.begin() as conn: + # 如果指定了筛选条件,只删除符合条件的记录;否则清空全部 + if exchange or index_type: + delete_conditions = [] + params = {} + if exchange: + delete_conditions.append("exchange = :exchange") + params["exchange"] = exchange + if index_type: + delete_conditions.append("type = :type") + params["type"] = index_type + delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}") + conn.execute(delete_sql, params) + print(f"{self.table_name} 已删除符合条件的旧数据") + else: + conn.execute(text(f"TRUNCATE TABLE {self.table_name}")) + print(f"{self.table_name} 已清空") + + # 获取指数列表 + index_df = self.fetch_index_list(exchange=exchange, index_type=index_type) + if index_df.empty: + print("未获取到指数数据,采集终止") + return + + # 转换数据格式 + result_df = self.transform_data(index_df) + if result_df.empty: + print("数据转换失败") + return + + # 保存数据 + self.save_dataframe(result_df) + print(f"成功写入 {len(result_df)} 条记录") + + # 按类型统计 + if "type" in result_df.columns: + type_stats = result_df.groupby("type").size() + print("\n按类型统计:") + for idx, count in type_stats.items(): + print(f" {idx}: {count} 个") + + print("\n" + "=" * 60) + print("全量覆盖采集完成") + print("=" * 60) + except Exception as exc: + print(f"全量采集失败: {exc}") + import traceback + traceback.print_exc() + finally: + self.engine.dispose() + + +def collect_ths_index( + db_url: str, + tushare_token: str, + exchange: Optional[str] = None, + index_type: Optional[str] = None, +): + """ + 采集入口 - 全量覆盖采集 + + Args: + db_url: 数据库连接URL + tushare_token: Tushare Token + exchange: 市场类型 A-a股 HK-港股 US-美股,None表示全部 + index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数 + ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数,None表示全部 + """ + collector = THSIndexCollector(db_url, tushare_token) + collector.run_full_collection(exchange=exchange, index_type=index_type) + + +if __name__ == "__main__": + # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" + DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + TOKEN = TUSHARE_TOKEN + + # 执行全量覆盖采集(获取所有类型的指数) + collect_ths_index(DB_URL, TOKEN) + + # 如果只想获取特定类型,可以这样调用: + # collect_ths_index(DB_URL, TOKEN, exchange='A', index_type='N') # 只获取A股概念指数 + diff --git a/src/tushare_scripts/trade_cal_collector.py b/src/tushare_scripts/trade_cal_collector.py new file mode 100644 index 0000000..30f6fd3 --- /dev/null +++ b/src/tushare_scripts/trade_cal_collector.py @@ -0,0 +1,236 @@ +# coding:utf-8 +""" +交易日历采集工具 +功能:从 Tushare 获取交易日历数据并落库 +API 文档: https://tushare.pro/document/2?doc_id=26 +说明:每次调用时全量覆盖,不需要每日更新 +""" + +import os +import sys +from datetime import datetime +from typing import Optional + +import pandas as pd +import tushare as ts +from sqlalchemy import create_engine, text + +# 添加项目根目录到路径,确保能够读取配置 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(PROJECT_ROOT) + +from src.scripts.config import TUSHARE_TOKEN + + +class TradeCalCollector: + """交易日历采集器""" + + def __init__(self, db_url: str, tushare_token: str, table_name: str = "trade_cal"): + """ + Args: + db_url: 数据库连接 URL + tushare_token: Tushare Token + table_name: 目标表名,默认 trade_cal + """ + self.engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + ) + self.table_name = table_name + + ts.set_token(tushare_token) + self.pro = ts.pro_api() + + print("=" * 60) + print("交易日历采集工具") + print(f"目标数据表: {self.table_name}") + print("=" * 60) + + def fetch_trade_cal( + self, + exchange: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ) -> pd.DataFrame: + """ + 获取交易日历数据 + API: trade_cal + 文档: https://tushare.pro/document/2?doc_id=26 + + Args: + exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所,None表示全部 + start_date: 开始日期,格式YYYYMMDD,None表示不限制 + end_date: 结束日期,格式YYYYMMDD,None表示不限制 + + Returns: + 交易日历 DataFrame + """ + try: + params = {} + if exchange: + params['exchange'] = exchange + if start_date: + params['start_date'] = start_date + if end_date: + params['end_date'] = end_date + + print("正在获取交易日历数据...") + if params: + print(f"筛选条件: {params}") + + df = self.pro.trade_cal(**params) + if df.empty: + print("未获取到交易日历数据") + return pd.DataFrame() + print(f"成功获取 {len(df)} 条交易日历记录") + return df + except Exception as exc: + print(f"获取交易日历数据失败: {exc}") + return pd.DataFrame() + + def transform_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + 将 Tushare 返回的数据转换为数据库入库格式 + + Args: + df: 交易日历 DataFrame + + Returns: + 转换后的 DataFrame + """ + if df.empty: + return pd.DataFrame() + + result = pd.DataFrame() + result["exchange"] = df["exchange"] + result["cal_date"] = pd.to_datetime(df["cal_date"], format="%Y%m%d", errors='coerce') + result["is_open"] = df.get("is_open", None) + result["pretrade_date"] = pd.to_datetime(df["pretrade_date"], format="%Y%m%d", errors='coerce') if "pretrade_date" in df.columns else None + result["created_at"] = datetime.now() + result["updated_at"] = datetime.now() + return result + + def save_dataframe(self, df: pd.DataFrame) -> None: + """ + 将数据写入数据库 + """ + if df.empty: + return + df.to_sql(self.table_name, self.engine, if_exists="append", index=False) + + def run_full_collection( + self, + exchange: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ) -> None: + """ + 执行全量覆盖采集: + - 清空目标表 + - 获取交易日历数据 + - 全量写入数据库 + + Args: + exchange: 交易所代码,None表示全部 + start_date: 开始日期,格式YYYYMMDD,None表示不限制 + end_date: 结束日期,格式YYYYMMDD,None表示不限制 + """ + print("=" * 60) + print("开始执行全量覆盖采集(交易日历)") + print("=" * 60) + + try: + # 清空表 + with self.engine.begin() as conn: + # 如果指定了筛选条件,只删除符合条件的记录;否则清空全部 + if exchange or start_date or end_date: + delete_conditions = [] + params = {} + if exchange: + delete_conditions.append("exchange = :exchange") + params["exchange"] = exchange + if start_date: + delete_conditions.append("cal_date >= :start_date") + params["start_date"] = datetime.strptime(start_date, "%Y%m%d").date() + if end_date: + delete_conditions.append("cal_date <= :end_date") + params["end_date"] = datetime.strptime(end_date, "%Y%m%d").date() + delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}") + conn.execute(delete_sql, params) + print(f"{self.table_name} 已删除符合条件的旧数据") + else: + conn.execute(text(f"TRUNCATE TABLE {self.table_name}")) + print(f"{self.table_name} 已清空") + + # 获取交易日历数据 + cal_df = self.fetch_trade_cal(exchange=exchange, start_date=start_date, end_date=end_date) + if cal_df.empty: + print("未获取到交易日历数据,采集终止") + return + + # 转换数据格式 + result_df = self.transform_data(cal_df) + if result_df.empty: + print("数据转换失败") + return + + # 保存数据 + self.save_dataframe(result_df) + print(f"成功写入 {len(result_df)} 条记录") + + # 按交易所统计 + if "exchange" in result_df.columns: + print("\n按交易所统计:") + for exchange_name in result_df["exchange"].unique(): + exchange_data = result_df[result_df["exchange"] == exchange_name] + total_days = len(exchange_data) + trade_days = int(exchange_data["is_open"].sum()) if "is_open" in exchange_data.columns and exchange_data["is_open"].notna().any() else 0 + min_date = exchange_data["cal_date"].min() + max_date = exchange_data["cal_date"].max() + print(f" {exchange_name}: 总天数={total_days}, 交易日={trade_days}, 日期范围={min_date} ~ {max_date}") + + print("\n" + "=" * 60) + print("全量覆盖采集完成") + print("=" * 60) + except Exception as exc: + print(f"全量采集失败: {exc}") + import traceback + traceback.print_exc() + finally: + self.engine.dispose() + + +def collect_trade_cal( + db_url: str, + tushare_token: str, + exchange: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, +): + """ + 采集入口 - 全量覆盖采集 + + Args: + db_url: 数据库连接URL + tushare_token: Tushare Token + exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所,None表示全部 + start_date: 开始日期,格式YYYYMMDD,None表示不限制 + end_date: 结束日期,格式YYYYMMDD,None表示不限制 + """ + collector = TradeCalCollector(db_url, tushare_token) + collector.run_full_collection(exchange=exchange, start_date=start_date, end_date=end_date) + + +if __name__ == "__main__": + # DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj" + DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + TOKEN = TUSHARE_TOKEN + + # 执行全量覆盖采集(获取所有交易所的交易日历,使用默认日期范围) + collect_trade_cal(DB_URL, TOKEN) + + # 如果只想获取特定交易所,可以这样调用: + # collect_trade_cal(DB_URL, TOKEN, exchange='SSE') +