This commit is contained in:
满脸小星星 2025-11-28 15:34:10 +08:00
parent b4c8b6d1b4
commit b3332eb920
7 changed files with 828 additions and 18 deletions

View File

@ -2,6 +2,7 @@
import sys
import os
from datetime import datetime, timedelta
from typing import List
import pandas as pd
import uuid
import json
@ -74,6 +75,9 @@ from src.tushare_scripts.chip_distribution_collector import collect_chip_distrib
from src.tushare_scripts.stock_factor_collector import collect_stock_factor
from src.tushare_scripts.stock_factor_pro_collector import collect_stock_factor_pro
# 导入科技主题基本面因子选股策略
from src.quantitative_analysis.tech_fundamental_factor_strategy_v3 import TechFundamentalFactorStrategy
# 设置日志
logging.basicConfig(
level=logging.INFO,
@ -460,7 +464,7 @@ def run_chip_distribution_collection():
# mode='full'
# )
collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
start_date='2022-01-03', end_date='2022-12-31')
start_date='2021-01-03', end_date='2021-09-30')
# collect_chip_distribution(
# db_url=db_url,
# tushare_token=TUSHARE_TOKEN,
@ -3780,6 +3784,76 @@ def analyze_stock_overlap():
}), 500
@app.route('/scheduler/techFundamentalStrategy/batch', methods=['GET', 'POST'])
def run_tech_fundamental_strategy_batch():
"""批量运行科技主题基本面因子选股策略并保存到数据库"""
try:
# 获取参数
start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date')
end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date')
if not start_date or not end_date:
return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400
# 验证日期
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
# 数据库连接(用于检查交易日数据)
from sqlalchemy import create_engine
db_url = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
engine = create_engine(db_url)
# 循环日期区间
success_count = 0
failed_count = 0
skipped_count = 0
current_date = start_dt
while current_date <= end_dt:
trade_date = current_date.strftime('%Y-%m-%d')
trade_datetime = f"{trade_date} 00:00:00"
# 检查该日期是否有数据
check_query = text("""
SELECT COUNT(0) as cnt
FROM gp_day_data
WHERE `timestamp` = :trade_datetime
""")
with engine.connect() as conn:
result = conn.execute(check_query, {"trade_datetime": trade_datetime}).fetchone()
count = result[0] if result else 0
# 只有当有数据时才运行策略
if count > 0:
try:
strategy = TechFundamentalFactorStrategy(target_date=trade_date)
strategy.run_strategy()
strategy.close_connections()
success_count += 1
except Exception as e:
failed_count += 1
logger.error(f"日期 {trade_date} 处理失败: {str(e)}")
else:
skipped_count += 1
current_date += timedelta(days=1)
engine.dispose()
return jsonify({
"status": "success",
"success_count": success_count,
"failed_count": failed_count,
"skipped_count": skipped_count
})
except Exception as e:
logger.error(f"批量运行失败: {str(e)}")
return jsonify({"status": "error", "message": str(e)}), 500
if __name__ == '__main__':
# 启动Web服务器

View File

@ -55,16 +55,32 @@ class AverageDistanceFactor:
print(f"获取股票列表失败: {e}")
return []
def get_stock_data(self, symbols, days=20):
"""获取股票的历史数据"""
def get_stock_data(self, symbols, days=20, end_date=None):
"""
获取股票的历史数据
Args:
symbols: 股票代码列表
days: 需要获取的天数
end_date: 结束日期datetime对象或字符串如果为None则使用今天
"""
if not symbols:
return pd.DataFrame()
# 计算开始日期
end_date = datetime.now()
start_date = end_date - timedelta(days=days * 2) # 多取一些数据以防节假日
# 计算结束日期(及时数据使用目标日期当天)
if end_date is None:
end_date = datetime.now()
elif isinstance(end_date, str):
end_date = datetime.strptime(end_date, '%Y-%m-%d')
# 构建SQL查询
# 及时数据(日线数据)应该包含目标日期当天的数据
# 例如如果目标日期是2025-11-01则使用2025-11-01及之前的数据
query_end_date = end_date
# 计算开始日期(多取一些数据以防节假日)
start_date = query_end_date - timedelta(days=days * 2)
# 构建SQL查询使用 <= 包含目标日期当天)
symbols_str = "', '".join(symbols)
query = f"""
SELECT symbol, timestamp, volume, open, high, low, close,
@ -72,28 +88,36 @@ class AverageDistanceFactor:
FROM gp_day_data
WHERE symbol IN ('{symbols_str}')
AND timestamp >= '{start_date.strftime('%Y-%m-%d')}'
AND timestamp <= '{end_date.strftime('%Y-%m-%d')}'
ORDER BY symbol, timestamp DESC
"""
try:
df = pd.read_sql(query, self.engine)
print(f"获取到 {len(df)} 条历史数据")
return df
except Exception as e:
print(f"获取历史数据失败: {e}")
return pd.DataFrame()
def calculate_technical_indicators(self, df, days=20):
"""计算技术指标"""
"""
计算技术指标
Args:
df: 股票历史数据DataFrame应已按目标日期筛选
days: 计算指标使用的天数取目标日期前N天
"""
result_data = []
for symbol in df['symbol'].unique():
stock_data = df[df['symbol'] == symbol].copy()
# 按时间升序排序,确保时间顺序正确
stock_data = stock_data.sort_values('timestamp')
# 只取最近N天的数据
# 只取目标日期前N个交易日的数据包括目标日期当天
# tail取最后N行即最接近目标日期的N个交易日包括目标日期当天
stock_data = stock_data.tail(days)
# 检查数据是否足够需要至少有days个交易日的数据
if len(stock_data) < days:
continue # 数据不足,跳过

View File

@ -8,7 +8,11 @@ import sys
import os
# 添加项目根目录到路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py
# 需要获取项目根目录: /app
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis
src_dir = os.path.dirname(current_file_dir) # /app/src
project_root = os.path.dirname(src_dir) # /app (项目根目录)
sys.path.append(project_root)
# 导入配置
@ -16,7 +20,7 @@ try:
from valuation_analysis.config import MONGO_CONFIG2
except ImportError:
import importlib.util
config_path = os.path.join(project_root, 'valuation_analysis', 'config.py')
config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py')
spec = importlib.util.spec_from_file_location("config", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
@ -27,7 +31,12 @@ try:
from tools.stock_code_formatter import StockCodeFormatter
except ImportError:
import importlib.util
formatter_path = os.path.join(os.path.dirname(project_root), 'tools', 'stock_code_formatter.py')
# project_root 已经是项目根目录,直接拼接 tools 目录
formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py')
if not os.path.exists(formatter_path):
raise ImportError(f"无法找到 stock_code_formatter.py路径: {formatter_path}")
spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path)
formatter_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(formatter_module)
@ -120,7 +129,7 @@ class CompanyLifecycleFactor:
annual_data = self.collection.find_one(query)
if annual_data:
logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
# logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
return annual_data
else:
logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")

View File

@ -464,11 +464,11 @@ if __name__ == "__main__":
# collect_chip_distribution(db_url, tushare_token, mode='full')
# 3. 采集指定日期的数据
collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-24')
# collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-25')
# 4. 采集指定日期范围的数据
# collect_chip_distribution(db_url, tushare_token, mode='full',
# start_date='2021-11-01', end_date='2021-11-30')
collect_chip_distribution(db_url, tushare_token, mode='full',
start_date='2021-01-03', end_date='2021-09-30')
# 5. 调整批量入库大小默认100只股票一批
# collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200)

View File

@ -0,0 +1,242 @@
# coding:utf-8
"""
同花顺概念板块成分股采集工具
功能 Tushare 获取同花顺概念板块成分股数据并落库
API 文档: https://tushare.pro/document/2?doc_id=261
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSConceptMemberCollector:
"""同花顺概念板块成分股采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_concept_member"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 ths_concept_member
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺概念板块成分股采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
def fetch_concept_index_list(self) -> pd.DataFrame:
"""
获取所有同花顺概念板块列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Returns:
概念板块列表 DataFrame
"""
try:
print("正在获取同花顺概念板块列表...")
df = self.pro.ths_index(
exchange='A',
type='N', # N=概念板块
)
if df.empty:
print("未获取到概念板块数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个概念板块")
return df
except Exception as exc:
print(f"获取概念板块列表失败: {exc}")
return pd.DataFrame()
def fetch_concept_members(self, ts_code: str) -> pd.DataFrame:
"""
获取指定概念板块的成分股
API: ths_member
文档: https://tushare.pro/document/2?doc_id=261
Args:
ts_code: 概念板块代码 '885800.TI'
Returns:
成分股列表 DataFrame
"""
try:
df = self.pro.ths_member(ts_code=ts_code)
return df
except Exception as exc:
print(f"获取概念板块 {ts_code} 成分股失败: {exc}")
return pd.DataFrame()
def transform_data(self, member_df: pd.DataFrame, concept_name: str = "") -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
member_df: 成分股信息包含 ts_code, con_code, con_name
concept_name: 概念板块名称从概念板块列表中获取
Returns:
转换后的 DataFrame
"""
if member_df.empty:
return pd.DataFrame()
# 检查必要字段
if "ts_code" not in member_df.columns:
print(f"警告: 未找到概念板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
if "con_code" not in member_df.columns:
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
result = pd.DataFrame()
# ts_code 是概念板块代码
result["concept_code"] = member_df["ts_code"]
result["concept_name"] = concept_name
# con_code 是股票代码
result["stock_ts_code"] = member_df["con_code"]
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
# con_name 是股票名称
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
# is_new 是否最新接口有返回但weight、in_date、out_date暂无数据暂不存储
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(self) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取所有概念板块列表
- 遍历每个概念板块获取成分股数据
- 全量写入数据库
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺概念板块成分股)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取所有概念板块列表
concept_list_df = self.fetch_concept_index_list()
if concept_list_df.empty:
print("未获取到概念板块列表,采集终止")
return
total_records = 0
success_count = 0
failed_count = 0
# 遍历每个概念板块,获取成分股
for idx, row in concept_list_df.iterrows():
concept_code = row["ts_code"]
concept_name = row["name"]
print(f"\n[{idx + 1}/{len(concept_list_df)}] 正在采集: {concept_name} ({concept_code})")
try:
# 获取该概念板块的成分股
member_df = self.fetch_concept_members(concept_code)
if member_df.empty:
print(f" 概念板块 {concept_name} 无成分股数据")
failed_count += 1
continue
# 转换数据格式(传入概念板块名称)
result_df = self.transform_data(member_df, concept_name)
if result_df.empty:
print(f" 概念板块 {concept_name} 数据转换失败")
failed_count += 1
continue
# 保存数据
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
print(f" 成功采集 {len(result_df)} 只成分股")
except Exception as exc:
print(f" 采集 {concept_name} ({concept_code}) 失败: {exc}")
failed_count += 1
continue
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print(f"总概念板块数: {len(concept_list_df)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计成分股记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_concept_member(
db_url: str,
tushare_token: str,
):
"""
采集入口 - 全量覆盖采集
"""
collector = THSConceptMemberCollector(db_url, tushare_token)
collector.run_full_collection()
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集
collect_ths_concept_member(DB_URL, TOKEN)

View File

@ -0,0 +1,225 @@
# coding:utf-8
"""
同花顺概念和行业指数采集工具
功能 Tushare 获取同花顺概念和行业指数数据并落库
API 文档: https://tushare.pro/document/2?doc_id=259
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSIndexCollector:
"""同花顺概念和行业指数采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_index"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 ths_index
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺概念和行业指数采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
def fetch_index_list(
self,
exchange: Optional[str] = None,
index_type: Optional[str] = None,
) -> pd.DataFrame:
"""
获取同花顺概念和行业指数列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Args:
exchange: 市场类型 A-a股 HK-港股 US-美股None表示全部
index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数
ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数None表示全部
Returns:
指数列表 DataFrame
"""
try:
params = {}
if exchange:
params['exchange'] = exchange
if index_type:
params['type'] = index_type
print(f"正在获取同花顺指数列表...")
if params:
print(f"筛选条件: {params}")
df = self.pro.ths_index(**params)
if df.empty:
print("未获取到指数数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个指数")
return df
except Exception as exc:
print(f"获取指数列表失败: {exc}")
return pd.DataFrame()
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
df: 指数信息 DataFrame
Returns:
转换后的 DataFrame
"""
if df.empty:
return pd.DataFrame()
result = pd.DataFrame()
result["ts_code"] = df["ts_code"]
result["name"] = df["name"]
result["count"] = df.get("count", None)
result["exchange"] = df.get("exchange", None)
result["list_date"] = pd.to_datetime(df["list_date"], format="%Y%m%d", errors='coerce') if "list_date" in df.columns else None
result["type"] = df.get("type", None)
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(
self,
exchange: Optional[str] = None,
index_type: Optional[str] = None,
) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取所有指数数据
- 全量写入数据库
Args:
exchange: 市场类型None表示全部
index_type: 指数类型None表示全部
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺概念和行业指数)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
# 如果指定了筛选条件,只删除符合条件的记录;否则清空全部
if exchange or index_type:
delete_conditions = []
params = {}
if exchange:
delete_conditions.append("exchange = :exchange")
params["exchange"] = exchange
if index_type:
delete_conditions.append("type = :type")
params["type"] = index_type
delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}")
conn.execute(delete_sql, params)
print(f"{self.table_name} 已删除符合条件的旧数据")
else:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取指数列表
index_df = self.fetch_index_list(exchange=exchange, index_type=index_type)
if index_df.empty:
print("未获取到指数数据,采集终止")
return
# 转换数据格式
result_df = self.transform_data(index_df)
if result_df.empty:
print("数据转换失败")
return
# 保存数据
self.save_dataframe(result_df)
print(f"成功写入 {len(result_df)} 条记录")
# 按类型统计
if "type" in result_df.columns:
type_stats = result_df.groupby("type").size()
print("\n按类型统计:")
for idx, count in type_stats.items():
print(f" {idx}: {count}")
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_index(
db_url: str,
tushare_token: str,
exchange: Optional[str] = None,
index_type: Optional[str] = None,
):
"""
采集入口 - 全量覆盖采集
Args:
db_url: 数据库连接URL
tushare_token: Tushare Token
exchange: 市场类型 A-a股 HK-港股 US-美股None表示全部
index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数
ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数None表示全部
"""
collector = THSIndexCollector(db_url, tushare_token)
collector.run_full_collection(exchange=exchange, index_type=index_type)
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集(获取所有类型的指数)
collect_ths_index(DB_URL, TOKEN)
# 如果只想获取特定类型,可以这样调用:
# collect_ths_index(DB_URL, TOKEN, exchange='A', index_type='N') # 只获取A股概念指数

View File

@ -0,0 +1,236 @@
# coding:utf-8
"""
交易日历采集工具
功能 Tushare 获取交易日历数据并落库
API 文档: https://tushare.pro/document/2?doc_id=26
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class TradeCalCollector:
"""交易日历采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "trade_cal"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 trade_cal
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("交易日历采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
def fetch_trade_cal(
self,
exchange: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""
获取交易日历数据
API: trade_cal
文档: https://tushare.pro/document/2?doc_id=26
Args:
exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所None表示全部
start_date: 开始日期格式YYYYMMDDNone表示不限制
end_date: 结束日期格式YYYYMMDDNone表示不限制
Returns:
交易日历 DataFrame
"""
try:
params = {}
if exchange:
params['exchange'] = exchange
if start_date:
params['start_date'] = start_date
if end_date:
params['end_date'] = end_date
print("正在获取交易日历数据...")
if params:
print(f"筛选条件: {params}")
df = self.pro.trade_cal(**params)
if df.empty:
print("未获取到交易日历数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 条交易日历记录")
return df
except Exception as exc:
print(f"获取交易日历数据失败: {exc}")
return pd.DataFrame()
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
df: 交易日历 DataFrame
Returns:
转换后的 DataFrame
"""
if df.empty:
return pd.DataFrame()
result = pd.DataFrame()
result["exchange"] = df["exchange"]
result["cal_date"] = pd.to_datetime(df["cal_date"], format="%Y%m%d", errors='coerce')
result["is_open"] = df.get("is_open", None)
result["pretrade_date"] = pd.to_datetime(df["pretrade_date"], format="%Y%m%d", errors='coerce') if "pretrade_date" in df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(
self,
exchange: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取交易日历数据
- 全量写入数据库
Args:
exchange: 交易所代码None表示全部
start_date: 开始日期格式YYYYMMDDNone表示不限制
end_date: 结束日期格式YYYYMMDDNone表示不限制
"""
print("=" * 60)
print("开始执行全量覆盖采集(交易日历)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
# 如果指定了筛选条件,只删除符合条件的记录;否则清空全部
if exchange or start_date or end_date:
delete_conditions = []
params = {}
if exchange:
delete_conditions.append("exchange = :exchange")
params["exchange"] = exchange
if start_date:
delete_conditions.append("cal_date >= :start_date")
params["start_date"] = datetime.strptime(start_date, "%Y%m%d").date()
if end_date:
delete_conditions.append("cal_date <= :end_date")
params["end_date"] = datetime.strptime(end_date, "%Y%m%d").date()
delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}")
conn.execute(delete_sql, params)
print(f"{self.table_name} 已删除符合条件的旧数据")
else:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取交易日历数据
cal_df = self.fetch_trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
if cal_df.empty:
print("未获取到交易日历数据,采集终止")
return
# 转换数据格式
result_df = self.transform_data(cal_df)
if result_df.empty:
print("数据转换失败")
return
# 保存数据
self.save_dataframe(result_df)
print(f"成功写入 {len(result_df)} 条记录")
# 按交易所统计
if "exchange" in result_df.columns:
print("\n按交易所统计:")
for exchange_name in result_df["exchange"].unique():
exchange_data = result_df[result_df["exchange"] == exchange_name]
total_days = len(exchange_data)
trade_days = int(exchange_data["is_open"].sum()) if "is_open" in exchange_data.columns and exchange_data["is_open"].notna().any() else 0
min_date = exchange_data["cal_date"].min()
max_date = exchange_data["cal_date"].max()
print(f" {exchange_name}: 总天数={total_days}, 交易日={trade_days}, 日期范围={min_date} ~ {max_date}")
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_trade_cal(
db_url: str,
tushare_token: str,
exchange: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
):
"""
采集入口 - 全量覆盖采集
Args:
db_url: 数据库连接URL
tushare_token: Tushare Token
exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所None表示全部
start_date: 开始日期格式YYYYMMDDNone表示不限制
end_date: 结束日期格式YYYYMMDDNone表示不限制
"""
collector = TradeCalCollector(db_url, tushare_token)
collector.run_full_collection(exchange=exchange, start_date=start_date, end_date=end_date)
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集(获取所有交易所的交易日历,使用默认日期范围)
collect_trade_cal(DB_URL, TOKEN)
# 如果只想获取特定交易所,可以这样调用:
# collect_trade_cal(DB_URL, TOKEN, exchange='SSE')