commit;
This commit is contained in:
parent
28a1d20529
commit
16f3efe3ae
|
|
@ -743,9 +743,9 @@ def create_clearance_strategy_callback(xt_trader, acc, logger):
|
||||||
sell_volume = current_position
|
sell_volume = current_position
|
||||||
|
|
||||||
if sell_volume > 0:
|
if sell_volume > 0:
|
||||||
# 清仓策略:使用限价单,价格设置为当前价的95%(向下取整到分),确保能成交
|
# 清仓策略:使用限价单,价格设置为当前价的98%(向下取整到分),确保能成交
|
||||||
# 对于卖出,价格越低越容易成交
|
# 对于卖出,价格越低越容易成交
|
||||||
aggressive_price = round(current_price * 0.95, 2)
|
aggressive_price = round(current_price * 0.98, 2)
|
||||||
# 确保价格至少为0.01元
|
# 确保价格至少为0.01元
|
||||||
aggressive_price = max(aggressive_price, 0.01)
|
aggressive_price = max(aggressive_price, 0.01)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,423 @@
|
||||||
|
# backend/app/data_collection/market_classifier_etl.py
|
||||||
|
"""
|
||||||
|
Created on 2025/11/12
|
||||||
|
---------
|
||||||
|
@summary: 市场阶段分片 ,多指数综合市场分类器
|
||||||
|
---------
|
||||||
|
@author: NBR
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import tushare as ts
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
|
||||||
|
class _MultiIndexMarketClassifier:
|
||||||
|
"""
|
||||||
|
多指数综合市场分类器 - 纯ETL版本
|
||||||
|
职责:从Tushare获取数据,计算市场阶段,并将结果写入数据库。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
PROXY_DB_URL = os.environ.get('PROXY_DB_URL')
|
||||||
|
if not PROXY_DB_URL:
|
||||||
|
raise ValueError("PROXY_DB_URL environment variable must be set.")
|
||||||
|
|
||||||
|
# 采用更健壮的连接池配置
|
||||||
|
engine = create_engine(PROXY_DB_URL, pool_pre_ping=True, pool_recycle=3600)
|
||||||
|
|
||||||
|
MARKET_INDICES = {
|
||||||
|
'000001.SH': {'name': '上证指数', 'weight': 0.3},
|
||||||
|
'000300.SH': {'name': '沪深300', 'weight': 0.3},
|
||||||
|
'399006.SZ': {'name': '创业板指', 'weight': 0.2},
|
||||||
|
'000905.SH': {'name': '中证500', 'weight': 0.2}
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, token: str, db_engine):
|
||||||
|
ts.set_token(token)
|
||||||
|
self.pro = ts.pro_api()
|
||||||
|
self.engine = db_engine
|
||||||
|
self._init_database()
|
||||||
|
|
||||||
|
def _init_database(self):
|
||||||
|
"""通过Proxy确保目标表存在。"""
|
||||||
|
create_table_sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS gp_market_category (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
market_type VARCHAR(20) NOT NULL,
|
||||||
|
start_date DATE NOT NULL,
|
||||||
|
end_date DATE NOT NULL,
|
||||||
|
duration_days INT NOT NULL,
|
||||||
|
analysis_date DATE,
|
||||||
|
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE KEY uk_market_period (market_type, start_date, end_date)
|
||||||
|
) ENGINE=InnoDB COMMENT='市场阶段分类表'
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
with connection.begin() as trans:
|
||||||
|
connection.execute(text(create_table_sql))
|
||||||
|
trans.commit()
|
||||||
|
print("✅ 数据库表'gp_market_category'检查/创建完成。")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
print(f"❌ 数据库表初始化失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_all_index_data(self, start_date: str, end_date: str) -> Dict[str, pd.DataFrame]:
|
||||||
|
"""获取所有指数的日线数据"""
|
||||||
|
print("📈 获取多指数数据...")
|
||||||
|
|
||||||
|
all_data = {}
|
||||||
|
for ts_code, info in self.MARKET_INDICES.items():
|
||||||
|
try:
|
||||||
|
df = self.pro.index_daily(ts_code=ts_code,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date)
|
||||||
|
if not df.empty:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
df = df.sort_values('trade_date').reset_index(drop=True)
|
||||||
|
|
||||||
|
# 计算技术指标
|
||||||
|
df['ma20'] = df['close'].rolling(window=20).mean()
|
||||||
|
df['ma60'] = df['close'].rolling(window=60).mean()
|
||||||
|
df['ma120'] = df['close'].rolling(window=120).mean()
|
||||||
|
df['high_60'] = df['close'].rolling(window=60).max()
|
||||||
|
df['low_60'] = df['close'].rolling(window=60).min()
|
||||||
|
df['volume_ma20'] = df['vol'].rolling(window=20).mean()
|
||||||
|
|
||||||
|
all_data[ts_code] = df
|
||||||
|
print(f" ✅ 获取 {info['name']} 数据: {len(df)} 个交易日")
|
||||||
|
else:
|
||||||
|
print(f" ❌ 获取 {info['name']} 数据为空")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ 获取 {info['name']} 数据失败: {e}")
|
||||||
|
|
||||||
|
return all_data
|
||||||
|
|
||||||
|
def calculate_composite_indicators(self, all_data: Dict[str, pd.DataFrame]) -> pd.DataFrame:
|
||||||
|
"""计算综合市场指标"""
|
||||||
|
print("📊 计算综合市场指标...")
|
||||||
|
|
||||||
|
# 获取所有指数的共同交易日
|
||||||
|
common_dates = None
|
||||||
|
for ts_code, df in all_data.items():
|
||||||
|
if common_dates is None:
|
||||||
|
common_dates = set(df['trade_date'])
|
||||||
|
else:
|
||||||
|
common_dates = common_dates.intersection(set(df['trade_date']))
|
||||||
|
|
||||||
|
if not common_dates:
|
||||||
|
raise Exception("没有共同的交易日数据")
|
||||||
|
|
||||||
|
common_dates = sorted(common_dates)
|
||||||
|
composite_data = []
|
||||||
|
|
||||||
|
for date in common_dates:
|
||||||
|
date_metrics = {'trade_date': date}
|
||||||
|
|
||||||
|
# 计算每个指数在该日期的状态
|
||||||
|
bull_signals = 0
|
||||||
|
bear_signals = 0
|
||||||
|
total_weight = 0
|
||||||
|
|
||||||
|
for ts_code, df in all_data.items():
|
||||||
|
weight = self.MARKET_INDICES[ts_code]['weight']
|
||||||
|
row = df[df['trade_date'] == date]
|
||||||
|
|
||||||
|
if not row.empty:
|
||||||
|
close = row['close'].iloc[0]
|
||||||
|
ma20 = row['ma20'].iloc[0]
|
||||||
|
ma60 = row['ma60'].iloc[0]
|
||||||
|
ma120 = row['ma120'].iloc[0]
|
||||||
|
high_60 = row['high_60'].iloc[0]
|
||||||
|
low_60 = row['low_60'].iloc[0]
|
||||||
|
volume_ratio = row['vol'].iloc[0] / row['volume_ma20'].iloc[0] if row['volume_ma20'].iloc[
|
||||||
|
0] > 0 else 1
|
||||||
|
|
||||||
|
# 判断单个指数的状态
|
||||||
|
index_bull_signals = 0
|
||||||
|
index_bear_signals = 0
|
||||||
|
|
||||||
|
# 牛市信号
|
||||||
|
if close > ma60 > ma120: # 多头排列
|
||||||
|
index_bull_signals += 1
|
||||||
|
if close > high_60 * 0.95: # 接近近期高点
|
||||||
|
index_bull_signals += 1
|
||||||
|
if volume_ratio > 1.2: # 放量
|
||||||
|
index_bull_signals += 1
|
||||||
|
if (close - low_60) / low_60 > 0.2: # 从低点上涨超过20%
|
||||||
|
index_bull_signals += 1
|
||||||
|
|
||||||
|
# 熊市信号
|
||||||
|
if close < ma60 < ma120: # 空头排列
|
||||||
|
index_bear_signals += 1
|
||||||
|
if close < low_60 * 1.05: # 接近近期低点
|
||||||
|
index_bear_signals += 1
|
||||||
|
if volume_ratio > 1.1 and close < ma20: # 放量下跌
|
||||||
|
index_bear_signals += 1
|
||||||
|
if (high_60 - close) / high_60 > 0.15: # 从高点下跌超过15%
|
||||||
|
index_bear_signals += 1
|
||||||
|
|
||||||
|
# 加权累计信号
|
||||||
|
if index_bull_signals >= 2:
|
||||||
|
bull_signals += weight
|
||||||
|
if index_bear_signals >= 2:
|
||||||
|
bear_signals += weight
|
||||||
|
|
||||||
|
total_weight += weight
|
||||||
|
|
||||||
|
# 计算综合市场状态
|
||||||
|
if total_weight > 0:
|
||||||
|
bull_ratio = bull_signals / total_weight
|
||||||
|
bear_ratio = bear_signals / total_weight
|
||||||
|
|
||||||
|
if bull_ratio > 0.6: # 60%以上的指数显示牛市信号
|
||||||
|
market_type = 'bull'
|
||||||
|
elif bear_ratio > 0.6: # 60%以上的指数显示熊市信号
|
||||||
|
market_type = 'bear'
|
||||||
|
else:
|
||||||
|
market_type = 'consolidation'
|
||||||
|
else:
|
||||||
|
market_type = 'consolidation'
|
||||||
|
|
||||||
|
date_metrics.update({
|
||||||
|
'market_type': market_type,
|
||||||
|
'bull_ratio': bull_ratio,
|
||||||
|
'bear_ratio': bear_ratio
|
||||||
|
})
|
||||||
|
|
||||||
|
composite_data.append(date_metrics)
|
||||||
|
|
||||||
|
return pd.DataFrame(composite_data)
|
||||||
|
|
||||||
|
def detect_market_phases(self, composite_df: pd.DataFrame) -> List[Dict]:
|
||||||
|
"""检测市场阶段"""
|
||||||
|
print("🔍 检测市场阶段...")
|
||||||
|
|
||||||
|
phases = []
|
||||||
|
current_phase = None
|
||||||
|
|
||||||
|
for i, row in composite_df.iterrows():
|
||||||
|
date = row['trade_date'].date()
|
||||||
|
market_type = row['market_type']
|
||||||
|
|
||||||
|
if current_phase is None:
|
||||||
|
current_phase = {
|
||||||
|
'market_type': market_type,
|
||||||
|
'start_date': date,
|
||||||
|
'end_date': date,
|
||||||
|
'dates': [date]
|
||||||
|
}
|
||||||
|
elif current_phase['market_type'] == market_type:
|
||||||
|
# 相同阶段,更新结束日期
|
||||||
|
current_phase['end_date'] = date
|
||||||
|
current_phase['dates'].append(date)
|
||||||
|
else:
|
||||||
|
# 阶段转换,保存当前阶段
|
||||||
|
phase_duration = (current_phase['end_date'] - current_phase['start_date']).days
|
||||||
|
if phase_duration >= 30: # 至少30天的阶段才保留
|
||||||
|
phases.append(current_phase)
|
||||||
|
|
||||||
|
# 开始新阶段
|
||||||
|
current_phase = {
|
||||||
|
'market_type': market_type,
|
||||||
|
'start_date': date,
|
||||||
|
'end_date': date,
|
||||||
|
'dates': [date]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理最后一个阶段
|
||||||
|
if current_phase:
|
||||||
|
phase_duration = (current_phase['end_date'] - current_phase['start_date']).days
|
||||||
|
if phase_duration >= 30:
|
||||||
|
phases.append(current_phase)
|
||||||
|
|
||||||
|
return phases
|
||||||
|
|
||||||
|
def merge_same_type_phases(self, phases: List[Dict]) -> List[Dict]:
|
||||||
|
"""合并相同类型的相邻阶段"""
|
||||||
|
if len(phases) <= 1:
|
||||||
|
return phases
|
||||||
|
|
||||||
|
merged_phases = []
|
||||||
|
current_phase = phases[0]
|
||||||
|
|
||||||
|
for i in range(1, len(phases)):
|
||||||
|
next_phase = phases[i]
|
||||||
|
|
||||||
|
# 检查是否相邻且类型相同
|
||||||
|
gap_days = (next_phase['start_date'] - current_phase['end_date']).days
|
||||||
|
if (current_phase['market_type'] == next_phase['market_type'] and
|
||||||
|
gap_days <= 60): # 间隔小于60天可以合并
|
||||||
|
|
||||||
|
# 合并阶段
|
||||||
|
current_phase['end_date'] = next_phase['end_date']
|
||||||
|
current_phase['dates'].extend(next_phase['dates'])
|
||||||
|
else:
|
||||||
|
merged_phases.append(current_phase)
|
||||||
|
current_phase = next_phase
|
||||||
|
|
||||||
|
merged_phases.append(current_phase)
|
||||||
|
|
||||||
|
print(f"🔄 阶段合并: {len(phases)} → {len(merged_phases)}")
|
||||||
|
return merged_phases
|
||||||
|
|
||||||
|
def format_phases_for_storage(self, phases: List[Dict]) -> List[Dict]:
|
||||||
|
"""格式化阶段数据用于存储"""
|
||||||
|
formatted_phases = []
|
||||||
|
|
||||||
|
for phase in phases:
|
||||||
|
duration_days = (phase['end_date'] - phase['start_date']).days + 1
|
||||||
|
|
||||||
|
formatted_phase = {
|
||||||
|
'market_type': phase['market_type'],
|
||||||
|
'start_date': phase['start_date'],
|
||||||
|
'end_date': phase['end_date'],
|
||||||
|
'duration_days': duration_days,
|
||||||
|
# 如果需要存储日期集合,可以在这里添加
|
||||||
|
# 'date_set': json.dumps([d.strftime('%Y-%m-%d') for d in sorted(set(phase['dates']))], ensure_ascii=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted_phases.append(formatted_phase)
|
||||||
|
|
||||||
|
return formatted_phases
|
||||||
|
|
||||||
|
def save_phases_to_mysql(self, phases: List[Dict]):
|
||||||
|
"""将市场阶段通过Proxy以事务方式写入MySQL。"""
|
||||||
|
if not phases:
|
||||||
|
print("ℹ️ 没有市场阶段数据需要保存。")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("💾 保存市场阶段数据到MySQL...")
|
||||||
|
phases_df = pd.DataFrame(phases)
|
||||||
|
phases_df['analysis_date'] = datetime.now().date()
|
||||||
|
table_name = 'gp_market_category'
|
||||||
|
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
with connection.begin() as transaction:
|
||||||
|
try:
|
||||||
|
connection.execute(text(f"TRUNCATE TABLE {table_name}"))
|
||||||
|
phases_df.to_sql(
|
||||||
|
name=table_name,
|
||||||
|
con=connection,
|
||||||
|
if_exists='append',
|
||||||
|
index=False,
|
||||||
|
chunksize=1000
|
||||||
|
)
|
||||||
|
transaction.commit()
|
||||||
|
print(f"✅ 保存完成: 成功写入 {len(phases_df)} 个市场阶段。")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
print(f"❌ 保存数据到 {table_name} 失败: {e}")
|
||||||
|
transaction.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def run_comprehensive_analysis(self, start_date: str, end_date: str):
|
||||||
|
"""执行完整的ETL流程。"""
|
||||||
|
print("🎯 开始多指数市场分析 (ETL)...")
|
||||||
|
print("=" * 60)
|
||||||
|
try:
|
||||||
|
all_data = self.get_all_index_data(start_date, end_date)
|
||||||
|
if not all_data:
|
||||||
|
print("❌ 无法获取指数数据,ETL终止。")
|
||||||
|
return
|
||||||
|
|
||||||
|
composite_df = self.calculate_composite_indicators(all_data)
|
||||||
|
if composite_df.empty:
|
||||||
|
print("❌ 无法计算综合指标,ETL终止。")
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_phases = self.detect_market_phases(composite_df)
|
||||||
|
merged_phases = self.merge_same_type_phases(raw_phases)
|
||||||
|
final_phases = self.format_phases_for_storage(merged_phases)
|
||||||
|
|
||||||
|
self.save_phases_to_mysql(final_phases)
|
||||||
|
self._print_analysis_results(final_phases)
|
||||||
|
|
||||||
|
print(f"\n🎉 ETL分析完成: 共处理了 {len(final_phases)} 个市场阶段。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ ETL分析过程中出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _print_analysis_results(self, phases: List[Dict]):
|
||||||
|
"""打印分析结果"""
|
||||||
|
print("\n📋 市场阶段分析结果:")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"{'序号':<3} {'市场类型':<10} {'开始日期':<12} {'结束日期':<12} {'持续时间':<8}")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
bull_count = bear_count = consolidation_count = 0
|
||||||
|
|
||||||
|
for i, phase in enumerate(phases, 1):
|
||||||
|
market_type = phase['market_type']
|
||||||
|
if market_type == 'bull':
|
||||||
|
bull_count += 1
|
||||||
|
type_display = '🐂 牛市'
|
||||||
|
elif market_type == 'bear':
|
||||||
|
bear_count += 1
|
||||||
|
type_display = '🐻 熊市'
|
||||||
|
else:
|
||||||
|
consolidation_count += 1
|
||||||
|
type_display = '➰ 震荡市'
|
||||||
|
|
||||||
|
print(f"{i:<3} {type_display:<10} {phase['start_date']} {phase['end_date']} "
|
||||||
|
f"{phase['duration_days']:>6} 天")
|
||||||
|
|
||||||
|
print("-" * 70)
|
||||||
|
print(f"📈 统计: 牛市{bull_count}个, 熊市{bear_count}个, 震荡市{consolidation_count}个")
|
||||||
|
|
||||||
|
|
||||||
|
# --- 唯一的公共入口函数 ---
|
||||||
|
def run_market_classifier_etl(start_date: str, end_date: str):
|
||||||
|
"""
|
||||||
|
ETL入口函数,供Celery任务调用。
|
||||||
|
"""
|
||||||
|
print(f"🚀 [ETL START] Market Classifier for period {start_date} to {end_date}")
|
||||||
|
|
||||||
|
# 1. 从环境变量获取配置
|
||||||
|
tushare_token = os.environ.get('TUSHARE_TOKEN')
|
||||||
|
proxy_db_url = os.environ.get('PROXY_DB_URL')
|
||||||
|
|
||||||
|
if not tushare_token or not proxy_db_url:
|
||||||
|
raise ValueError("TUSHARE_TOKEN and PROXY_DB_URL environment variables must be set.")
|
||||||
|
|
||||||
|
# 2. 创建数据库引擎
|
||||||
|
engine = create_engine(proxy_db_url, pool_pre_ping=True)
|
||||||
|
|
||||||
|
# 3. 实例化并运行
|
||||||
|
try:
|
||||||
|
classifier = _MultiIndexMarketClassifier(tushare_token, engine)
|
||||||
|
classifier.run_comprehensive_analysis(start_date, end_date)
|
||||||
|
print(f"✅ [ETL SUCCESS] Market Classifier finished.")
|
||||||
|
return {"message": f"Market classifier ETL completed for period {start_date} to {end_date}."}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ [ETL FAILED] Market Classifier failed: {e}")
|
||||||
|
# 在Celery任务中,重新抛出异常很重要,这样任务状态才会变为FAILURE
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# --- 用于独立测试的入口 ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# 为了方便测试,从.env文件加载环境变量
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 假设.env文件在项目根目录,即上一级的上一级
|
||||||
|
dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
|
||||||
|
load_dotenv(dotenv_path=dotenv_path)
|
||||||
|
print(f"Loaded .env file from: {dotenv_path}")
|
||||||
|
except ImportError:
|
||||||
|
print("Warning: dotenv not installed. Please set environment variables manually for standalone testing.")
|
||||||
|
|
||||||
|
# 定义测试运行的时间范围
|
||||||
|
test_start_date = '20140101'
|
||||||
|
test_end_date = '20251130'
|
||||||
|
|
||||||
|
run_market_classifier_etl(test_start_date, test_end_date)
|
||||||
|
|
@ -0,0 +1,342 @@
|
||||||
|
# /backend/app/data_collection/stock_grouper_etl.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
Created on 2025/11/12
|
||||||
|
---------
|
||||||
|
@summary: 股票分组脚本 (终极修复版 V5)
|
||||||
|
1. 解决 OOM 问题 (按年分批)
|
||||||
|
2. 解决 跨批次重复问题 (前序状态合并)
|
||||||
|
3. 解决 KeyError: symbol 问题 (手动生成 symbol)
|
||||||
|
---------
|
||||||
|
@author: NBR
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import gc
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import tushare as ts
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
PROXY_DB_URL = os.environ.get('PROXY_DB_URL')
|
||||||
|
TUSHARE_TOKEN = os.environ.get('TUSHARE_TOKEN')
|
||||||
|
|
||||||
|
|
||||||
|
class _StockGroupManager:
|
||||||
|
"""
|
||||||
|
股票分组管理器
|
||||||
|
"""
|
||||||
|
REQUEST_INTERVAL = 0.05
|
||||||
|
|
||||||
|
INDEX_MAPPING = {
|
||||||
|
'沪深300': '000300.SH',
|
||||||
|
'中证500': '000905.SH',
|
||||||
|
'中证1000': '000852.SH',
|
||||||
|
'创业板': '399006.SZ',
|
||||||
|
'科创板': '000688.SH',
|
||||||
|
'上证指数': '000001.SH',
|
||||||
|
'深证成指': '399001.SZ'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, token: str, db_engine, test_mode: bool = False):
|
||||||
|
self.test_mode = test_mode
|
||||||
|
self.pro = ts.pro_api(token)
|
||||||
|
if self.pro is None:
|
||||||
|
raise Exception("Tushare pro_api initialization failed.")
|
||||||
|
|
||||||
|
self.engine = db_engine
|
||||||
|
self.index_constituents_cache = {}
|
||||||
|
self.TEST_STOCKS = ['300750.SZ', '688981.SH']
|
||||||
|
self._init_database()
|
||||||
|
|
||||||
|
def _init_database(self):
|
||||||
|
create_table_sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS gp_stock_category (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
ts_code VARCHAR(20) NOT NULL,
|
||||||
|
symbol VARCHAR(20) NOT NULL,
|
||||||
|
name VARCHAR(100) NOT NULL,
|
||||||
|
trade_date DATE NOT NULL,
|
||||||
|
industry VARCHAR(50),
|
||||||
|
is_st BOOLEAN DEFAULT FALSE,
|
||||||
|
index_series TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE KEY unique_stock_date (ts_code, trade_date)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||||
|
"""
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
with connection.begin() as trans:
|
||||||
|
connection.execute(text(create_table_sql))
|
||||||
|
trans.commit()
|
||||||
|
|
||||||
|
# --- 辅助方法 ---
|
||||||
|
def _get_month_start_end_dates(self, year_month: str):
|
||||||
|
try:
|
||||||
|
year = int(year_month[:4])
|
||||||
|
month = int(year_month[4:6])
|
||||||
|
month_start = datetime(year, month, 1)
|
||||||
|
if month == 12:
|
||||||
|
next_month = datetime(year + 1, 1, 1)
|
||||||
|
else:
|
||||||
|
next_month = datetime(year, month + 1, 1)
|
||||||
|
month_end = next_month - timedelta(days=1)
|
||||||
|
return month_start.strftime('%Y%m%d'), month_end.strftime('%Y%m%d')
|
||||||
|
except:
|
||||||
|
return year_month + '01', year_month + '31'
|
||||||
|
|
||||||
|
def _get_previous_month_key(self, month_key: str) -> str:
|
||||||
|
try:
|
||||||
|
year = int(month_key[:4])
|
||||||
|
month = int(month_key[4:])
|
||||||
|
if month == 1:
|
||||||
|
return f"{year - 1:04d}12"
|
||||||
|
else:
|
||||||
|
return f"{year:04d}{month - 1:02d}"
|
||||||
|
except:
|
||||||
|
return month_key
|
||||||
|
|
||||||
|
def _cache_month_index_constituents(self, month_key: str):
|
||||||
|
if month_key in self.index_constituents_cache:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.index_constituents_cache[month_key] = {}
|
||||||
|
prev_month_key = self._get_previous_month_key(month_key)
|
||||||
|
prev_data = self.index_constituents_cache.get(prev_month_key, {})
|
||||||
|
month_start, month_end = self._get_month_start_end_dates(month_key)
|
||||||
|
|
||||||
|
for index_name, index_code in self.INDEX_MAPPING.items():
|
||||||
|
current_stocks = set()
|
||||||
|
try:
|
||||||
|
df = self.pro.index_weight(index_code=index_code, start_date=month_start, end_date=month_end)
|
||||||
|
if not df.empty:
|
||||||
|
current_stocks = set(df['con_code'].tolist())
|
||||||
|
else:
|
||||||
|
# 前向填充
|
||||||
|
if index_name in prev_data:
|
||||||
|
current_stocks = prev_data[index_name]
|
||||||
|
|
||||||
|
self.index_constituents_cache[month_key][index_name] = current_stocks
|
||||||
|
time.sleep(self.REQUEST_INTERVAL)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" 缓存{index_name}成分股失败: {e}")
|
||||||
|
if index_name in prev_data:
|
||||||
|
self.index_constituents_cache[month_key][index_name] = prev_data[index_name]
|
||||||
|
|
||||||
|
def _get_index_series_for_stock(self, ts_code: str, trade_date: str) -> str:
|
||||||
|
try:
|
||||||
|
index_series = []
|
||||||
|
month_key = trade_date[:6]
|
||||||
|
if month_key not in self.index_constituents_cache:
|
||||||
|
self._cache_month_index_constituents(month_key)
|
||||||
|
|
||||||
|
month_data = self.index_constituents_cache.get(month_key, {})
|
||||||
|
for index_name, stock_set in month_data.items():
|
||||||
|
if ts_code in stock_set:
|
||||||
|
index_series.append(index_name)
|
||||||
|
return ','.join(index_series) if index_series else ''
|
||||||
|
except:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def _is_st_stock(self, stock_name: str) -> bool:
|
||||||
|
if pd.isna(stock_name): return False
|
||||||
|
return 'ST' in stock_name or '*ST' in stock_name
|
||||||
|
|
||||||
|
def _format_code_to_prefix_style(self, ts_code: str) -> str:
|
||||||
|
if '.' in ts_code:
|
||||||
|
parts = ts_code.split('.')
|
||||||
|
if len(parts) == 2:
|
||||||
|
return f"{parts[1].upper()}{parts[0]}"
|
||||||
|
return ts_code
|
||||||
|
|
||||||
|
def get_latest_states_before(self, before_date_str: str) -> pd.DataFrame:
|
||||||
|
"""获取前序基准状态"""
|
||||||
|
print(f" 🔍 查询 {before_date_str} 之前的基准状态 (用于跨年去重)...")
|
||||||
|
query = text("""
|
||||||
|
SELECT t1.ts_code, t1.industry, t1.is_st, t1.index_series, t1.trade_date, t1.name
|
||||||
|
FROM gp_stock_category t1
|
||||||
|
JOIN (
|
||||||
|
SELECT ts_code, MAX(trade_date) as max_dt
|
||||||
|
FROM gp_stock_category
|
||||||
|
WHERE trade_date < :before_date
|
||||||
|
GROUP BY ts_code
|
||||||
|
) t2 ON t1.ts_code = t2.ts_code AND t1.trade_date = t2.max_dt
|
||||||
|
""")
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = pd.read_sql(query, self.engine, params={'before_date': before_date_str})
|
||||||
|
if not df.empty:
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
|
||||||
|
def _revert_code(c):
|
||||||
|
if len(c) > 2 and c[:2] in ['SZ', 'SH']:
|
||||||
|
return f"{c[2:]}.{c[:2]}"
|
||||||
|
return c
|
||||||
|
|
||||||
|
df['ts_code'] = df['ts_code'].apply(_revert_code)
|
||||||
|
|
||||||
|
print(f" ✅ 获取到 {len(df)} 条基准记录。")
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ 获取基准状态失败 (可能是第一次运行): {e}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def get_all_bak_basic_in_range(self, start_date: str, end_date: str) -> pd.DataFrame:
|
||||||
|
print(f" 正在获取 {start_date} 到 {end_date} 的日线数据...")
|
||||||
|
all_dates = pd.date_range(start=start_date, end=end_date, freq='B')
|
||||||
|
all_dfs = []
|
||||||
|
|
||||||
|
for i, date in enumerate(all_dates):
|
||||||
|
date_str = date.strftime('%Y%m%d')
|
||||||
|
try:
|
||||||
|
df = self.pro.bak_basic(trade_date=date_str)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
all_dfs.append(df)
|
||||||
|
time.sleep(self.REQUEST_INTERVAL)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" 获取 {date_str} 数据失败: {e}")
|
||||||
|
|
||||||
|
if not all_dfs:
|
||||||
|
return pd.DataFrame()
|
||||||
|
return pd.concat(all_dfs, ignore_index=True)
|
||||||
|
|
||||||
|
def update_stock_category_data(self, start_date: str, end_date: str):
|
||||||
|
"""核心处理逻辑"""
|
||||||
|
# 1. 基础日线
|
||||||
|
daily_data_df = self.get_all_bak_basic_in_range(start_date, end_date)
|
||||||
|
if daily_data_df.empty:
|
||||||
|
print(" 无数据,本批次结束。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# =================【关键修复】=================
|
||||||
|
# 确保 symbol 列存在,因为后续 merge 需要用到它
|
||||||
|
if 'symbol' not in daily_data_df.columns:
|
||||||
|
# 从 ts_code 生成 symbol (例如 000001.SZ -> 000001)
|
||||||
|
daily_data_df['symbol'] = daily_data_df['ts_code'].apply(lambda x: str(x).split('.')[0])
|
||||||
|
# ==============================================
|
||||||
|
|
||||||
|
print(f" 批量计算属性 ({len(daily_data_df)} 行)...")
|
||||||
|
daily_data_df['trade_date'] = pd.to_datetime(daily_data_df['trade_date'], format='%Y%m%d')
|
||||||
|
daily_data_df['is_st'] = daily_data_df['name'].apply(self._is_st_stock)
|
||||||
|
|
||||||
|
# 2. 预加载指数缓存
|
||||||
|
unique_months = sorted(daily_data_df['trade_date'].dt.strftime('%Y%m').unique())
|
||||||
|
for ym in unique_months:
|
||||||
|
self._cache_month_index_constituents(ym)
|
||||||
|
|
||||||
|
# 3. 计算指数归属
|
||||||
|
daily_data_df['index_series'] = daily_data_df.apply(
|
||||||
|
lambda row: self._get_index_series_for_stock(row['ts_code'], row['trade_date'].strftime('%Y%m%d')),
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 合并前序状态
|
||||||
|
prev_state_df = self.get_latest_states_before(pd.to_datetime(start_date).strftime('%Y-%m-%d'))
|
||||||
|
|
||||||
|
if not prev_state_df.empty:
|
||||||
|
cols = ['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']
|
||||||
|
combined_df = pd.concat([prev_state_df[cols], daily_data_df[cols]], ignore_index=True)
|
||||||
|
else:
|
||||||
|
combined_df = daily_data_df.copy()
|
||||||
|
|
||||||
|
# 5. 变化检测
|
||||||
|
combined_df.sort_values(['ts_code', 'trade_date'], inplace=True)
|
||||||
|
cols_to_check = ['industry', 'is_st', 'index_series']
|
||||||
|
|
||||||
|
shifted = combined_df.groupby('ts_code')[cols_to_check].shift(1)
|
||||||
|
is_changed = (combined_df[cols_to_check].ne(shifted)).any(axis=1)
|
||||||
|
|
||||||
|
# 6. 筛选结果
|
||||||
|
current_batch_start_dt = pd.to_datetime(start_date)
|
||||||
|
records_to_insert = combined_df[is_changed & (combined_df['trade_date'] >= current_batch_start_dt)].copy()
|
||||||
|
|
||||||
|
# 7. 补回 symbol 和 name
|
||||||
|
if not records_to_insert.empty:
|
||||||
|
records_to_insert = pd.merge(
|
||||||
|
records_to_insert[['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']],
|
||||||
|
daily_data_df[['ts_code', 'trade_date', 'symbol', 'name']], # 这里现在肯定有 symbol 了
|
||||||
|
on=['ts_code', 'trade_date'],
|
||||||
|
how='left'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 筛选出 {len(records_to_insert)} 条真实变更记录。")
|
||||||
|
|
||||||
|
# 8. 写入
|
||||||
|
if not records_to_insert.empty:
|
||||||
|
records_to_insert['ts_code'] = records_to_insert['ts_code'].apply(self._format_code_to_prefix_style)
|
||||||
|
# 双重保险:如果 merge 后 symbol 仍有空值(极少情况),再次生成
|
||||||
|
if 'symbol' not in records_to_insert.columns or records_to_insert['symbol'].isnull().any():
|
||||||
|
records_to_insert['symbol'] = records_to_insert['ts_code'].apply(lambda x: x[2:] if len(x) > 2 else x)
|
||||||
|
|
||||||
|
final_df = records_to_insert[
|
||||||
|
['ts_code', 'symbol', 'name', 'trade_date', 'industry', 'is_st', 'index_series']]
|
||||||
|
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
with connection.begin() as trans:
|
||||||
|
del_start = pd.to_datetime(start_date).strftime('%Y-%m-%d')
|
||||||
|
del_end = pd.to_datetime(end_date).strftime('%Y-%m-%d')
|
||||||
|
connection.execute(
|
||||||
|
text("DELETE FROM gp_stock_category WHERE trade_date BETWEEN :s AND :e"),
|
||||||
|
{'s': del_start, 'e': del_end}
|
||||||
|
)
|
||||||
|
final_df.to_sql('gp_stock_category', con=connection, if_exists='append', index=False)
|
||||||
|
trans.commit()
|
||||||
|
print(f" ✅ 写入成功。")
|
||||||
|
|
||||||
|
# 9. 清理
|
||||||
|
del daily_data_df
|
||||||
|
del combined_df
|
||||||
|
del records_to_insert
|
||||||
|
if not prev_state_df.empty: del prev_state_df
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
# --- 运行入口 ---
|
||||||
|
def run_stock_grouper_etl(start_date: str = None, end_date: str = None, test_mode: bool = False,
|
||||||
|
is_full_run: bool = False):
|
||||||
|
if not PROXY_DB_URL or not TUSHARE_TOKEN:
|
||||||
|
raise ValueError("Environment variables required.")
|
||||||
|
|
||||||
|
engine = create_engine(PROXY_DB_URL, pool_pre_ping=True)
|
||||||
|
manager = _StockGroupManager(TUSHARE_TOKEN, engine, test_mode)
|
||||||
|
|
||||||
|
if is_full_run:
|
||||||
|
print("🚀 [Full Run] 启动分批处理模式 (按年)...")
|
||||||
|
current_year = 2016
|
||||||
|
end_year = datetime.now().year
|
||||||
|
|
||||||
|
while current_year <= end_year:
|
||||||
|
batch_start = f"{current_year}0101"
|
||||||
|
batch_end = f"{current_year}1231"
|
||||||
|
|
||||||
|
if current_year == end_year:
|
||||||
|
batch_end = datetime.now().strftime('%Y%m%d')
|
||||||
|
if batch_start > batch_end: break
|
||||||
|
|
||||||
|
print(f"\n>>> 处理批次: {batch_start} - {batch_end}")
|
||||||
|
manager.update_stock_category_data(batch_start, batch_end)
|
||||||
|
current_year += 1
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("🎉 全量重跑完成。")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not start_date:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
res = conn.execute(text("SELECT MAX(trade_date) FROM gp_stock_category")).scalar()
|
||||||
|
if res:
|
||||||
|
start_date = (res + timedelta(days=1)).strftime('%Y%m%d')
|
||||||
|
else:
|
||||||
|
start_date = '20160101'
|
||||||
|
end_date = datetime.now().strftime('%Y%m%d')
|
||||||
|
|
||||||
|
if start_date > datetime.now().strftime('%Y%m%d'):
|
||||||
|
return {"message": "No new data."}
|
||||||
|
|
||||||
|
print(f"🚀 [Incremental] 增量更新: {start_date} - {end_date}")
|
||||||
|
manager.update_stock_category_data(start_date, end_date)
|
||||||
|
|
||||||
|
return {"message": "Success"}
|
||||||
Loading…
Reference in New Issue