From 16f3efe3ae801403bc7553d1a317ef8a20a65585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BB=A1=E8=84=B8=E5=B0=8F=E6=98=9F=E6=98=9F?= Date: Fri, 26 Dec 2025 15:32:42 +0800 Subject: [PATCH] commit; --- src/QMT/strategy.py | 4 +- src/tushare_scripts/market_classifier_etl.py | 423 +++++++++++++++++++ src/tushare_scripts/stock_grouper_etl.py | 342 +++++++++++++++ 3 files changed, 767 insertions(+), 2 deletions(-) create mode 100644 src/tushare_scripts/market_classifier_etl.py create mode 100644 src/tushare_scripts/stock_grouper_etl.py diff --git a/src/QMT/strategy.py b/src/QMT/strategy.py index cb2917b..3fd3624 100644 --- a/src/QMT/strategy.py +++ b/src/QMT/strategy.py @@ -743,9 +743,9 @@ def create_clearance_strategy_callback(xt_trader, acc, logger): sell_volume = current_position if sell_volume > 0: - # 清仓策略:使用限价单,价格设置为当前价的95%(向下取整到分),确保能成交 + # 清仓策略:使用限价单,价格设置为当前价的98%(向下取整到分),确保能成交 # 对于卖出,价格越低越容易成交 - aggressive_price = round(current_price * 0.95, 2) + aggressive_price = round(current_price * 0.98, 2) # 确保价格至少为0.01元 aggressive_price = max(aggressive_price, 0.01) diff --git a/src/tushare_scripts/market_classifier_etl.py b/src/tushare_scripts/market_classifier_etl.py new file mode 100644 index 0000000..a946a2d --- /dev/null +++ b/src/tushare_scripts/market_classifier_etl.py @@ -0,0 +1,423 @@ +# backend/app/data_collection/market_classifier_etl.py +""" +Created on 2025/11/12 +--------- +@summary: 市场阶段分片 ,多指数综合市场分类器 +--------- +@author: NBR +""" +import os +from datetime import datetime +from typing import List, Dict + +import pandas as pd +import tushare as ts +from sqlalchemy import create_engine, text +from sqlalchemy.exc import SQLAlchemyError + + +class _MultiIndexMarketClassifier: + """ + 多指数综合市场分类器 - 纯ETL版本 + 职责:从Tushare获取数据,计算市场阶段,并将结果写入数据库。 + """ + + # --- 配置 --- + PROXY_DB_URL = os.environ.get('PROXY_DB_URL') + if not PROXY_DB_URL: + raise ValueError("PROXY_DB_URL environment variable must be set.") + + # 采用更健壮的连接池配置 + engine = create_engine(PROXY_DB_URL, pool_pre_ping=True, pool_recycle=3600) + + MARKET_INDICES = { + '000001.SH': {'name': '上证指数', 'weight': 0.3}, + '000300.SH': {'name': '沪深300', 'weight': 0.3}, + '399006.SZ': {'name': '创业板指', 'weight': 0.2}, + '000905.SH': {'name': '中证500', 'weight': 0.2} + } + + def __init__(self, token: str, db_engine): + ts.set_token(token) + self.pro = ts.pro_api() + self.engine = db_engine + self._init_database() + + def _init_database(self): + """通过Proxy确保目标表存在。""" + create_table_sql = """ + CREATE TABLE IF NOT EXISTS gp_market_category ( + id INT AUTO_INCREMENT PRIMARY KEY, + market_type VARCHAR(20) NOT NULL, + start_date DATE NOT NULL, + end_date DATE NOT NULL, + duration_days INT NOT NULL, + analysis_date DATE, + created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_market_period (market_type, start_date, end_date) + ) ENGINE=InnoDB COMMENT='市场阶段分类表' + """ + try: + with self.engine.connect() as connection: + with connection.begin() as trans: + connection.execute(text(create_table_sql)) + trans.commit() + print("✅ 数据库表'gp_market_category'检查/创建完成。") + except SQLAlchemyError as e: + print(f"❌ 数据库表初始化失败: {e}") + raise + + def get_all_index_data(self, start_date: str, end_date: str) -> Dict[str, pd.DataFrame]: + """获取所有指数的日线数据""" + print("📈 获取多指数数据...") + + all_data = {} + for ts_code, info in self.MARKET_INDICES.items(): + try: + df = self.pro.index_daily(ts_code=ts_code, + start_date=start_date, + end_date=end_date) + if not df.empty: + df['trade_date'] = pd.to_datetime(df['trade_date']) + df = df.sort_values('trade_date').reset_index(drop=True) + + # 计算技术指标 + df['ma20'] = df['close'].rolling(window=20).mean() + df['ma60'] = df['close'].rolling(window=60).mean() + df['ma120'] = df['close'].rolling(window=120).mean() + df['high_60'] = df['close'].rolling(window=60).max() + df['low_60'] = df['close'].rolling(window=60).min() + df['volume_ma20'] = df['vol'].rolling(window=20).mean() + + all_data[ts_code] = df + print(f" ✅ 获取 {info['name']} 数据: {len(df)} 个交易日") + else: + print(f" ❌ 获取 {info['name']} 数据为空") + + except Exception as e: + print(f" ❌ 获取 {info['name']} 数据失败: {e}") + + return all_data + + def calculate_composite_indicators(self, all_data: Dict[str, pd.DataFrame]) -> pd.DataFrame: + """计算综合市场指标""" + print("📊 计算综合市场指标...") + + # 获取所有指数的共同交易日 + common_dates = None + for ts_code, df in all_data.items(): + if common_dates is None: + common_dates = set(df['trade_date']) + else: + common_dates = common_dates.intersection(set(df['trade_date'])) + + if not common_dates: + raise Exception("没有共同的交易日数据") + + common_dates = sorted(common_dates) + composite_data = [] + + for date in common_dates: + date_metrics = {'trade_date': date} + + # 计算每个指数在该日期的状态 + bull_signals = 0 + bear_signals = 0 + total_weight = 0 + + for ts_code, df in all_data.items(): + weight = self.MARKET_INDICES[ts_code]['weight'] + row = df[df['trade_date'] == date] + + if not row.empty: + close = row['close'].iloc[0] + ma20 = row['ma20'].iloc[0] + ma60 = row['ma60'].iloc[0] + ma120 = row['ma120'].iloc[0] + high_60 = row['high_60'].iloc[0] + low_60 = row['low_60'].iloc[0] + volume_ratio = row['vol'].iloc[0] / row['volume_ma20'].iloc[0] if row['volume_ma20'].iloc[ + 0] > 0 else 1 + + # 判断单个指数的状态 + index_bull_signals = 0 + index_bear_signals = 0 + + # 牛市信号 + if close > ma60 > ma120: # 多头排列 + index_bull_signals += 1 + if close > high_60 * 0.95: # 接近近期高点 + index_bull_signals += 1 + if volume_ratio > 1.2: # 放量 + index_bull_signals += 1 + if (close - low_60) / low_60 > 0.2: # 从低点上涨超过20% + index_bull_signals += 1 + + # 熊市信号 + if close < ma60 < ma120: # 空头排列 + index_bear_signals += 1 + if close < low_60 * 1.05: # 接近近期低点 + index_bear_signals += 1 + if volume_ratio > 1.1 and close < ma20: # 放量下跌 + index_bear_signals += 1 + if (high_60 - close) / high_60 > 0.15: # 从高点下跌超过15% + index_bear_signals += 1 + + # 加权累计信号 + if index_bull_signals >= 2: + bull_signals += weight + if index_bear_signals >= 2: + bear_signals += weight + + total_weight += weight + + # 计算综合市场状态 + if total_weight > 0: + bull_ratio = bull_signals / total_weight + bear_ratio = bear_signals / total_weight + + if bull_ratio > 0.6: # 60%以上的指数显示牛市信号 + market_type = 'bull' + elif bear_ratio > 0.6: # 60%以上的指数显示熊市信号 + market_type = 'bear' + else: + market_type = 'consolidation' + else: + market_type = 'consolidation' + + date_metrics.update({ + 'market_type': market_type, + 'bull_ratio': bull_ratio, + 'bear_ratio': bear_ratio + }) + + composite_data.append(date_metrics) + + return pd.DataFrame(composite_data) + + def detect_market_phases(self, composite_df: pd.DataFrame) -> List[Dict]: + """检测市场阶段""" + print("🔍 检测市场阶段...") + + phases = [] + current_phase = None + + for i, row in composite_df.iterrows(): + date = row['trade_date'].date() + market_type = row['market_type'] + + if current_phase is None: + current_phase = { + 'market_type': market_type, + 'start_date': date, + 'end_date': date, + 'dates': [date] + } + elif current_phase['market_type'] == market_type: + # 相同阶段,更新结束日期 + current_phase['end_date'] = date + current_phase['dates'].append(date) + else: + # 阶段转换,保存当前阶段 + phase_duration = (current_phase['end_date'] - current_phase['start_date']).days + if phase_duration >= 30: # 至少30天的阶段才保留 + phases.append(current_phase) + + # 开始新阶段 + current_phase = { + 'market_type': market_type, + 'start_date': date, + 'end_date': date, + 'dates': [date] + } + + # 处理最后一个阶段 + if current_phase: + phase_duration = (current_phase['end_date'] - current_phase['start_date']).days + if phase_duration >= 30: + phases.append(current_phase) + + return phases + + def merge_same_type_phases(self, phases: List[Dict]) -> List[Dict]: + """合并相同类型的相邻阶段""" + if len(phases) <= 1: + return phases + + merged_phases = [] + current_phase = phases[0] + + for i in range(1, len(phases)): + next_phase = phases[i] + + # 检查是否相邻且类型相同 + gap_days = (next_phase['start_date'] - current_phase['end_date']).days + if (current_phase['market_type'] == next_phase['market_type'] and + gap_days <= 60): # 间隔小于60天可以合并 + + # 合并阶段 + current_phase['end_date'] = next_phase['end_date'] + current_phase['dates'].extend(next_phase['dates']) + else: + merged_phases.append(current_phase) + current_phase = next_phase + + merged_phases.append(current_phase) + + print(f"🔄 阶段合并: {len(phases)} → {len(merged_phases)}") + return merged_phases + + def format_phases_for_storage(self, phases: List[Dict]) -> List[Dict]: + """格式化阶段数据用于存储""" + formatted_phases = [] + + for phase in phases: + duration_days = (phase['end_date'] - phase['start_date']).days + 1 + + formatted_phase = { + 'market_type': phase['market_type'], + 'start_date': phase['start_date'], + 'end_date': phase['end_date'], + 'duration_days': duration_days, + # 如果需要存储日期集合,可以在这里添加 + # 'date_set': json.dumps([d.strftime('%Y-%m-%d') for d in sorted(set(phase['dates']))], ensure_ascii=False) + } + + formatted_phases.append(formatted_phase) + + return formatted_phases + + def save_phases_to_mysql(self, phases: List[Dict]): + """将市场阶段通过Proxy以事务方式写入MySQL。""" + if not phases: + print("ℹ️ 没有市场阶段数据需要保存。") + return + + print("💾 保存市场阶段数据到MySQL...") + phases_df = pd.DataFrame(phases) + phases_df['analysis_date'] = datetime.now().date() + table_name = 'gp_market_category' + + with self.engine.connect() as connection: + with connection.begin() as transaction: + try: + connection.execute(text(f"TRUNCATE TABLE {table_name}")) + phases_df.to_sql( + name=table_name, + con=connection, + if_exists='append', + index=False, + chunksize=1000 + ) + transaction.commit() + print(f"✅ 保存完成: 成功写入 {len(phases_df)} 个市场阶段。") + except SQLAlchemyError as e: + print(f"❌ 保存数据到 {table_name} 失败: {e}") + transaction.rollback() + raise + + def run_comprehensive_analysis(self, start_date: str, end_date: str): + """执行完整的ETL流程。""" + print("🎯 开始多指数市场分析 (ETL)...") + print("=" * 60) + try: + all_data = self.get_all_index_data(start_date, end_date) + if not all_data: + print("❌ 无法获取指数数据,ETL终止。") + return + + composite_df = self.calculate_composite_indicators(all_data) + if composite_df.empty: + print("❌ 无法计算综合指标,ETL终止。") + return + + raw_phases = self.detect_market_phases(composite_df) + merged_phases = self.merge_same_type_phases(raw_phases) + final_phases = self.format_phases_for_storage(merged_phases) + + self.save_phases_to_mysql(final_phases) + self._print_analysis_results(final_phases) + + print(f"\n🎉 ETL分析完成: 共处理了 {len(final_phases)} 个市场阶段。") + except Exception as e: + print(f"❌ ETL分析过程中出错: {e}") + import traceback + traceback.print_exc() + raise + + def _print_analysis_results(self, phases: List[Dict]): + """打印分析结果""" + print("\n📋 市场阶段分析结果:") + print("=" * 70) + print(f"{'序号':<3} {'市场类型':<10} {'开始日期':<12} {'结束日期':<12} {'持续时间':<8}") + print("-" * 70) + + bull_count = bear_count = consolidation_count = 0 + + for i, phase in enumerate(phases, 1): + market_type = phase['market_type'] + if market_type == 'bull': + bull_count += 1 + type_display = '🐂 牛市' + elif market_type == 'bear': + bear_count += 1 + type_display = '🐻 熊市' + else: + consolidation_count += 1 + type_display = '➰ 震荡市' + + print(f"{i:<3} {type_display:<10} {phase['start_date']} {phase['end_date']} " + f"{phase['duration_days']:>6} 天") + + print("-" * 70) + print(f"📈 统计: 牛市{bull_count}个, 熊市{bear_count}个, 震荡市{consolidation_count}个") + + +# --- 唯一的公共入口函数 --- +def run_market_classifier_etl(start_date: str, end_date: str): + """ + ETL入口函数,供Celery任务调用。 + """ + print(f"🚀 [ETL START] Market Classifier for period {start_date} to {end_date}") + + # 1. 从环境变量获取配置 + tushare_token = os.environ.get('TUSHARE_TOKEN') + proxy_db_url = os.environ.get('PROXY_DB_URL') + + if not tushare_token or not proxy_db_url: + raise ValueError("TUSHARE_TOKEN and PROXY_DB_URL environment variables must be set.") + + # 2. 创建数据库引擎 + engine = create_engine(proxy_db_url, pool_pre_ping=True) + + # 3. 实例化并运行 + try: + classifier = _MultiIndexMarketClassifier(tushare_token, engine) + classifier.run_comprehensive_analysis(start_date, end_date) + print(f"✅ [ETL SUCCESS] Market Classifier finished.") + return {"message": f"Market classifier ETL completed for period {start_date} to {end_date}."} + except Exception as e: + print(f"❌ [ETL FAILED] Market Classifier failed: {e}") + # 在Celery任务中,重新抛出异常很重要,这样任务状态才会变为FAILURE + raise + + +# --- 用于独立测试的入口 --- +if __name__ == "__main__": + + # 为了方便测试,从.env文件加载环境变量 + try: + from dotenv import load_dotenv + + # 假设.env文件在项目根目录,即上一级的上一级 + dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env') + load_dotenv(dotenv_path=dotenv_path) + print(f"Loaded .env file from: {dotenv_path}") + except ImportError: + print("Warning: dotenv not installed. Please set environment variables manually for standalone testing.") + + # 定义测试运行的时间范围 + test_start_date = '20140101' + test_end_date = '20251130' + + run_market_classifier_etl(test_start_date, test_end_date) diff --git a/src/tushare_scripts/stock_grouper_etl.py b/src/tushare_scripts/stock_grouper_etl.py new file mode 100644 index 0000000..dc7d059 --- /dev/null +++ b/src/tushare_scripts/stock_grouper_etl.py @@ -0,0 +1,342 @@ +# /backend/app/data_collection/stock_grouper_etl.py + +""" +Created on 2025/11/12 +--------- +@summary: 股票分组脚本 (终极修复版 V5) + 1. 解决 OOM 问题 (按年分批) + 2. 解决 跨批次重复问题 (前序状态合并) + 3. 解决 KeyError: symbol 问题 (手动生成 symbol) +--------- +@author: NBR +""" + +import os +import time +import gc +from datetime import datetime, timedelta + +import pandas as pd +import tushare as ts +from sqlalchemy import create_engine, text + +# --- 配置 --- +PROXY_DB_URL = os.environ.get('PROXY_DB_URL') +TUSHARE_TOKEN = os.environ.get('TUSHARE_TOKEN') + + +class _StockGroupManager: + """ + 股票分组管理器 + """ + REQUEST_INTERVAL = 0.05 + + INDEX_MAPPING = { + '沪深300': '000300.SH', + '中证500': '000905.SH', + '中证1000': '000852.SH', + '创业板': '399006.SZ', + '科创板': '000688.SH', + '上证指数': '000001.SH', + '深证成指': '399001.SZ' + } + + def __init__(self, token: str, db_engine, test_mode: bool = False): + self.test_mode = test_mode + self.pro = ts.pro_api(token) + if self.pro is None: + raise Exception("Tushare pro_api initialization failed.") + + self.engine = db_engine + self.index_constituents_cache = {} + self.TEST_STOCKS = ['300750.SZ', '688981.SH'] + self._init_database() + + def _init_database(self): + create_table_sql = """ + CREATE TABLE IF NOT EXISTS gp_stock_category ( + id INT AUTO_INCREMENT PRIMARY KEY, + ts_code VARCHAR(20) NOT NULL, + symbol VARCHAR(20) NOT NULL, + name VARCHAR(100) NOT NULL, + trade_date DATE NOT NULL, + industry VARCHAR(50), + is_st BOOLEAN DEFAULT FALSE, + index_series TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY unique_stock_date (ts_code, trade_date) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """ + with self.engine.connect() as connection: + with connection.begin() as trans: + connection.execute(text(create_table_sql)) + trans.commit() + + # --- 辅助方法 --- + def _get_month_start_end_dates(self, year_month: str): + try: + year = int(year_month[:4]) + month = int(year_month[4:6]) + month_start = datetime(year, month, 1) + if month == 12: + next_month = datetime(year + 1, 1, 1) + else: + next_month = datetime(year, month + 1, 1) + month_end = next_month - timedelta(days=1) + return month_start.strftime('%Y%m%d'), month_end.strftime('%Y%m%d') + except: + return year_month + '01', year_month + '31' + + def _get_previous_month_key(self, month_key: str) -> str: + try: + year = int(month_key[:4]) + month = int(month_key[4:]) + if month == 1: + return f"{year - 1:04d}12" + else: + return f"{year:04d}{month - 1:02d}" + except: + return month_key + + def _cache_month_index_constituents(self, month_key: str): + if month_key in self.index_constituents_cache: + return + + self.index_constituents_cache[month_key] = {} + prev_month_key = self._get_previous_month_key(month_key) + prev_data = self.index_constituents_cache.get(prev_month_key, {}) + month_start, month_end = self._get_month_start_end_dates(month_key) + + for index_name, index_code in self.INDEX_MAPPING.items(): + current_stocks = set() + try: + df = self.pro.index_weight(index_code=index_code, start_date=month_start, end_date=month_end) + if not df.empty: + current_stocks = set(df['con_code'].tolist()) + else: + # 前向填充 + if index_name in prev_data: + current_stocks = prev_data[index_name] + + self.index_constituents_cache[month_key][index_name] = current_stocks + time.sleep(self.REQUEST_INTERVAL) + except Exception as e: + print(f" 缓存{index_name}成分股失败: {e}") + if index_name in prev_data: + self.index_constituents_cache[month_key][index_name] = prev_data[index_name] + + def _get_index_series_for_stock(self, ts_code: str, trade_date: str) -> str: + try: + index_series = [] + month_key = trade_date[:6] + if month_key not in self.index_constituents_cache: + self._cache_month_index_constituents(month_key) + + month_data = self.index_constituents_cache.get(month_key, {}) + for index_name, stock_set in month_data.items(): + if ts_code in stock_set: + index_series.append(index_name) + return ','.join(index_series) if index_series else '' + except: + return '' + + def _is_st_stock(self, stock_name: str) -> bool: + if pd.isna(stock_name): return False + return 'ST' in stock_name or '*ST' in stock_name + + def _format_code_to_prefix_style(self, ts_code: str) -> str: + if '.' in ts_code: + parts = ts_code.split('.') + if len(parts) == 2: + return f"{parts[1].upper()}{parts[0]}" + return ts_code + + def get_latest_states_before(self, before_date_str: str) -> pd.DataFrame: + """获取前序基准状态""" + print(f" 🔍 查询 {before_date_str} 之前的基准状态 (用于跨年去重)...") + query = text(""" + SELECT t1.ts_code, t1.industry, t1.is_st, t1.index_series, t1.trade_date, t1.name + FROM gp_stock_category t1 + JOIN ( + SELECT ts_code, MAX(trade_date) as max_dt + FROM gp_stock_category + WHERE trade_date < :before_date + GROUP BY ts_code + ) t2 ON t1.ts_code = t2.ts_code AND t1.trade_date = t2.max_dt + """) + + try: + df = pd.read_sql(query, self.engine, params={'before_date': before_date_str}) + if not df.empty: + df['trade_date'] = pd.to_datetime(df['trade_date']) + + def _revert_code(c): + if len(c) > 2 and c[:2] in ['SZ', 'SH']: + return f"{c[2:]}.{c[:2]}" + return c + + df['ts_code'] = df['ts_code'].apply(_revert_code) + + print(f" ✅ 获取到 {len(df)} 条基准记录。") + return df + except Exception as e: + print(f" ⚠️ 获取基准状态失败 (可能是第一次运行): {e}") + return pd.DataFrame() + + def get_all_bak_basic_in_range(self, start_date: str, end_date: str) -> pd.DataFrame: + print(f" 正在获取 {start_date} 到 {end_date} 的日线数据...") + all_dates = pd.date_range(start=start_date, end=end_date, freq='B') + all_dfs = [] + + for i, date in enumerate(all_dates): + date_str = date.strftime('%Y%m%d') + try: + df = self.pro.bak_basic(trade_date=date_str) + if df is not None and not df.empty: + all_dfs.append(df) + time.sleep(self.REQUEST_INTERVAL) + except Exception as e: + print(f" 获取 {date_str} 数据失败: {e}") + + if not all_dfs: + return pd.DataFrame() + return pd.concat(all_dfs, ignore_index=True) + + def update_stock_category_data(self, start_date: str, end_date: str): + """核心处理逻辑""" + # 1. 基础日线 + daily_data_df = self.get_all_bak_basic_in_range(start_date, end_date) + if daily_data_df.empty: + print(" 无数据,本批次结束。") + return + + # =================【关键修复】================= + # 确保 symbol 列存在,因为后续 merge 需要用到它 + if 'symbol' not in daily_data_df.columns: + # 从 ts_code 生成 symbol (例如 000001.SZ -> 000001) + daily_data_df['symbol'] = daily_data_df['ts_code'].apply(lambda x: str(x).split('.')[0]) + # ============================================== + + print(f" 批量计算属性 ({len(daily_data_df)} 行)...") + daily_data_df['trade_date'] = pd.to_datetime(daily_data_df['trade_date'], format='%Y%m%d') + daily_data_df['is_st'] = daily_data_df['name'].apply(self._is_st_stock) + + # 2. 预加载指数缓存 + unique_months = sorted(daily_data_df['trade_date'].dt.strftime('%Y%m').unique()) + for ym in unique_months: + self._cache_month_index_constituents(ym) + + # 3. 计算指数归属 + daily_data_df['index_series'] = daily_data_df.apply( + lambda row: self._get_index_series_for_stock(row['ts_code'], row['trade_date'].strftime('%Y%m%d')), + axis=1 + ) + + # 4. 合并前序状态 + prev_state_df = self.get_latest_states_before(pd.to_datetime(start_date).strftime('%Y-%m-%d')) + + if not prev_state_df.empty: + cols = ['ts_code', 'trade_date', 'industry', 'is_st', 'index_series'] + combined_df = pd.concat([prev_state_df[cols], daily_data_df[cols]], ignore_index=True) + else: + combined_df = daily_data_df.copy() + + # 5. 变化检测 + combined_df.sort_values(['ts_code', 'trade_date'], inplace=True) + cols_to_check = ['industry', 'is_st', 'index_series'] + + shifted = combined_df.groupby('ts_code')[cols_to_check].shift(1) + is_changed = (combined_df[cols_to_check].ne(shifted)).any(axis=1) + + # 6. 筛选结果 + current_batch_start_dt = pd.to_datetime(start_date) + records_to_insert = combined_df[is_changed & (combined_df['trade_date'] >= current_batch_start_dt)].copy() + + # 7. 补回 symbol 和 name + if not records_to_insert.empty: + records_to_insert = pd.merge( + records_to_insert[['ts_code', 'trade_date', 'industry', 'is_st', 'index_series']], + daily_data_df[['ts_code', 'trade_date', 'symbol', 'name']], # 这里现在肯定有 symbol 了 + on=['ts_code', 'trade_date'], + how='left' + ) + + print(f" 筛选出 {len(records_to_insert)} 条真实变更记录。") + + # 8. 写入 + if not records_to_insert.empty: + records_to_insert['ts_code'] = records_to_insert['ts_code'].apply(self._format_code_to_prefix_style) + # 双重保险:如果 merge 后 symbol 仍有空值(极少情况),再次生成 + if 'symbol' not in records_to_insert.columns or records_to_insert['symbol'].isnull().any(): + records_to_insert['symbol'] = records_to_insert['ts_code'].apply(lambda x: x[2:] if len(x) > 2 else x) + + final_df = records_to_insert[ + ['ts_code', 'symbol', 'name', 'trade_date', 'industry', 'is_st', 'index_series']] + + with self.engine.connect() as connection: + with connection.begin() as trans: + del_start = pd.to_datetime(start_date).strftime('%Y-%m-%d') + del_end = pd.to_datetime(end_date).strftime('%Y-%m-%d') + connection.execute( + text("DELETE FROM gp_stock_category WHERE trade_date BETWEEN :s AND :e"), + {'s': del_start, 'e': del_end} + ) + final_df.to_sql('gp_stock_category', con=connection, if_exists='append', index=False) + trans.commit() + print(f" ✅ 写入成功。") + + # 9. 清理 + del daily_data_df + del combined_df + del records_to_insert + if not prev_state_df.empty: del prev_state_df + gc.collect() + + +# --- 运行入口 --- +def run_stock_grouper_etl(start_date: str = None, end_date: str = None, test_mode: bool = False, + is_full_run: bool = False): + if not PROXY_DB_URL or not TUSHARE_TOKEN: + raise ValueError("Environment variables required.") + + engine = create_engine(PROXY_DB_URL, pool_pre_ping=True) + manager = _StockGroupManager(TUSHARE_TOKEN, engine, test_mode) + + if is_full_run: + print("🚀 [Full Run] 启动分批处理模式 (按年)...") + current_year = 2016 + end_year = datetime.now().year + + while current_year <= end_year: + batch_start = f"{current_year}0101" + batch_end = f"{current_year}1231" + + if current_year == end_year: + batch_end = datetime.now().strftime('%Y%m%d') + if batch_start > batch_end: break + + print(f"\n>>> 处理批次: {batch_start} - {batch_end}") + manager.update_stock_category_data(batch_start, batch_end) + current_year += 1 + time.sleep(1) + + print("🎉 全量重跑完成。") + + else: + if not start_date: + with engine.connect() as conn: + res = conn.execute(text("SELECT MAX(trade_date) FROM gp_stock_category")).scalar() + if res: + start_date = (res + timedelta(days=1)).strftime('%Y%m%d') + else: + start_date = '20160101' + end_date = datetime.now().strftime('%Y%m%d') + + if start_date > datetime.now().strftime('%Y%m%d'): + return {"message": "No new data."} + + print(f"🚀 [Incremental] 增量更新: {start_date} - {end_date}") + manager.update_stock_category_data(start_date, end_date) + + return {"message": "Success"}