commit;
This commit is contained in:
parent
28a1d20529
commit
16f3efe3ae
|
|
@ -743,9 +743,9 @@ def create_clearance_strategy_callback(xt_trader, acc, logger):
|
|||
sell_volume = current_position
|
||||
|
||||
if sell_volume > 0:
|
||||
# 清仓策略:使用限价单,价格设置为当前价的95%(向下取整到分),确保能成交
|
||||
# 清仓策略:使用限价单,价格设置为当前价的98%(向下取整到分),确保能成交
|
||||
# 对于卖出,价格越低越容易成交
|
||||
aggressive_price = round(current_price * 0.95, 2)
|
||||
aggressive_price = round(current_price * 0.98, 2)
|
||||
# 确保价格至少为0.01元
|
||||
aggressive_price = max(aggressive_price, 0.01)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,423 @@
|
|||
# backend/app/data_collection/market_classifier_etl.py
|
||||
"""
|
||||
Created on 2025/11/12
|
||||
---------
|
||||
@summary: 市场阶段分片 ,多指数综合市场分类器
|
||||
---------
|
||||
@author: NBR
|
||||
"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
||||
class _MultiIndexMarketClassifier:
|
||||
"""
|
||||
多指数综合市场分类器 - 纯ETL版本
|
||||
职责:从Tushare获取数据,计算市场阶段,并将结果写入数据库。
|
||||
"""
|
||||
|
||||
# --- 配置 ---
|
||||
PROXY_DB_URL = os.environ.get('PROXY_DB_URL')
|
||||
if not PROXY_DB_URL:
|
||||
raise ValueError("PROXY_DB_URL environment variable must be set.")
|
||||
|
||||
# 采用更健壮的连接池配置
|
||||
engine = create_engine(PROXY_DB_URL, pool_pre_ping=True, pool_recycle=3600)
|
||||
|
||||
MARKET_INDICES = {
|
||||
'000001.SH': {'name': '上证指数', 'weight': 0.3},
|
||||
'000300.SH': {'name': '沪深300', 'weight': 0.3},
|
||||
'399006.SZ': {'name': '创业板指', 'weight': 0.2},
|
||||
'000905.SH': {'name': '中证500', 'weight': 0.2}
|
||||
}
|
||||
|
||||
def __init__(self, token: str, db_engine):
|
||||
ts.set_token(token)
|
||||
self.pro = ts.pro_api()
|
||||
self.engine = db_engine
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""通过Proxy确保目标表存在。"""
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS gp_market_category (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
market_type VARCHAR(20) NOT NULL,
|
||||
start_date DATE NOT NULL,
|
||||
end_date DATE NOT NULL,
|
||||
duration_days INT NOT NULL,
|
||||
analysis_date DATE,
|
||||
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_market_period (market_type, start_date, end_date)
|
||||
) ENGINE=InnoDB COMMENT='市场阶段分类表'
|
||||
"""
|
||||
try:
|
||||
with self.engine.connect() as connection:
|
||||
with connection.begin() as trans:
|
||||
connection.execute(text(create_table_sql))
|
||||
trans.commit()
|
||||
print("✅ 数据库表'gp_market_category'检查/创建完成。")
|
||||
except SQLAlchemyError as e:
|
||||
print(f"❌ 数据库表初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def get_all_index_data(self, start_date: str, end_date: str) -> Dict[str, pd.DataFrame]:
|
||||
"""获取所有指数的日线数据"""
|
||||
print("📈 获取多指数数据...")
|
||||
|
||||
all_data = {}
|
||||
for ts_code, info in self.MARKET_INDICES.items():
|
||||
try:
|
||||
df = self.pro.index_daily(ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date)
|
||||
if not df.empty:
|
||||
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||
df = df.sort_values('trade_date').reset_index(drop=True)
|
||||
|
||||
# 计算技术指标
|
||||
df['ma20'] = df['close'].rolling(window=20).mean()
|
||||
df['ma60'] = df['close'].rolling(window=60).mean()
|
||||
df['ma120'] = df['close'].rolling(window=120).mean()
|
||||
df['high_60'] = df['close'].rolling(window=60).max()
|
||||
df['low_60'] = df['close'].rolling(window=60).min()
|
||||
df['volume_ma20'] = df['vol'].rolling(window=20).mean()
|
||||
|
||||
all_data[ts_code] = df
|
||||
print(f" ✅ 获取 {info['name']} 数据: {len(df)} 个交易日")
|
||||
else:
|
||||
print(f" ❌ 获取 {info['name']} 数据为空")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 获取 {info['name']} 数据失败: {e}")
|
||||
|
||||
return all_data
|
||||
|
||||
def calculate_composite_indicators(self, all_data: Dict[str, pd.DataFrame]) -> pd.DataFrame:
|
||||
"""计算综合市场指标"""
|
||||
print("📊 计算综合市场指标...")
|
||||
|
||||
# 获取所有指数的共同交易日
|
||||
common_dates = None
|
||||
for ts_code, df in all_data.items():
|
||||
if common_dates is None:
|
||||
common_dates = set(df['trade_date'])
|
||||
else:
|
||||
common_dates = common_dates.intersection(set(df['trade_date']))
|
||||
|
||||
if not common_dates:
|
||||
raise Exception("没有共同的交易日数据")
|
||||
|
||||
common_dates = sorted(common_dates)
|
||||
composite_data = []
|
||||
|
||||
for date in common_dates:
|
||||
date_metrics = {'trade_date': date}
|
||||
|
||||
# 计算每个指数在该日期的状态
|
||||
bull_signals = 0
|
||||
bear_signals = 0
|
||||
total_weight = 0
|
||||
|
||||
for ts_code, df in all_data.items():
|
||||
weight = self.MARKET_INDICES[ts_code]['weight']
|
||||
row = df[df['trade_date'] == date]
|
||||
|
||||
if not row.empty:
|
||||
close = row['close'].iloc[0]
|
||||
ma20 = row['ma20'].iloc[0]
|
||||
ma60 = row['ma60'].iloc[0]
|
||||
ma120 = row['ma120'].iloc[0]
|
||||
high_60 = row['high_60'].iloc[0]
|
||||
low_60 = row['low_60'].iloc[0]
|
||||
volume_ratio = row['vol'].iloc[0] / row['volume_ma20'].iloc[0] if row['volume_ma20'].iloc[
|
||||
0] > 0 else 1
|
||||
|
||||
# 判断单个指数的状态
|
||||
index_bull_signals = 0
|
||||
index_bear_signals = 0
|
||||
|
||||
# 牛市信号
|
||||
if close > ma60 > ma120: # 多头排列
|
||||
index_bull_signals += 1
|
||||
if close > high_60 * 0.95: # 接近近期高点
|
||||
index_bull_signals += 1
|
||||
if volume_ratio > 1.2: # 放量
|
||||
index_bull_signals += 1
|
||||
if (close - low_60) / low_60 > 0.2: # 从低点上涨超过20%
|
||||
index_bull_signals += 1
|
||||
|
||||
# 熊市信号
|
||||
if close < ma60 < ma120: # 空头排列
|
||||
index_bear_signals += 1
|
||||
if close < low_60 * 1.05: # 接近近期低点
|
||||
index_bear_signals += 1
|
||||
if volume_ratio > 1.1 and close < ma20: # 放量下跌
|
||||
index_bear_signals += 1
|
||||
if (high_60 - close) / high_60 > 0.15: # 从高点下跌超过15%
|
||||
index_bear_signals += 1
|
||||
|
||||
# 加权累计信号
|
||||
if index_bull_signals >= 2:
|
||||
bull_signals += weight
|
||||
if index_bear_signals >= 2:
|
||||
bear_signals += weight
|
||||
|
||||
total_weight += weight
|
||||
|
||||
# 计算综合市场状态
|
||||
if total_weight > 0:
|
||||
bull_ratio = bull_signals / total_weight
|
||||
bear_ratio = bear_signals / total_weight
|
||||
|
||||
if bull_ratio > 0.6: # 60%以上的指数显示牛市信号
|
||||
market_type = 'bull'
|
||||
elif bear_ratio > 0.6: # 60%以上的指数显示熊市信号
|
||||
market_type = 'bear'
|
||||
else:
|
||||
market_type = 'consolidation'
|
||||
else:
|
||||
market_type = 'consolidation'
|
||||
|
||||
date_metrics.update({
|
||||
'market_type': market_type,
|
||||
'bull_ratio': bull_ratio,
|
||||
'bear_ratio': bear_ratio
|
||||
})
|
||||
|
||||
composite_data.append(date_metrics)
|
||||
|
||||
return pd.DataFrame(composite_data)
|
||||
|
||||
def detect_market_phases(self, composite_df: pd.DataFrame) -> List[Dict]:
|
||||
"""检测市场阶段"""
|
||||
print("🔍 检测市场阶段...")
|
||||
|
||||
phases = []
|
||||
current_phase = None
|
||||
|
||||
for i, row in composite_df.iterrows():
|
||||
date = row['trade_date'].date()
|
||||
market_type = row['market_type']
|
||||
|
||||
if current_phase is None:
|
||||
current_phase = {
|
||||
'market_type': market_type,
|
||||
'start_date': date,
|
||||
'end_date': date,
|
||||
'dates': [date]
|
||||
}
|
||||
elif current_phase['market_type'] == market_type:
|
||||
# 相同阶段,更新结束日期
|
||||
current_phase['end_date'] = date
|
||||
current_phase['dates'].append(date)
|
||||
else:
|
||||
# 阶段转换,保存当前阶段
|
||||
phase_duration = (current_phase['end_date'] - current_phase['start_date']).days
|
||||
if phase_duration >= 30: # 至少30天的阶段才保留
|
||||
phases.append(current_phase)
|
||||
|
||||
# 开始新阶段
|
||||
current_phase = {
|
||||
'market_type': market_type,
|
||||
'start_date': date,
|
||||
'end_date': date,
|
||||
'dates': [date]
|
||||
}
|
||||
|
||||
# 处理最后一个阶段
|
||||
if current_phase:
|
||||
phase_duration = (current_phase['end_date'] - current_phase['start_date']).days
|
||||
if phase_duration >= 30:
|
||||
phases.append(current_phase)
|
||||
|
||||
return phases
|
||||
|
||||
def merge_same_type_phases(self, phases: List[Dict]) -> List[Dict]:
|
||||
"""合并相同类型的相邻阶段"""
|
||||
if len(phases) <= 1:
|
||||
return phases
|
||||
|
||||
merged_phases = []
|
||||
current_phase = phases[0]
|
||||
|
||||
for i in range(1, len(phases)):
|
||||
next_phase = phases[i]
|
||||
|
||||
# 检查是否相邻且类型相同
|
||||
gap_days = (next_phase['start_date'] - current_phase['end_date']).days
|
||||
if (current_phase['market_type'] == next_phase['market_type'] and
|
||||
gap_days <= 60): # 间隔小于60天可以合并
|
||||
|
||||
# 合并阶段
|
||||
current_phase['end_date'] = next_phase['end_date']
|
||||
current_phase['dates'].extend(next_phase['dates'])
|
||||
else:
|
||||
merged_phases.append(current_phase)
|
||||
current_phase = next_phase
|
||||
|
||||
merged_phases.append(current_phase)
|
||||
|
||||
print(f"🔄 阶段合并: {len(phases)} → {len(merged_phases)}")
|
||||
return merged_phases
|
||||
|
||||
def format_phases_for_storage(self, phases: List[Dict]) -> List[Dict]:
|
||||
"""格式化阶段数据用于存储"""
|
||||
formatted_phases = []
|
||||
|
||||
for phase in phases:
|
||||
duration_days = (phase['end_date'] - phase['start_date']).days + 1
|
||||
|
||||
formatted_phase = {
|
||||
'market_type': phase['market_type'],
|
||||
'start_date': phase['start_date'],
|
||||
'end_date': phase['end_date'],
|
||||
'duration_days': duration_days,
|
||||
# 如果需要存储日期集合,可以在这里添加
|
||||
# 'date_set': json.dumps([d.strftime('%Y-%m-%d') for d in sorted(set(phase['dates']))], ensure_ascii=False)
|
||||
}
|
||||
|
||||
formatted_phases.append(formatted_phase)
|
||||
|
||||
return formatted_phases
|
||||
|
||||
def save_phases_to_mysql(self, phases: List[Dict]):
|
||||
"""将市场阶段通过Proxy以事务方式写入MySQL。"""
|
||||
if not phases:
|
||||
print("ℹ️ 没有市场阶段数据需要保存。")
|
||||
return
|
||||
|
||||
print("💾 保存市场阶段数据到MySQL...")
|
||||
phases_df = pd.DataFrame(phases)
|
||||
phases_df['analysis_date'] = datetime.now().date()
|
||||
table_name = 'gp_market_category'
|
||||
|
||||
with self.engine.connect() as connection:
|
||||
with connection.begin() as transaction:
|
||||
try:
|
||||
connection.execute(text(f"TRUNCATE TABLE {table_name}"))
|
||||
phases_df.to_sql(
|
||||
name=table_name,
|
||||
con=connection,
|
||||
if_exists='append',
|
||||
index=False,
|
||||
chunksize=1000
|
||||
)
|
||||
transaction.commit()
|
||||
print(f"✅ 保存完成: 成功写入 {len(phases_df)} 个市场阶段。")
|
||||
except SQLAlchemyError as e:
|
||||
print(f"❌ 保存数据到 {table_name} 失败: {e}")
|
||||
transaction.rollback()
|
||||
raise
|
||||
|
||||
def run_comprehensive_analysis(self, start_date: str, end_date: str):
|
||||
"""执行完整的ETL流程。"""
|
||||
print("🎯 开始多指数市场分析 (ETL)...")
|
||||
print("=" * 60)
|
||||
try:
|
||||
all_data = self.get_all_index_data(start_date, end_date)
|
||||
if not all_data:
|
||||
print("❌ 无法获取指数数据,ETL终止。")
|
||||
return
|
||||
|
||||
composite_df = self.calculate_composite_indicators(all_data)
|
||||
if composite_df.empty:
|
||||
print("❌ 无法计算综合指标,ETL终止。")
|
||||
return
|
||||
|
||||
raw_phases = self.detect_market_phases(composite_df)
|
||||
merged_phases = self.merge_same_type_phases(raw_phases)
|
||||
final_phases = self.format_phases_for_storage(merged_phases)
|
||||
|
||||
self.save_phases_to_mysql(final_phases)
|
||||
self._print_analysis_results(final_phases)
|
||||
|
||||
print(f"\n🎉 ETL分析完成: 共处理了 {len(final_phases)} 个市场阶段。")
|
||||
except Exception as e:
|
||||
print(f"❌ ETL分析过程中出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _print_analysis_results(self, phases: List[Dict]):
|
||||
"""打印分析结果"""
|
||||
print("\n📋 市场阶段分析结果:")
|
||||
print("=" * 70)
|
||||
print(f"{'序号':<3} {'市场类型':<10} {'开始日期':<12} {'结束日期':<12} {'持续时间':<8}")
|
||||
print("-" * 70)
|
||||
|
||||
bull_count = bear_count = consolidation_count = 0
|
||||
|
||||
for i, phase in enumerate(phases, 1):
|
||||
market_type = phase['market_type']
|
||||
if market_type == 'bull':
|
||||
bull_count += 1
|
||||
type_display = '🐂 牛市'
|
||||
elif market_type == 'bear':
|
||||
bear_count += 1
|
||||
type_display = '🐻 熊市'
|
||||
else:
|
||||
consolidation_count += 1
|
||||
type_display = '➰ 震荡市'
|
||||
|
||||
print(f"{i:<3} {type_display:<10} {phase['start_date']} {phase['end_date']} "
|
||||
f"{phase['duration_days']:>6} 天")
|
||||
|
||||
print("-" * 70)
|
||||
print(f"📈 统计: 牛市{bull_count}个, 熊市{bear_count}个, 震荡市{consolidation_count}个")
|
||||
|
||||
|
||||
# --- 唯一的公共入口函数 ---
|
||||
def run_market_classifier_etl(start_date: str, end_date: str):
|
||||
"""
|
||||
ETL入口函数,供Celery任务调用。
|
||||
"""
|
||||
print(f"🚀 [ETL START] Market Classifier for period {start_date} to {end_date}")
|
||||
|
||||
# 1. 从环境变量获取配置
|
||||
tushare_token = os.environ.get('TUSHARE_TOKEN')
|
||||
proxy_db_url = os.environ.get('PROXY_DB_URL')
|
||||
|
||||
if not tushare_token or not proxy_db_url:
|
||||
raise ValueError("TUSHARE_TOKEN and PROXY_DB_URL environment variables must be set.")
|
||||
|
||||
# 2. 创建数据库引擎
|
||||
engine = create_engine(proxy_db_url, pool_pre_ping=True)
|
||||
|
||||
# 3. 实例化并运行
|
||||
try:
|
||||
classifier = _MultiIndexMarketClassifier(tushare_token, engine)
|
||||
classifier.run_comprehensive_analysis(start_date, end_date)
|
||||
print(f"✅ [ETL SUCCESS] Market Classifier finished.")
|
||||
return {"message": f"Market classifier ETL completed for period {start_date} to {end_date}."}
|
||||
except Exception as e:
|
||||
print(f"❌ [ETL FAILED] Market Classifier failed: {e}")
|
||||
# 在Celery任务中,重新抛出异常很重要,这样任务状态才会变为FAILURE
|
||||
raise
|
||||
|
||||
|
||||
# --- 用于独立测试的入口 ---
|
||||
if __name__ == "__main__":
|
||||
|
||||
# 为了方便测试,从.env文件加载环境变量
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 假设.env文件在项目根目录,即上一级的上一级
|
||||
dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
print(f"Loaded .env file from: {dotenv_path}")
|
||||
except ImportError:
|
||||
print("Warning: dotenv not installed. Please set environment variables manually for standalone testing.")
|
||||
|
||||
# 定义测试运行的时间范围
|
||||
test_start_date = '20140101'
|
||||
test_end_date = '20251130'
|
||||
|
||||
run_market_classifier_etl(test_start_date, test_end_date)
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
# /backend/app/data_collection/stock_grouper_etl.py
|
||||
|
||||
"""
|
||||
Created on 2025/11/12
|
||||
---------
|
||||
@summary: 股票分组脚本 (终极修复版 V5)
|
||||
1. 解决 OOM 问题 (按年分批)
|
||||
2. 解决 跨批次重复问题 (前序状态合并)
|
||||
3. 解决 KeyError: symbol 问题 (手动生成 symbol)
|
||||
---------
|
||||
@author: NBR
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import gc
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# --- 配置 ---
|
||||
PROXY_DB_URL = os.environ.get('PROXY_DB_URL')
|
||||
TUSHARE_TOKEN = os.environ.get('TUSHARE_TOKEN')
|
||||
|
||||
|
||||
class _StockGroupManager:
|
||||
"""
|
||||
股票分组管理器
|
||||
"""
|
||||
REQUEST_INTERVAL = 0.05
|
||||
|
||||
INDEX_MAPPING = {
|
||||
'沪深300': '000300.SH',
|
||||
'中证500': '000905.SH',
|
||||
'中证1000': '000852.SH',
|
||||
'创业板': '399006.SZ',
|
||||
'科创板': '000688.SH',
|
||||
'上证指数': '000001.SH',
|
||||
'深证成指': '399001.SZ'
|
||||
}
|
||||
|
||||
def __init__(self, token: str, db_engine, test_mode: bool = False):
|
||||
self.test_mode = test_mode
|
||||
self.pro = ts.pro_api(token)
|
||||
if self.pro is None:
|
||||
raise Exception("Tushare pro_api initialization failed.")
|
||||
|
||||
self.engine = db_engine
|
||||
self.index_constituents_cache = {}
|
||||
self.TEST_STOCKS = ['300750.SZ', '688981.SH']
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS gp_stock_category (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
ts_code VARCHAR(20) NOT NULL,
|
||||
symbol VARCHAR(20) NOT NULL,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
industry VARCHAR(50),
|
||||
is_st BOOLEAN DEFAULT FALSE,
|
||||
index_series TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY unique_stock_date (ts_code, trade_date)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""
|
||||
with self.engine.connect() as connection:
|
||||
with connection.begin() as trans:
|
||||
connection.execute(text(create_table_sql))
|
||||
trans.commit()
|
||||
|
||||
# --- 辅助方法 ---
|
||||
def _get_month_start_end_dates(self, year_month: str):
|
||||
try:
|
||||
year = int(year_month[:4])
|
||||
month = int(year_month[4:6])
|
||||
month_start = datetime(year, month, 1)
|
||||
if month == 12:
|
||||
next_month = datetime(year + 1, 1, 1)
|
||||
else:
|
||||
next_month = datetime(year, month + 1, 1)
|
||||
month_end = next_month - timedelta(days=1)
|
||||
return month_start.strftime('%Y%m%d'), month_end.strftime('%Y%m%d')
|
||||
except:
|
||||
return year_month + '01', year_month + '31'
|
||||
|
||||
def _get_previous_month_key(self, month_key: str) -> str:
|
||||
try:
|
||||
year = int(month_key[:4])
|
||||
month = int(month_key[4:])
|
||||
if month == 1:
|
||||
return f"{year - 1:04d}12"
|
||||
else:
|
||||
return f"{year:04d}{month - 1:02d}"
|
||||
except:
|
||||
return month_key
|
||||
|
||||
def _cache_month_index_constituents(self, month_key: str):
|
||||
if month_key in self.index_constituents_cache:
|
||||
return
|
||||
|
||||
self.index_constituents_cache[month_key] = {}
|
||||
prev_month_key = self._get_previous_month_key(month_key)
|
||||
prev_data = self.index_constituents_cache.get(prev_month_key, {})
|
||||
month_start, month_end = self._get_month_start_end_dates(month_key)
|
||||
|
||||
for index_name, index_code in self.INDEX_MAPPING.items():
|
||||
current_stocks = set()
|
||||
try:
|
||||
df = self.pro.index_weight(index_code=index_code, start_date=month_start, end_date=month_end)
|
||||
if not df.empty:
|
||||
current_stocks = set(df['con_code'].tolist())
|
||||
else:
|
||||
# 前向填充
|
||||
if index_name in prev_data:
|
||||
current_stocks = prev_data[index_name]
|
||||
|
||||
self.index_constituents_cache[month_key][index_name] = current_stocks
|
||||
time.sleep(self.REQUEST_INTERVAL)
|
||||
except Exception as e:
|
||||
print(f" 缓存{index_name}成分股失败: {e}")
|
||||
if index_name in prev_data:
|
||||
self.index_constituents_cache[month_key][index_name] = prev_data[index_name]
|
||||
|
||||
def _get_index_series_for_stock(self, ts_code: str, trade_date: str) -> str:
|
||||
try:
|
||||
index_series = []
|
||||
month_key = trade_date[:6]
|
||||
if month_key not in self.index_constituents_cache:
|
||||
self._cache_month_index_constituents(month_key)
|
||||
|
||||
month_data = self.index_constituents_cache.get(month_key, {})
|
||||
for index_name, stock_set in month_data.items():
|
||||
if ts_code in stock_set:
|
||||
index_series.append(index_name)
|
||||
return ','.join(index_series) if index_series else ''
|
||||
except:
|
||||
return ''
|
||||
|
||||
def _is_st_stock(self, stock_name: str) -> bool:
|
||||
if pd.isna(stock_name): return False
|
||||
return 'ST' in stock_name or '*ST' in stock_name
|
||||
|
||||
def _format_code_to_prefix_style(self, ts_code: str) -> str:
|
||||
if '.' in ts_code:
|
||||
parts = ts_code.split('.')
|
||||
if len(parts) == 2:
|
||||
return f"{parts[1].upper()}{parts[0]}"
|
||||
return ts_code
|
||||
|
||||
def get_latest_states_before(self, before_date_str: str) -> pd.DataFrame:
|
||||
"""获取前序基准状态"""
|
||||
print(f" 🔍 查询 {before_date_str} 之前的基准状态 (用于跨年去重)...")
|
||||
query = text("""
|
||||
SELECT t1.ts_code, t1.industry, t1.is_st, t1.index_series, t1.trade_date, t1.name
|
||||
FROM gp_stock_category t1
|
||||
JOIN (
|
||||
SELECT ts_code, MAX(trade_date) as max_dt
|
||||
FROM gp_stock_category
|
||||
WHERE trade_date < :before_date
|
||||
GROUP BY ts_code
|
||||
) t2 ON t1.ts_code = t2.ts_code AND t1.trade_date = t2.max_dt
|
||||
""")
|
||||
|
||||
try:
|
||||
df = pd.read_sql(query, self.engine, params={'before_date': before_date_str})
|
||||
if not df.empty:
|
||||
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||
|
||||
def _revert_code(c):
|
||||
if len(c) > 2 and c[:2] in ['SZ', 'SH']:
|
||||
return f"{c[2:]}.{c[:2]}"
|
||||
return c
|
||||
|
||||
df['ts_code'] = df['ts_code'].apply(_revert_code)
|
||||
|
||||
print(f" ✅ 获取到 {len(df)} 条基准记录。")
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 获取基准状态失败 (可能是第一次运行): {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_all_bak_basic_in_range(self, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
print(f" 正在获取 {start_date} 到 {end_date} 的日线数据...")
|
||||
all_dates = pd.date_range(start=start_date, end=end_date, freq='B')
|
||||
all_dfs = []
|
||||
|
||||
for i, date in enumerate(all_dates):
|
||||
date_str = date.strftime('%Y%m%d')
|
||||
try:
|
||||
df = self.pro.bak_basic(trade_date=date_str)
|
||||
if df is not None and not df.empty:
|
||||
all_dfs.append(df)
|
||||
time.sleep(self.REQUEST_INTERVAL)
|
||||
except Exception as e:
|
||||
print(f" 获取 {date_str} 数据失败: {e}")
|
||||
|
||||
if not all_dfs:
|
||||
return pd.DataFrame()
|
||||
return pd.concat(all_dfs, ignore_index=True)
|
||||
|
||||
def update_stock_category_data(self, start_date: str, end_date: str):
|
||||
"""核心处理逻辑"""
|
||||
# 1. 基础日线
|
||||
daily_data_df = self.get_all_bak_basic_in_range(start_date, end_date)
|
||||
if daily_data_df.empty:
|
||||
print(" 无数据,本批次结束。")
|
||||
return
|
||||
|
||||
# =================【关键修复】=================
|
||||
# 确保 symbol 列存在,因为后续 merge 需要用到它
|
||||
if 'symbol' not in daily_data_df.columns:
|
||||
# 从 ts_code 生成 symbol (例如 000001.SZ -> 000001)
|
||||
daily_data_df['symbol'] = daily_data_df['ts_code'].apply(lambda x: str(x).split('.')[0])
|
||||
# ==============================================
|
||||
|
||||
print(f" 批量计算属性 ({len(daily_data_df)} 行)...")
|
||||
daily_data_df['trade_date'] = pd.to_datetime(daily_data_df['trade_date'], format='%Y%m%d')
|
||||
daily_data_df['is_st'] = daily_data_df['name'].apply(self._is_st_stock)
|
||||
|
||||
# 2. 预加载指数缓存
|
||||
unique_months = sorted(daily_data_df['trade_date'].dt.strftime('%Y%m').unique())
|
||||
for ym in unique_months:
|
||||
self._cache_month_index_constituents(ym)
|
||||
|
||||
# 3. 计算指数归属
|
||||
daily_data_df['index_series'] = daily_data_df.apply(
|
||||
lambda row: self._get_index_series_for_stock(row['ts_code'], row['trade_date'].strftime('%Y%m%d')),
|
||||
axis=1
|
||||
)
|
||||
|
||||
# 4. 合并前序状态
|
||||
prev_state_df = self.get_latest_states_before(pd.to_datetime(start_date).strftime('%Y-%m-%d'))
|
||||
|
||||
if not prev_state_df.empty:
|
||||
cols = ['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']
|
||||
combined_df = pd.concat([prev_state_df[cols], daily_data_df[cols]], ignore_index=True)
|
||||
else:
|
||||
combined_df = daily_data_df.copy()
|
||||
|
||||
# 5. 变化检测
|
||||
combined_df.sort_values(['ts_code', 'trade_date'], inplace=True)
|
||||
cols_to_check = ['industry', 'is_st', 'index_series']
|
||||
|
||||
shifted = combined_df.groupby('ts_code')[cols_to_check].shift(1)
|
||||
is_changed = (combined_df[cols_to_check].ne(shifted)).any(axis=1)
|
||||
|
||||
# 6. 筛选结果
|
||||
current_batch_start_dt = pd.to_datetime(start_date)
|
||||
records_to_insert = combined_df[is_changed & (combined_df['trade_date'] >= current_batch_start_dt)].copy()
|
||||
|
||||
# 7. 补回 symbol 和 name
|
||||
if not records_to_insert.empty:
|
||||
records_to_insert = pd.merge(
|
||||
records_to_insert[['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']],
|
||||
daily_data_df[['ts_code', 'trade_date', 'symbol', 'name']], # 这里现在肯定有 symbol 了
|
||||
on=['ts_code', 'trade_date'],
|
||||
how='left'
|
||||
)
|
||||
|
||||
print(f" 筛选出 {len(records_to_insert)} 条真实变更记录。")
|
||||
|
||||
# 8. 写入
|
||||
if not records_to_insert.empty:
|
||||
records_to_insert['ts_code'] = records_to_insert['ts_code'].apply(self._format_code_to_prefix_style)
|
||||
# 双重保险:如果 merge 后 symbol 仍有空值(极少情况),再次生成
|
||||
if 'symbol' not in records_to_insert.columns or records_to_insert['symbol'].isnull().any():
|
||||
records_to_insert['symbol'] = records_to_insert['ts_code'].apply(lambda x: x[2:] if len(x) > 2 else x)
|
||||
|
||||
final_df = records_to_insert[
|
||||
['ts_code', 'symbol', 'name', 'trade_date', 'industry', 'is_st', 'index_series']]
|
||||
|
||||
with self.engine.connect() as connection:
|
||||
with connection.begin() as trans:
|
||||
del_start = pd.to_datetime(start_date).strftime('%Y-%m-%d')
|
||||
del_end = pd.to_datetime(end_date).strftime('%Y-%m-%d')
|
||||
connection.execute(
|
||||
text("DELETE FROM gp_stock_category WHERE trade_date BETWEEN :s AND :e"),
|
||||
{'s': del_start, 'e': del_end}
|
||||
)
|
||||
final_df.to_sql('gp_stock_category', con=connection, if_exists='append', index=False)
|
||||
trans.commit()
|
||||
print(f" ✅ 写入成功。")
|
||||
|
||||
# 9. 清理
|
||||
del daily_data_df
|
||||
del combined_df
|
||||
del records_to_insert
|
||||
if not prev_state_df.empty: del prev_state_df
|
||||
gc.collect()
|
||||
|
||||
|
||||
# --- 运行入口 ---
|
||||
def run_stock_grouper_etl(start_date: str = None, end_date: str = None, test_mode: bool = False,
|
||||
is_full_run: bool = False):
|
||||
if not PROXY_DB_URL or not TUSHARE_TOKEN:
|
||||
raise ValueError("Environment variables required.")
|
||||
|
||||
engine = create_engine(PROXY_DB_URL, pool_pre_ping=True)
|
||||
manager = _StockGroupManager(TUSHARE_TOKEN, engine, test_mode)
|
||||
|
||||
if is_full_run:
|
||||
print("🚀 [Full Run] 启动分批处理模式 (按年)...")
|
||||
current_year = 2016
|
||||
end_year = datetime.now().year
|
||||
|
||||
while current_year <= end_year:
|
||||
batch_start = f"{current_year}0101"
|
||||
batch_end = f"{current_year}1231"
|
||||
|
||||
if current_year == end_year:
|
||||
batch_end = datetime.now().strftime('%Y%m%d')
|
||||
if batch_start > batch_end: break
|
||||
|
||||
print(f"\n>>> 处理批次: {batch_start} - {batch_end}")
|
||||
manager.update_stock_category_data(batch_start, batch_end)
|
||||
current_year += 1
|
||||
time.sleep(1)
|
||||
|
||||
print("🎉 全量重跑完成。")
|
||||
|
||||
else:
|
||||
if not start_date:
|
||||
with engine.connect() as conn:
|
||||
res = conn.execute(text("SELECT MAX(trade_date) FROM gp_stock_category")).scalar()
|
||||
if res:
|
||||
start_date = (res + timedelta(days=1)).strftime('%Y%m%d')
|
||||
else:
|
||||
start_date = '20160101'
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
if start_date > datetime.now().strftime('%Y%m%d'):
|
||||
return {"message": "No new data."}
|
||||
|
||||
print(f"🚀 [Incremental] 增量更新: {start_date} - {end_date}")
|
||||
manager.update_stock_category_data(start_date, end_date)
|
||||
|
||||
return {"message": "Success"}
|
||||
Loading…
Reference in New Issue