This commit is contained in:
满脸小星星 2025-12-26 15:32:42 +08:00
parent 28a1d20529
commit 16f3efe3ae
3 changed files with 767 additions and 2 deletions

View File

@ -743,9 +743,9 @@ def create_clearance_strategy_callback(xt_trader, acc, logger):
sell_volume = current_position
if sell_volume > 0:
# 清仓策略使用限价单价格设置为当前价的95%(向下取整到分),确保能成交
# 清仓策略使用限价单价格设置为当前价的98%(向下取整到分),确保能成交
# 对于卖出,价格越低越容易成交
aggressive_price = round(current_price * 0.95, 2)
aggressive_price = round(current_price * 0.98, 2)
# 确保价格至少为0.01元
aggressive_price = max(aggressive_price, 0.01)

View File

@ -0,0 +1,423 @@
# backend/app/data_collection/market_classifier_etl.py
"""
Created on 2025/11/12
---------
@summary: 市场阶段分片 ,多指数综合市场分类器
---------
@author: NBR
"""
import os
from datetime import datetime
from typing import List, Dict
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
class _MultiIndexMarketClassifier:
"""
多指数综合市场分类器 - 纯ETL版本
职责从Tushare获取数据计算市场阶段并将结果写入数据库
"""
# --- 配置 ---
PROXY_DB_URL = os.environ.get('PROXY_DB_URL')
if not PROXY_DB_URL:
raise ValueError("PROXY_DB_URL environment variable must be set.")
# 采用更健壮的连接池配置
engine = create_engine(PROXY_DB_URL, pool_pre_ping=True, pool_recycle=3600)
MARKET_INDICES = {
'000001.SH': {'name': '上证指数', 'weight': 0.3},
'000300.SH': {'name': '沪深300', 'weight': 0.3},
'399006.SZ': {'name': '创业板指', 'weight': 0.2},
'000905.SH': {'name': '中证500', 'weight': 0.2}
}
def __init__(self, token: str, db_engine):
ts.set_token(token)
self.pro = ts.pro_api()
self.engine = db_engine
self._init_database()
def _init_database(self):
"""通过Proxy确保目标表存在。"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS gp_market_category (
id INT AUTO_INCREMENT PRIMARY KEY,
market_type VARCHAR(20) NOT NULL,
start_date DATE NOT NULL,
end_date DATE NOT NULL,
duration_days INT NOT NULL,
analysis_date DATE,
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_market_period (market_type, start_date, end_date)
) ENGINE=InnoDB COMMENT='市场阶段分类表'
"""
try:
with self.engine.connect() as connection:
with connection.begin() as trans:
connection.execute(text(create_table_sql))
trans.commit()
print("✅ 数据库表'gp_market_category'检查/创建完成。")
except SQLAlchemyError as e:
print(f"❌ 数据库表初始化失败: {e}")
raise
def get_all_index_data(self, start_date: str, end_date: str) -> Dict[str, pd.DataFrame]:
"""获取所有指数的日线数据"""
print("📈 获取多指数数据...")
all_data = {}
for ts_code, info in self.MARKET_INDICES.items():
try:
df = self.pro.index_daily(ts_code=ts_code,
start_date=start_date,
end_date=end_date)
if not df.empty:
df['trade_date'] = pd.to_datetime(df['trade_date'])
df = df.sort_values('trade_date').reset_index(drop=True)
# 计算技术指标
df['ma20'] = df['close'].rolling(window=20).mean()
df['ma60'] = df['close'].rolling(window=60).mean()
df['ma120'] = df['close'].rolling(window=120).mean()
df['high_60'] = df['close'].rolling(window=60).max()
df['low_60'] = df['close'].rolling(window=60).min()
df['volume_ma20'] = df['vol'].rolling(window=20).mean()
all_data[ts_code] = df
print(f" ✅ 获取 {info['name']} 数据: {len(df)} 个交易日")
else:
print(f" ❌ 获取 {info['name']} 数据为空")
except Exception as e:
print(f" ❌ 获取 {info['name']} 数据失败: {e}")
return all_data
def calculate_composite_indicators(self, all_data: Dict[str, pd.DataFrame]) -> pd.DataFrame:
"""计算综合市场指标"""
print("📊 计算综合市场指标...")
# 获取所有指数的共同交易日
common_dates = None
for ts_code, df in all_data.items():
if common_dates is None:
common_dates = set(df['trade_date'])
else:
common_dates = common_dates.intersection(set(df['trade_date']))
if not common_dates:
raise Exception("没有共同的交易日数据")
common_dates = sorted(common_dates)
composite_data = []
for date in common_dates:
date_metrics = {'trade_date': date}
# 计算每个指数在该日期的状态
bull_signals = 0
bear_signals = 0
total_weight = 0
for ts_code, df in all_data.items():
weight = self.MARKET_INDICES[ts_code]['weight']
row = df[df['trade_date'] == date]
if not row.empty:
close = row['close'].iloc[0]
ma20 = row['ma20'].iloc[0]
ma60 = row['ma60'].iloc[0]
ma120 = row['ma120'].iloc[0]
high_60 = row['high_60'].iloc[0]
low_60 = row['low_60'].iloc[0]
volume_ratio = row['vol'].iloc[0] / row['volume_ma20'].iloc[0] if row['volume_ma20'].iloc[
0] > 0 else 1
# 判断单个指数的状态
index_bull_signals = 0
index_bear_signals = 0
# 牛市信号
if close > ma60 > ma120: # 多头排列
index_bull_signals += 1
if close > high_60 * 0.95: # 接近近期高点
index_bull_signals += 1
if volume_ratio > 1.2: # 放量
index_bull_signals += 1
if (close - low_60) / low_60 > 0.2: # 从低点上涨超过20%
index_bull_signals += 1
# 熊市信号
if close < ma60 < ma120: # 空头排列
index_bear_signals += 1
if close < low_60 * 1.05: # 接近近期低点
index_bear_signals += 1
if volume_ratio > 1.1 and close < ma20: # 放量下跌
index_bear_signals += 1
if (high_60 - close) / high_60 > 0.15: # 从高点下跌超过15%
index_bear_signals += 1
# 加权累计信号
if index_bull_signals >= 2:
bull_signals += weight
if index_bear_signals >= 2:
bear_signals += weight
total_weight += weight
# 计算综合市场状态
if total_weight > 0:
bull_ratio = bull_signals / total_weight
bear_ratio = bear_signals / total_weight
if bull_ratio > 0.6: # 60%以上的指数显示牛市信号
market_type = 'bull'
elif bear_ratio > 0.6: # 60%以上的指数显示熊市信号
market_type = 'bear'
else:
market_type = 'consolidation'
else:
market_type = 'consolidation'
date_metrics.update({
'market_type': market_type,
'bull_ratio': bull_ratio,
'bear_ratio': bear_ratio
})
composite_data.append(date_metrics)
return pd.DataFrame(composite_data)
def detect_market_phases(self, composite_df: pd.DataFrame) -> List[Dict]:
"""检测市场阶段"""
print("🔍 检测市场阶段...")
phases = []
current_phase = None
for i, row in composite_df.iterrows():
date = row['trade_date'].date()
market_type = row['market_type']
if current_phase is None:
current_phase = {
'market_type': market_type,
'start_date': date,
'end_date': date,
'dates': [date]
}
elif current_phase['market_type'] == market_type:
# 相同阶段,更新结束日期
current_phase['end_date'] = date
current_phase['dates'].append(date)
else:
# 阶段转换,保存当前阶段
phase_duration = (current_phase['end_date'] - current_phase['start_date']).days
if phase_duration >= 30: # 至少30天的阶段才保留
phases.append(current_phase)
# 开始新阶段
current_phase = {
'market_type': market_type,
'start_date': date,
'end_date': date,
'dates': [date]
}
# 处理最后一个阶段
if current_phase:
phase_duration = (current_phase['end_date'] - current_phase['start_date']).days
if phase_duration >= 30:
phases.append(current_phase)
return phases
def merge_same_type_phases(self, phases: List[Dict]) -> List[Dict]:
"""合并相同类型的相邻阶段"""
if len(phases) <= 1:
return phases
merged_phases = []
current_phase = phases[0]
for i in range(1, len(phases)):
next_phase = phases[i]
# 检查是否相邻且类型相同
gap_days = (next_phase['start_date'] - current_phase['end_date']).days
if (current_phase['market_type'] == next_phase['market_type'] and
gap_days <= 60): # 间隔小于60天可以合并
# 合并阶段
current_phase['end_date'] = next_phase['end_date']
current_phase['dates'].extend(next_phase['dates'])
else:
merged_phases.append(current_phase)
current_phase = next_phase
merged_phases.append(current_phase)
print(f"🔄 阶段合并: {len(phases)}{len(merged_phases)}")
return merged_phases
def format_phases_for_storage(self, phases: List[Dict]) -> List[Dict]:
"""格式化阶段数据用于存储"""
formatted_phases = []
for phase in phases:
duration_days = (phase['end_date'] - phase['start_date']).days + 1
formatted_phase = {
'market_type': phase['market_type'],
'start_date': phase['start_date'],
'end_date': phase['end_date'],
'duration_days': duration_days,
# 如果需要存储日期集合,可以在这里添加
# 'date_set': json.dumps([d.strftime('%Y-%m-%d') for d in sorted(set(phase['dates']))], ensure_ascii=False)
}
formatted_phases.append(formatted_phase)
return formatted_phases
def save_phases_to_mysql(self, phases: List[Dict]):
"""将市场阶段通过Proxy以事务方式写入MySQL。"""
if not phases:
print(" 没有市场阶段数据需要保存。")
return
print("💾 保存市场阶段数据到MySQL...")
phases_df = pd.DataFrame(phases)
phases_df['analysis_date'] = datetime.now().date()
table_name = 'gp_market_category'
with self.engine.connect() as connection:
with connection.begin() as transaction:
try:
connection.execute(text(f"TRUNCATE TABLE {table_name}"))
phases_df.to_sql(
name=table_name,
con=connection,
if_exists='append',
index=False,
chunksize=1000
)
transaction.commit()
print(f"✅ 保存完成: 成功写入 {len(phases_df)} 个市场阶段。")
except SQLAlchemyError as e:
print(f"❌ 保存数据到 {table_name} 失败: {e}")
transaction.rollback()
raise
def run_comprehensive_analysis(self, start_date: str, end_date: str):
"""执行完整的ETL流程。"""
print("🎯 开始多指数市场分析 (ETL)...")
print("=" * 60)
try:
all_data = self.get_all_index_data(start_date, end_date)
if not all_data:
print("❌ 无法获取指数数据ETL终止。")
return
composite_df = self.calculate_composite_indicators(all_data)
if composite_df.empty:
print("❌ 无法计算综合指标ETL终止。")
return
raw_phases = self.detect_market_phases(composite_df)
merged_phases = self.merge_same_type_phases(raw_phases)
final_phases = self.format_phases_for_storage(merged_phases)
self.save_phases_to_mysql(final_phases)
self._print_analysis_results(final_phases)
print(f"\n🎉 ETL分析完成: 共处理了 {len(final_phases)} 个市场阶段。")
except Exception as e:
print(f"❌ ETL分析过程中出错: {e}")
import traceback
traceback.print_exc()
raise
def _print_analysis_results(self, phases: List[Dict]):
"""打印分析结果"""
print("\n📋 市场阶段分析结果:")
print("=" * 70)
print(f"{'序号':<3} {'市场类型':<10} {'开始日期':<12} {'结束日期':<12} {'持续时间':<8}")
print("-" * 70)
bull_count = bear_count = consolidation_count = 0
for i, phase in enumerate(phases, 1):
market_type = phase['market_type']
if market_type == 'bull':
bull_count += 1
type_display = '🐂 牛市'
elif market_type == 'bear':
bear_count += 1
type_display = '🐻 熊市'
else:
consolidation_count += 1
type_display = '➰ 震荡市'
print(f"{i:<3} {type_display:<10} {phase['start_date']} {phase['end_date']} "
f"{phase['duration_days']:>6}")
print("-" * 70)
print(f"📈 统计: 牛市{bull_count}个, 熊市{bear_count}个, 震荡市{consolidation_count}")
# --- 唯一的公共入口函数 ---
def run_market_classifier_etl(start_date: str, end_date: str):
"""
ETL入口函数供Celery任务调用
"""
print(f"🚀 [ETL START] Market Classifier for period {start_date} to {end_date}")
# 1. 从环境变量获取配置
tushare_token = os.environ.get('TUSHARE_TOKEN')
proxy_db_url = os.environ.get('PROXY_DB_URL')
if not tushare_token or not proxy_db_url:
raise ValueError("TUSHARE_TOKEN and PROXY_DB_URL environment variables must be set.")
# 2. 创建数据库引擎
engine = create_engine(proxy_db_url, pool_pre_ping=True)
# 3. 实例化并运行
try:
classifier = _MultiIndexMarketClassifier(tushare_token, engine)
classifier.run_comprehensive_analysis(start_date, end_date)
print(f"✅ [ETL SUCCESS] Market Classifier finished.")
return {"message": f"Market classifier ETL completed for period {start_date} to {end_date}."}
except Exception as e:
print(f"❌ [ETL FAILED] Market Classifier failed: {e}")
# 在Celery任务中重新抛出异常很重要这样任务状态才会变为FAILURE
raise
# --- 用于独立测试的入口 ---
if __name__ == "__main__":
# 为了方便测试,从.env文件加载环境变量
try:
from dotenv import load_dotenv
# 假设.env文件在项目根目录即上一级的上一级
dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
load_dotenv(dotenv_path=dotenv_path)
print(f"Loaded .env file from: {dotenv_path}")
except ImportError:
print("Warning: dotenv not installed. Please set environment variables manually for standalone testing.")
# 定义测试运行的时间范围
test_start_date = '20140101'
test_end_date = '20251130'
run_market_classifier_etl(test_start_date, test_end_date)

View File

@ -0,0 +1,342 @@
# /backend/app/data_collection/stock_grouper_etl.py
"""
Created on 2025/11/12
---------
@summary: 股票分组脚本 (终极修复版 V5)
1. 解决 OOM 问题 (按年分批)
2. 解决 跨批次重复问题 (前序状态合并)
3. 解决 KeyError: symbol 问题 (手动生成 symbol)
---------
@author: NBR
"""
import os
import time
import gc
from datetime import datetime, timedelta
import pandas as pd
import tushare as ts
from sqlalchemy import create_engine, text
# --- 配置 ---
PROXY_DB_URL = os.environ.get('PROXY_DB_URL')
TUSHARE_TOKEN = os.environ.get('TUSHARE_TOKEN')
class _StockGroupManager:
"""
股票分组管理器
"""
REQUEST_INTERVAL = 0.05
INDEX_MAPPING = {
'沪深300': '000300.SH',
'中证500': '000905.SH',
'中证1000': '000852.SH',
'创业板': '399006.SZ',
'科创板': '000688.SH',
'上证指数': '000001.SH',
'深证成指': '399001.SZ'
}
def __init__(self, token: str, db_engine, test_mode: bool = False):
self.test_mode = test_mode
self.pro = ts.pro_api(token)
if self.pro is None:
raise Exception("Tushare pro_api initialization failed.")
self.engine = db_engine
self.index_constituents_cache = {}
self.TEST_STOCKS = ['300750.SZ', '688981.SH']
self._init_database()
def _init_database(self):
create_table_sql = """
CREATE TABLE IF NOT EXISTS gp_stock_category (
id INT AUTO_INCREMENT PRIMARY KEY,
ts_code VARCHAR(20) NOT NULL,
symbol VARCHAR(20) NOT NULL,
name VARCHAR(100) NOT NULL,
trade_date DATE NOT NULL,
industry VARCHAR(50),
is_st BOOLEAN DEFAULT FALSE,
index_series TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_stock_date (ts_code, trade_date)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
with self.engine.connect() as connection:
with connection.begin() as trans:
connection.execute(text(create_table_sql))
trans.commit()
# --- 辅助方法 ---
def _get_month_start_end_dates(self, year_month: str):
try:
year = int(year_month[:4])
month = int(year_month[4:6])
month_start = datetime(year, month, 1)
if month == 12:
next_month = datetime(year + 1, 1, 1)
else:
next_month = datetime(year, month + 1, 1)
month_end = next_month - timedelta(days=1)
return month_start.strftime('%Y%m%d'), month_end.strftime('%Y%m%d')
except:
return year_month + '01', year_month + '31'
def _get_previous_month_key(self, month_key: str) -> str:
try:
year = int(month_key[:4])
month = int(month_key[4:])
if month == 1:
return f"{year - 1:04d}12"
else:
return f"{year:04d}{month - 1:02d}"
except:
return month_key
def _cache_month_index_constituents(self, month_key: str):
if month_key in self.index_constituents_cache:
return
self.index_constituents_cache[month_key] = {}
prev_month_key = self._get_previous_month_key(month_key)
prev_data = self.index_constituents_cache.get(prev_month_key, {})
month_start, month_end = self._get_month_start_end_dates(month_key)
for index_name, index_code in self.INDEX_MAPPING.items():
current_stocks = set()
try:
df = self.pro.index_weight(index_code=index_code, start_date=month_start, end_date=month_end)
if not df.empty:
current_stocks = set(df['con_code'].tolist())
else:
# 前向填充
if index_name in prev_data:
current_stocks = prev_data[index_name]
self.index_constituents_cache[month_key][index_name] = current_stocks
time.sleep(self.REQUEST_INTERVAL)
except Exception as e:
print(f" 缓存{index_name}成分股失败: {e}")
if index_name in prev_data:
self.index_constituents_cache[month_key][index_name] = prev_data[index_name]
def _get_index_series_for_stock(self, ts_code: str, trade_date: str) -> str:
try:
index_series = []
month_key = trade_date[:6]
if month_key not in self.index_constituents_cache:
self._cache_month_index_constituents(month_key)
month_data = self.index_constituents_cache.get(month_key, {})
for index_name, stock_set in month_data.items():
if ts_code in stock_set:
index_series.append(index_name)
return ','.join(index_series) if index_series else ''
except:
return ''
def _is_st_stock(self, stock_name: str) -> bool:
if pd.isna(stock_name): return False
return 'ST' in stock_name or '*ST' in stock_name
def _format_code_to_prefix_style(self, ts_code: str) -> str:
if '.' in ts_code:
parts = ts_code.split('.')
if len(parts) == 2:
return f"{parts[1].upper()}{parts[0]}"
return ts_code
def get_latest_states_before(self, before_date_str: str) -> pd.DataFrame:
"""获取前序基准状态"""
print(f" 🔍 查询 {before_date_str} 之前的基准状态 (用于跨年去重)...")
query = text("""
SELECT t1.ts_code, t1.industry, t1.is_st, t1.index_series, t1.trade_date, t1.name
FROM gp_stock_category t1
JOIN (
SELECT ts_code, MAX(trade_date) as max_dt
FROM gp_stock_category
WHERE trade_date < :before_date
GROUP BY ts_code
) t2 ON t1.ts_code = t2.ts_code AND t1.trade_date = t2.max_dt
""")
try:
df = pd.read_sql(query, self.engine, params={'before_date': before_date_str})
if not df.empty:
df['trade_date'] = pd.to_datetime(df['trade_date'])
def _revert_code(c):
if len(c) > 2 and c[:2] in ['SZ', 'SH']:
return f"{c[2:]}.{c[:2]}"
return c
df['ts_code'] = df['ts_code'].apply(_revert_code)
print(f" ✅ 获取到 {len(df)} 条基准记录。")
return df
except Exception as e:
print(f" ⚠️ 获取基准状态失败 (可能是第一次运行): {e}")
return pd.DataFrame()
def get_all_bak_basic_in_range(self, start_date: str, end_date: str) -> pd.DataFrame:
print(f" 正在获取 {start_date}{end_date} 的日线数据...")
all_dates = pd.date_range(start=start_date, end=end_date, freq='B')
all_dfs = []
for i, date in enumerate(all_dates):
date_str = date.strftime('%Y%m%d')
try:
df = self.pro.bak_basic(trade_date=date_str)
if df is not None and not df.empty:
all_dfs.append(df)
time.sleep(self.REQUEST_INTERVAL)
except Exception as e:
print(f" 获取 {date_str} 数据失败: {e}")
if not all_dfs:
return pd.DataFrame()
return pd.concat(all_dfs, ignore_index=True)
def update_stock_category_data(self, start_date: str, end_date: str):
"""核心处理逻辑"""
# 1. 基础日线
daily_data_df = self.get_all_bak_basic_in_range(start_date, end_date)
if daily_data_df.empty:
print(" 无数据,本批次结束。")
return
# =================【关键修复】=================
# 确保 symbol 列存在,因为后续 merge 需要用到它
if 'symbol' not in daily_data_df.columns:
# 从 ts_code 生成 symbol (例如 000001.SZ -> 000001)
daily_data_df['symbol'] = daily_data_df['ts_code'].apply(lambda x: str(x).split('.')[0])
# ==============================================
print(f" 批量计算属性 ({len(daily_data_df)} 行)...")
daily_data_df['trade_date'] = pd.to_datetime(daily_data_df['trade_date'], format='%Y%m%d')
daily_data_df['is_st'] = daily_data_df['name'].apply(self._is_st_stock)
# 2. 预加载指数缓存
unique_months = sorted(daily_data_df['trade_date'].dt.strftime('%Y%m').unique())
for ym in unique_months:
self._cache_month_index_constituents(ym)
# 3. 计算指数归属
daily_data_df['index_series'] = daily_data_df.apply(
lambda row: self._get_index_series_for_stock(row['ts_code'], row['trade_date'].strftime('%Y%m%d')),
axis=1
)
# 4. 合并前序状态
prev_state_df = self.get_latest_states_before(pd.to_datetime(start_date).strftime('%Y-%m-%d'))
if not prev_state_df.empty:
cols = ['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']
combined_df = pd.concat([prev_state_df[cols], daily_data_df[cols]], ignore_index=True)
else:
combined_df = daily_data_df.copy()
# 5. 变化检测
combined_df.sort_values(['ts_code', 'trade_date'], inplace=True)
cols_to_check = ['industry', 'is_st', 'index_series']
shifted = combined_df.groupby('ts_code')[cols_to_check].shift(1)
is_changed = (combined_df[cols_to_check].ne(shifted)).any(axis=1)
# 6. 筛选结果
current_batch_start_dt = pd.to_datetime(start_date)
records_to_insert = combined_df[is_changed & (combined_df['trade_date'] >= current_batch_start_dt)].copy()
# 7. 补回 symbol 和 name
if not records_to_insert.empty:
records_to_insert = pd.merge(
records_to_insert[['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']],
daily_data_df[['ts_code', 'trade_date', 'symbol', 'name']], # 这里现在肯定有 symbol 了
on=['ts_code', 'trade_date'],
how='left'
)
print(f" 筛选出 {len(records_to_insert)} 条真实变更记录。")
# 8. 写入
if not records_to_insert.empty:
records_to_insert['ts_code'] = records_to_insert['ts_code'].apply(self._format_code_to_prefix_style)
# 双重保险:如果 merge 后 symbol 仍有空值(极少情况),再次生成
if 'symbol' not in records_to_insert.columns or records_to_insert['symbol'].isnull().any():
records_to_insert['symbol'] = records_to_insert['ts_code'].apply(lambda x: x[2:] if len(x) > 2 else x)
final_df = records_to_insert[
['ts_code', 'symbol', 'name', 'trade_date', 'industry', 'is_st', 'index_series']]
with self.engine.connect() as connection:
with connection.begin() as trans:
del_start = pd.to_datetime(start_date).strftime('%Y-%m-%d')
del_end = pd.to_datetime(end_date).strftime('%Y-%m-%d')
connection.execute(
text("DELETE FROM gp_stock_category WHERE trade_date BETWEEN :s AND :e"),
{'s': del_start, 'e': del_end}
)
final_df.to_sql('gp_stock_category', con=connection, if_exists='append', index=False)
trans.commit()
print(f" ✅ 写入成功。")
# 9. 清理
del daily_data_df
del combined_df
del records_to_insert
if not prev_state_df.empty: del prev_state_df
gc.collect()
# --- 运行入口 ---
def run_stock_grouper_etl(start_date: str = None, end_date: str = None, test_mode: bool = False,
is_full_run: bool = False):
if not PROXY_DB_URL or not TUSHARE_TOKEN:
raise ValueError("Environment variables required.")
engine = create_engine(PROXY_DB_URL, pool_pre_ping=True)
manager = _StockGroupManager(TUSHARE_TOKEN, engine, test_mode)
if is_full_run:
print("🚀 [Full Run] 启动分批处理模式 (按年)...")
current_year = 2016
end_year = datetime.now().year
while current_year <= end_year:
batch_start = f"{current_year}0101"
batch_end = f"{current_year}1231"
if current_year == end_year:
batch_end = datetime.now().strftime('%Y%m%d')
if batch_start > batch_end: break
print(f"\n>>> 处理批次: {batch_start} - {batch_end}")
manager.update_stock_category_data(batch_start, batch_end)
current_year += 1
time.sleep(1)
print("🎉 全量重跑完成。")
else:
if not start_date:
with engine.connect() as conn:
res = conn.execute(text("SELECT MAX(trade_date) FROM gp_stock_category")).scalar()
if res:
start_date = (res + timedelta(days=1)).strftime('%Y%m%d')
else:
start_date = '20160101'
end_date = datetime.now().strftime('%Y%m%d')
if start_date > datetime.now().strftime('%Y%m%d'):
return {"message": "No new data."}
print(f"🚀 [Incremental] 增量更新: {start_date} - {end_date}")
manager.update_stock_category_data(start_date, end_date)
return {"message": "Success"}