This commit is contained in:
满脸小星星 2025-11-28 15:34:10 +08:00
parent b4c8b6d1b4
commit b3332eb920
7 changed files with 828 additions and 18 deletions

View File

@ -2,6 +2,7 @@
import sys import sys
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List
import pandas as pd import pandas as pd
import uuid import uuid
import json import json
@ -74,6 +75,9 @@ from src.tushare_scripts.chip_distribution_collector import collect_chip_distrib
from src.tushare_scripts.stock_factor_collector import collect_stock_factor from src.tushare_scripts.stock_factor_collector import collect_stock_factor
from src.tushare_scripts.stock_factor_pro_collector import collect_stock_factor_pro from src.tushare_scripts.stock_factor_pro_collector import collect_stock_factor_pro
# 导入科技主题基本面因子选股策略
from src.quantitative_analysis.tech_fundamental_factor_strategy_v3 import TechFundamentalFactorStrategy
# 设置日志 # 设置日志
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -460,7 +464,7 @@ def run_chip_distribution_collection():
# mode='full' # mode='full'
# ) # )
collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full', collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
start_date='2022-01-03', end_date='2022-12-31') start_date='2021-01-03', end_date='2021-09-30')
# collect_chip_distribution( # collect_chip_distribution(
# db_url=db_url, # db_url=db_url,
# tushare_token=TUSHARE_TOKEN, # tushare_token=TUSHARE_TOKEN,
@ -3780,6 +3784,76 @@ def analyze_stock_overlap():
}), 500 }), 500
@app.route('/scheduler/techFundamentalStrategy/batch', methods=['GET', 'POST'])
def run_tech_fundamental_strategy_batch():
"""批量运行科技主题基本面因子选股策略并保存到数据库"""
try:
# 获取参数
start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date')
end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date')
if not start_date or not end_date:
return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400
# 验证日期
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
# 数据库连接(用于检查交易日数据)
from sqlalchemy import create_engine
db_url = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
engine = create_engine(db_url)
# 循环日期区间
success_count = 0
failed_count = 0
skipped_count = 0
current_date = start_dt
while current_date <= end_dt:
trade_date = current_date.strftime('%Y-%m-%d')
trade_datetime = f"{trade_date} 00:00:00"
# 检查该日期是否有数据
check_query = text("""
SELECT COUNT(0) as cnt
FROM gp_day_data
WHERE `timestamp` = :trade_datetime
""")
with engine.connect() as conn:
result = conn.execute(check_query, {"trade_datetime": trade_datetime}).fetchone()
count = result[0] if result else 0
# 只有当有数据时才运行策略
if count > 0:
try:
strategy = TechFundamentalFactorStrategy(target_date=trade_date)
strategy.run_strategy()
strategy.close_connections()
success_count += 1
except Exception as e:
failed_count += 1
logger.error(f"日期 {trade_date} 处理失败: {str(e)}")
else:
skipped_count += 1
current_date += timedelta(days=1)
engine.dispose()
return jsonify({
"status": "success",
"success_count": success_count,
"failed_count": failed_count,
"skipped_count": skipped_count
})
except Exception as e:
logger.error(f"批量运行失败: {str(e)}")
return jsonify({"status": "error", "message": str(e)}), 500
if __name__ == '__main__': if __name__ == '__main__':
# 启动Web服务器 # 启动Web服务器

View File

@ -55,16 +55,32 @@ class AverageDistanceFactor:
print(f"获取股票列表失败: {e}") print(f"获取股票列表失败: {e}")
return [] return []
def get_stock_data(self, symbols, days=20): def get_stock_data(self, symbols, days=20, end_date=None):
"""获取股票的历史数据""" """
获取股票的历史数据
Args:
symbols: 股票代码列表
days: 需要获取的天数
end_date: 结束日期datetime对象或字符串如果为None则使用今天
"""
if not symbols: if not symbols:
return pd.DataFrame() return pd.DataFrame()
# 计算开始日期 # 计算结束日期(及时数据使用目标日期当天)
if end_date is None:
end_date = datetime.now() end_date = datetime.now()
start_date = end_date - timedelta(days=days * 2) # 多取一些数据以防节假日 elif isinstance(end_date, str):
end_date = datetime.strptime(end_date, '%Y-%m-%d')
# 构建SQL查询 # 及时数据(日线数据)应该包含目标日期当天的数据
# 例如如果目标日期是2025-11-01则使用2025-11-01及之前的数据
query_end_date = end_date
# 计算开始日期(多取一些数据以防节假日)
start_date = query_end_date - timedelta(days=days * 2)
# 构建SQL查询使用 <= 包含目标日期当天)
symbols_str = "', '".join(symbols) symbols_str = "', '".join(symbols)
query = f""" query = f"""
SELECT symbol, timestamp, volume, open, high, low, close, SELECT symbol, timestamp, volume, open, high, low, close,
@ -72,28 +88,36 @@ class AverageDistanceFactor:
FROM gp_day_data FROM gp_day_data
WHERE symbol IN ('{symbols_str}') WHERE symbol IN ('{symbols_str}')
AND timestamp >= '{start_date.strftime('%Y-%m-%d')}' AND timestamp >= '{start_date.strftime('%Y-%m-%d')}'
AND timestamp <= '{end_date.strftime('%Y-%m-%d')}'
ORDER BY symbol, timestamp DESC ORDER BY symbol, timestamp DESC
""" """
try: try:
df = pd.read_sql(query, self.engine) df = pd.read_sql(query, self.engine)
print(f"获取到 {len(df)} 条历史数据")
return df return df
except Exception as e: except Exception as e:
print(f"获取历史数据失败: {e}")
return pd.DataFrame() return pd.DataFrame()
def calculate_technical_indicators(self, df, days=20): def calculate_technical_indicators(self, df, days=20):
"""计算技术指标""" """
计算技术指标
Args:
df: 股票历史数据DataFrame应已按目标日期筛选
days: 计算指标使用的天数取目标日期前N天
"""
result_data = [] result_data = []
for symbol in df['symbol'].unique(): for symbol in df['symbol'].unique():
stock_data = df[df['symbol'] == symbol].copy() stock_data = df[df['symbol'] == symbol].copy()
# 按时间升序排序,确保时间顺序正确
stock_data = stock_data.sort_values('timestamp') stock_data = stock_data.sort_values('timestamp')
# 只取最近N天的数据 # 只取目标日期前N个交易日的数据包括目标日期当天
# tail取最后N行即最接近目标日期的N个交易日包括目标日期当天
stock_data = stock_data.tail(days) stock_data = stock_data.tail(days)
# 检查数据是否足够需要至少有days个交易日的数据
if len(stock_data) < days: if len(stock_data) < days:
continue # 数据不足,跳过 continue # 数据不足,跳过

View File

@ -8,7 +8,11 @@ import sys
import os import os
# 添加项目根目录到路径 # 添加项目根目录到路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py
# 需要获取项目根目录: /app
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis
src_dir = os.path.dirname(current_file_dir) # /app/src
project_root = os.path.dirname(src_dir) # /app (项目根目录)
sys.path.append(project_root) sys.path.append(project_root)
# 导入配置 # 导入配置
@ -16,7 +20,7 @@ try:
from valuation_analysis.config import MONGO_CONFIG2 from valuation_analysis.config import MONGO_CONFIG2
except ImportError: except ImportError:
import importlib.util import importlib.util
config_path = os.path.join(project_root, 'valuation_analysis', 'config.py') config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py')
spec = importlib.util.spec_from_file_location("config", config_path) spec = importlib.util.spec_from_file_location("config", config_path)
config_module = importlib.util.module_from_spec(spec) config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module) spec.loader.exec_module(config_module)
@ -27,7 +31,12 @@ try:
from tools.stock_code_formatter import StockCodeFormatter from tools.stock_code_formatter import StockCodeFormatter
except ImportError: except ImportError:
import importlib.util import importlib.util
formatter_path = os.path.join(os.path.dirname(project_root), 'tools', 'stock_code_formatter.py') # project_root 已经是项目根目录,直接拼接 tools 目录
formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py')
if not os.path.exists(formatter_path):
raise ImportError(f"无法找到 stock_code_formatter.py路径: {formatter_path}")
spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path) spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path)
formatter_module = importlib.util.module_from_spec(spec) formatter_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(formatter_module) spec.loader.exec_module(formatter_module)
@ -120,7 +129,7 @@ class CompanyLifecycleFactor:
annual_data = self.collection.find_one(query) annual_data = self.collection.find_one(query)
if annual_data: if annual_data:
logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") # logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
return annual_data return annual_data
else: else:
logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")

View File

@ -464,11 +464,11 @@ if __name__ == "__main__":
# collect_chip_distribution(db_url, tushare_token, mode='full') # collect_chip_distribution(db_url, tushare_token, mode='full')
# 3. 采集指定日期的数据 # 3. 采集指定日期的数据
collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-24') # collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-25')
# 4. 采集指定日期范围的数据 # 4. 采集指定日期范围的数据
# collect_chip_distribution(db_url, tushare_token, mode='full', collect_chip_distribution(db_url, tushare_token, mode='full',
# start_date='2021-11-01', end_date='2021-11-30') start_date='2021-01-03', end_date='2021-09-30')
# 5. 调整批量入库大小默认100只股票一批 # 5. 调整批量入库大小默认100只股票一批
# collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200) # collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200)

View File

@ -0,0 +1,242 @@
# coding:utf-8
"""
同花顺概念板块成分股采集工具
功能 Tushare 获取同花顺概念板块成分股数据并落库
API 文档: https://tushare.pro/document/2?doc_id=261
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSConceptMemberCollector:
"""同花顺概念板块成分股采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_concept_member"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 ths_concept_member
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺概念板块成分股采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
@staticmethod
def convert_tushare_code_to_db(ts_code: str) -> str:
"""
Tushare 代码600000.SH转换为数据库代码SH600000
"""
if not ts_code or "." not in ts_code:
return ts_code
base, market = ts_code.split(".")
return f"{market}{base}"
def fetch_concept_index_list(self) -> pd.DataFrame:
"""
获取所有同花顺概念板块列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Returns:
概念板块列表 DataFrame
"""
try:
print("正在获取同花顺概念板块列表...")
df = self.pro.ths_index(
exchange='A',
type='N', # N=概念板块
)
if df.empty:
print("未获取到概念板块数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个概念板块")
return df
except Exception as exc:
print(f"获取概念板块列表失败: {exc}")
return pd.DataFrame()
def fetch_concept_members(self, ts_code: str) -> pd.DataFrame:
"""
获取指定概念板块的成分股
API: ths_member
文档: https://tushare.pro/document/2?doc_id=261
Args:
ts_code: 概念板块代码 '885800.TI'
Returns:
成分股列表 DataFrame
"""
try:
df = self.pro.ths_member(ts_code=ts_code)
return df
except Exception as exc:
print(f"获取概念板块 {ts_code} 成分股失败: {exc}")
return pd.DataFrame()
def transform_data(self, member_df: pd.DataFrame, concept_name: str = "") -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
member_df: 成分股信息包含 ts_code, con_code, con_name
concept_name: 概念板块名称从概念板块列表中获取
Returns:
转换后的 DataFrame
"""
if member_df.empty:
return pd.DataFrame()
# 检查必要字段
if "ts_code" not in member_df.columns:
print(f"警告: 未找到概念板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
if "con_code" not in member_df.columns:
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
return pd.DataFrame()
result = pd.DataFrame()
# ts_code 是概念板块代码
result["concept_code"] = member_df["ts_code"]
result["concept_name"] = concept_name
# con_code 是股票代码
result["stock_ts_code"] = member_df["con_code"]
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
# con_name 是股票名称
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
# is_new 是否最新接口有返回但weight、in_date、out_date暂无数据暂不存储
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(self) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取所有概念板块列表
- 遍历每个概念板块获取成分股数据
- 全量写入数据库
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺概念板块成分股)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取所有概念板块列表
concept_list_df = self.fetch_concept_index_list()
if concept_list_df.empty:
print("未获取到概念板块列表,采集终止")
return
total_records = 0
success_count = 0
failed_count = 0
# 遍历每个概念板块,获取成分股
for idx, row in concept_list_df.iterrows():
concept_code = row["ts_code"]
concept_name = row["name"]
print(f"\n[{idx + 1}/{len(concept_list_df)}] 正在采集: {concept_name} ({concept_code})")
try:
# 获取该概念板块的成分股
member_df = self.fetch_concept_members(concept_code)
if member_df.empty:
print(f" 概念板块 {concept_name} 无成分股数据")
failed_count += 1
continue
# 转换数据格式(传入概念板块名称)
result_df = self.transform_data(member_df, concept_name)
if result_df.empty:
print(f" 概念板块 {concept_name} 数据转换失败")
failed_count += 1
continue
# 保存数据
self.save_dataframe(result_df)
total_records += len(result_df)
success_count += 1
print(f" 成功采集 {len(result_df)} 只成分股")
except Exception as exc:
print(f" 采集 {concept_name} ({concept_code}) 失败: {exc}")
failed_count += 1
continue
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print(f"总概念板块数: {len(concept_list_df)}")
print(f"成功采集: {success_count}")
print(f"失败: {failed_count}")
print(f"累计成分股记录: {total_records}")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_concept_member(
db_url: str,
tushare_token: str,
):
"""
采集入口 - 全量覆盖采集
"""
collector = THSConceptMemberCollector(db_url, tushare_token)
collector.run_full_collection()
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集
collect_ths_concept_member(DB_URL, TOKEN)

View File

@ -0,0 +1,225 @@
# coding:utf-8
"""
同花顺概念和行业指数采集工具
功能 Tushare 获取同花顺概念和行业指数数据并落库
API 文档: https://tushare.pro/document/2?doc_id=259
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class THSIndexCollector:
"""同花顺概念和行业指数采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_index"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 ths_index
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("同花顺概念和行业指数采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
def fetch_index_list(
self,
exchange: Optional[str] = None,
index_type: Optional[str] = None,
) -> pd.DataFrame:
"""
获取同花顺概念和行业指数列表
API: ths_index
文档: https://tushare.pro/document/2?doc_id=259
Args:
exchange: 市场类型 A-a股 HK-港股 US-美股None表示全部
index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数
ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数None表示全部
Returns:
指数列表 DataFrame
"""
try:
params = {}
if exchange:
params['exchange'] = exchange
if index_type:
params['type'] = index_type
print(f"正在获取同花顺指数列表...")
if params:
print(f"筛选条件: {params}")
df = self.pro.ths_index(**params)
if df.empty:
print("未获取到指数数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 个指数")
return df
except Exception as exc:
print(f"获取指数列表失败: {exc}")
return pd.DataFrame()
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
df: 指数信息 DataFrame
Returns:
转换后的 DataFrame
"""
if df.empty:
return pd.DataFrame()
result = pd.DataFrame()
result["ts_code"] = df["ts_code"]
result["name"] = df["name"]
result["count"] = df.get("count", None)
result["exchange"] = df.get("exchange", None)
result["list_date"] = pd.to_datetime(df["list_date"], format="%Y%m%d", errors='coerce') if "list_date" in df.columns else None
result["type"] = df.get("type", None)
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(
self,
exchange: Optional[str] = None,
index_type: Optional[str] = None,
) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取所有指数数据
- 全量写入数据库
Args:
exchange: 市场类型None表示全部
index_type: 指数类型None表示全部
"""
print("=" * 60)
print("开始执行全量覆盖采集(同花顺概念和行业指数)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
# 如果指定了筛选条件,只删除符合条件的记录;否则清空全部
if exchange or index_type:
delete_conditions = []
params = {}
if exchange:
delete_conditions.append("exchange = :exchange")
params["exchange"] = exchange
if index_type:
delete_conditions.append("type = :type")
params["type"] = index_type
delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}")
conn.execute(delete_sql, params)
print(f"{self.table_name} 已删除符合条件的旧数据")
else:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取指数列表
index_df = self.fetch_index_list(exchange=exchange, index_type=index_type)
if index_df.empty:
print("未获取到指数数据,采集终止")
return
# 转换数据格式
result_df = self.transform_data(index_df)
if result_df.empty:
print("数据转换失败")
return
# 保存数据
self.save_dataframe(result_df)
print(f"成功写入 {len(result_df)} 条记录")
# 按类型统计
if "type" in result_df.columns:
type_stats = result_df.groupby("type").size()
print("\n按类型统计:")
for idx, count in type_stats.items():
print(f" {idx}: {count}")
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_ths_index(
db_url: str,
tushare_token: str,
exchange: Optional[str] = None,
index_type: Optional[str] = None,
):
"""
采集入口 - 全量覆盖采集
Args:
db_url: 数据库连接URL
tushare_token: Tushare Token
exchange: 市场类型 A-a股 HK-港股 US-美股None表示全部
index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数
ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数None表示全部
"""
collector = THSIndexCollector(db_url, tushare_token)
collector.run_full_collection(exchange=exchange, index_type=index_type)
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集(获取所有类型的指数)
collect_ths_index(DB_URL, TOKEN)
# 如果只想获取特定类型,可以这样调用:
# collect_ths_index(DB_URL, TOKEN, exchange='A', index_type='N') # 只获取A股概念指数

View File

@ -0,0 +1,236 @@
# coding:utf-8
"""
交易日历采集工具
功能 Tushare 获取交易日历数据并落库
API 文档: https://tushare.pro/document/2?doc_id=26
说明每次调用时全量覆盖不需要每日更新
"""
import os
import sys
from datetime import datetime
from typing import Optional
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# 添加项目根目录到路径,确保能够读取配置
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(PROJECT_ROOT)
from src.scripts.config import TUSHARE_TOKEN
class TradeCalCollector:
"""交易日历采集器"""
def __init__(self, db_url: str, tushare_token: str, table_name: str = "trade_cal"):
"""
Args:
db_url: 数据库连接 URL
tushare_token: Tushare Token
table_name: 目标表名默认 trade_cal
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600,
)
self.table_name = table_name
ts.set_token(tushare_token)
self.pro = ts.pro_api()
print("=" * 60)
print("交易日历采集工具")
print(f"目标数据表: {self.table_name}")
print("=" * 60)
def fetch_trade_cal(
self,
exchange: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""
获取交易日历数据
API: trade_cal
文档: https://tushare.pro/document/2?doc_id=26
Args:
exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所None表示全部
start_date: 开始日期格式YYYYMMDDNone表示不限制
end_date: 结束日期格式YYYYMMDDNone表示不限制
Returns:
交易日历 DataFrame
"""
try:
params = {}
if exchange:
params['exchange'] = exchange
if start_date:
params['start_date'] = start_date
if end_date:
params['end_date'] = end_date
print("正在获取交易日历数据...")
if params:
print(f"筛选条件: {params}")
df = self.pro.trade_cal(**params)
if df.empty:
print("未获取到交易日历数据")
return pd.DataFrame()
print(f"成功获取 {len(df)} 条交易日历记录")
return df
except Exception as exc:
print(f"获取交易日历数据失败: {exc}")
return pd.DataFrame()
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Tushare 返回的数据转换为数据库入库格式
Args:
df: 交易日历 DataFrame
Returns:
转换后的 DataFrame
"""
if df.empty:
return pd.DataFrame()
result = pd.DataFrame()
result["exchange"] = df["exchange"]
result["cal_date"] = pd.to_datetime(df["cal_date"], format="%Y%m%d", errors='coerce')
result["is_open"] = df.get("is_open", None)
result["pretrade_date"] = pd.to_datetime(df["pretrade_date"], format="%Y%m%d", errors='coerce') if "pretrade_date" in df.columns else None
result["created_at"] = datetime.now()
result["updated_at"] = datetime.now()
return result
def save_dataframe(self, df: pd.DataFrame) -> None:
"""
将数据写入数据库
"""
if df.empty:
return
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
def run_full_collection(
self,
exchange: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> None:
"""
执行全量覆盖采集
- 清空目标表
- 获取交易日历数据
- 全量写入数据库
Args:
exchange: 交易所代码None表示全部
start_date: 开始日期格式YYYYMMDDNone表示不限制
end_date: 结束日期格式YYYYMMDDNone表示不限制
"""
print("=" * 60)
print("开始执行全量覆盖采集(交易日历)")
print("=" * 60)
try:
# 清空表
with self.engine.begin() as conn:
# 如果指定了筛选条件,只删除符合条件的记录;否则清空全部
if exchange or start_date or end_date:
delete_conditions = []
params = {}
if exchange:
delete_conditions.append("exchange = :exchange")
params["exchange"] = exchange
if start_date:
delete_conditions.append("cal_date >= :start_date")
params["start_date"] = datetime.strptime(start_date, "%Y%m%d").date()
if end_date:
delete_conditions.append("cal_date <= :end_date")
params["end_date"] = datetime.strptime(end_date, "%Y%m%d").date()
delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}")
conn.execute(delete_sql, params)
print(f"{self.table_name} 已删除符合条件的旧数据")
else:
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
print(f"{self.table_name} 已清空")
# 获取交易日历数据
cal_df = self.fetch_trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
if cal_df.empty:
print("未获取到交易日历数据,采集终止")
return
# 转换数据格式
result_df = self.transform_data(cal_df)
if result_df.empty:
print("数据转换失败")
return
# 保存数据
self.save_dataframe(result_df)
print(f"成功写入 {len(result_df)} 条记录")
# 按交易所统计
if "exchange" in result_df.columns:
print("\n按交易所统计:")
for exchange_name in result_df["exchange"].unique():
exchange_data = result_df[result_df["exchange"] == exchange_name]
total_days = len(exchange_data)
trade_days = int(exchange_data["is_open"].sum()) if "is_open" in exchange_data.columns and exchange_data["is_open"].notna().any() else 0
min_date = exchange_data["cal_date"].min()
max_date = exchange_data["cal_date"].max()
print(f" {exchange_name}: 总天数={total_days}, 交易日={trade_days}, 日期范围={min_date} ~ {max_date}")
print("\n" + "=" * 60)
print("全量覆盖采集完成")
print("=" * 60)
except Exception as exc:
print(f"全量采集失败: {exc}")
import traceback
traceback.print_exc()
finally:
self.engine.dispose()
def collect_trade_cal(
db_url: str,
tushare_token: str,
exchange: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
):
"""
采集入口 - 全量覆盖采集
Args:
db_url: 数据库连接URL
tushare_token: Tushare Token
exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所None表示全部
start_date: 开始日期格式YYYYMMDDNone表示不限制
end_date: 结束日期格式YYYYMMDDNone表示不限制
"""
collector = TradeCalCollector(db_url, tushare_token)
collector.run_full_collection(exchange=exchange, start_date=start_date, end_date=end_date)
if __name__ == "__main__":
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
TOKEN = TUSHARE_TOKEN
# 执行全量覆盖采集(获取所有交易所的交易日历,使用默认日期范围)
collect_trade_cal(DB_URL, TOKEN)
# 如果只想获取特定交易所,可以这样调用:
# collect_trade_cal(DB_URL, TOKEN, exchange='SSE')