diff --git a/src/quantitative_analysis/tech_fundamental_factor_strategy_v3.py b/src/quantitative_analysis/tech_fundamental_factor_strategy_v3.py new file mode 100644 index 0000000..6f16dcb --- /dev/null +++ b/src/quantitative_analysis/tech_fundamental_factor_strategy_v3.py @@ -0,0 +1,831 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +科技主题基本面因子选股策略--这里就是入口--请执行这个文件! +整合企业生命周期、财务指标和平均距离因子分析 +""" + +import sys +import pandas as pd +import numpy as np +import logging +from typing import Dict, List, Tuple +from pathlib import Path +from sqlalchemy import create_engine, text +from datetime import datetime, timedelta +import math + +# 添加项目根路径到Python路径 +project_root = Path(__file__).parent.parent.parent +sys.path.append(str(project_root)) + +# 导入依赖的工具类 +from src.quantitative_analysis.company_lifecycle_factor import CompanyLifecycleFactor +from src.quantitative_analysis.financial_indicator_analyzer import FinancialIndicatorAnalyzer +from src.quantitative_analysis.average_distance_factor import AverageDistanceFactor +# from src.valuation_analysis.config import DB_URL + + +class DateHelper: + """日期处理工具类,根据输入日期计算应该使用的财务数据日期""" + + @staticmethod + def get_annual_date(target_date: str) -> str: + """ + 获取应该使用的年度数据日期 + + 规则:5月1日之前用上上年数据,5月1日之后用上年数据 + 例如:2025-11-03 -> 2024-12-31 (上年年报) + 2025-04-30 -> 2023-12-31 (上上年年报) + + Args: + target_date: 目标日期,格式 YYYY-MM-DD + + Returns: + 年度数据日期,格式 YYYY-MM-DD + """ + date_obj = datetime.strptime(target_date, '%Y-%m-%d') + year = date_obj.year + month = date_obj.month + day = date_obj.day + + # 5月1日之前用上上年数据 + if month < 5 or (month == 5 and day < 1): + annual_year = year - 2 + else: + annual_year = year - 1 + + return f"{annual_year}-12-31" + + @staticmethod + def get_quarter_date(target_date: str) -> str: + """ + 获取应该使用的季度数据日期 + + 规则: + - 5月1日发去年的年报和今年一季报 -> 使用一季报 (YYYY-03-31) + - 8月31日发半年报 -> 使用半年报 (YYYY-06-30) + - 10月31日发三季报 -> 使用三季报 (YYYY-09-30) + - 其他时间使用最新可用季度 + + Args: + target_date: 目标日期,格式 YYYY-MM-DD + + Returns: + 季度数据日期,格式 YYYY-MM-DD + """ + date_obj = datetime.strptime(target_date, '%Y-%m-%d') + year = date_obj.year + month = date_obj.month + day = date_obj.day + + # 5月1日之前:使用上上年的年报或上年的三季报 + if month < 5 or (month == 5 and day < 1): + # 使用上年的三季报 + return f"{year - 1}-09-30" + # 5月1日到8月31日:使用今年一季报 + elif month < 8 or (month == 8 and day <= 31): + return f"{year}-03-31" + # 8月31日到10月31日:使用半年报 + elif month < 10 or (month == 10 and day <= 31): + return f"{year}-06-30" + # 10月31日之后:使用三季报 + else: + return f"{year}-09-30" + + @staticmethod + def get_lifecycle_year(target_date: str) -> int: + """ + 获取生命周期分析应该使用的年份 + + Args: + target_date: 目标日期,格式 YYYY-MM-DD + + Returns: + 分析年份 + """ + date_obj = datetime.strptime(target_date, '%Y-%m-%d') + year = date_obj.year + month = date_obj.month + day = date_obj.day + + # 5月1日之前用上上年,之后用上年 + if month < 5 or (month == 5 and day < 1): + return year - 2 + else: + return year - 1 + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +class TechFundamentalFactorStrategy: + """科技主题基本面因子选股策略""" + + def __init__(self, target_date: str = None): + """ + 初始化策略 + + Args: + target_date: 目标日期,格式 YYYY-MM-DD,默认为今天 + """ + self.target_date = target_date or datetime.now().strftime('%Y-%m-%d') + self.date_helper = DateHelper() + + # 计算应该使用的数据日期 + self.annual_date = self.date_helper.get_annual_date(self.target_date) + self.quarter_date = self.date_helper.get_quarter_date(self.target_date) + self.lifecycle_year = self.date_helper.get_lifecycle_year(self.target_date) + + self.lifecycle_calculator = CompanyLifecycleFactor() + self.financial_analyzer = FinancialIndicatorAnalyzer() + DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db" + self.distance_calculator = AverageDistanceFactor(DB_URL) + + # MySQL连接 + self.mysql_engine = create_engine( + DB_URL, + pool_size=5, + max_overflow=10, + pool_recycle=3600 + ) + + # 结果表名 + self.result_table = 'tech_fundamental_factor_scores' + + # 科技概念板块列表 + self.tech_concepts = [ + "5G概念", "物联网", "云计算", "边缘计算", "信息安全", "国产软件", + "大数据", "数据中心", "芯片", "MCU芯片", "汽车芯片", "存储芯片", + "人工智能", "AIGC概念", "ChatGPT概念", "CPO概念", "华为鸿蒙", + "华为海思", "华为算力", "量子科技", "区块链", "数字货币", "工业互联", + "操作系统", "光刻机", "第三代半导体", "元宇宙概念", "云游戏", "信创", + "东数西算", "PCB概念", "先进封装", "EDA概念", "Web3概念", "数据确权", + "数据要素", "数字水印", "工业软件", "6G概念", "时空大数据", "算力租赁", + "光通信", "英伟达概念", "星闪概念", "液冷服务器", "多模态AI", "Sora概念", + "AI手机PC", "铜缆高速连接", "车联网", "财税数字化", "智谱AI", "AI智能体", + "DeepSeek概念", "AI医疗概念" + ] + + def get_tech_stocks(self) -> pd.DataFrame: + """ + 获取科技概念板块的股票列表 + + Returns: + pd.DataFrame: 包含股票代码和名称的DataFrame + """ + try: + # 构建查询条件 + concepts_str = "', '".join(self.tech_concepts) + query = text(f""" + SELECT DISTINCT gp_code as stock_code, gp_name as stock_name, bk_name as concept_name + FROM gp_gnbk + WHERE bk_name IN ('{concepts_str}') + ORDER BY gp_code + """) + + with self.mysql_engine.connect() as conn: + df = pd.read_sql(query, conn) + + return df + + except Exception as e: + logger.error(f"获取科技概念股票失败: {str(e)}") + return pd.DataFrame() + + def filter_by_lifecycle(self, stock_codes: List[str]) -> Dict[str, List[str]]: + """ + 根据企业生命周期筛选股票 + + Args: + stock_codes: 股票代码列表 + + Returns: + Dict: 包含成长期和成熟期股票的字典 + """ + try: + # 批量计算生命周期 + lifecycle_df = self.lifecycle_calculator.batch_calculate_lifecycle_factors(stock_codes, self.lifecycle_year) + + # 筛选目标阶段的股票 + # 引入期(1)和成长期(2)合并为成长期,成熟期(3)保持不变 + growth_stage_stocks = lifecycle_df[ + lifecycle_df['stage_id'].isin([1, 2]) + ]['stock_code'].tolist() + + mature_stage_stocks = lifecycle_df[ + lifecycle_df['stage_id'] == 3 + ]['stock_code'].tolist() + + result = { + 'growth': growth_stage_stocks, + 'mature': mature_stage_stocks + } + + return result + + except Exception as e: + logger.error(f"生命周期筛选失败: {str(e)}") + return {'growth': [], 'mature': []} + + def calculate_distance_factors(self, growth_stocks: List[str], mature_stocks: List[str], days=20) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + 分别计算成长期和成熟期股票的平均距离因子 + 使用目标日期前N天的数据计算 + + Args: + growth_stocks: 成长期股票列表 + mature_stocks: 成熟期股票列表 + days: 计算距离因子使用的天数(默认20天) + + Returns: + Tuple: (成长期距离因子DataFrame, 成熟期距离因子DataFrame) + """ + try: + # 将目标日期转换为datetime对象 + target_datetime = datetime.strptime(self.target_date, '%Y-%m-%d') + + growth_distance_df = pd.DataFrame() + mature_distance_df = pd.DataFrame() + + # 计算成长期股票距离因子 + if growth_stocks: + growth_data = self.distance_calculator.get_stock_data(growth_stocks, days=days, end_date=target_datetime) + if not growth_data.empty: + growth_indicators = self.distance_calculator.calculate_technical_indicators(growth_data, days=days) + growth_distance_df = self.distance_calculator.calculate_distance_factor(growth_indicators) + + # 计算成熟期股票距离因子 + if mature_stocks: + mature_data = self.distance_calculator.get_stock_data(mature_stocks, days=days, end_date=target_datetime) + if not mature_data.empty: + mature_indicators = self.distance_calculator.calculate_technical_indicators(mature_data, days=days) + mature_distance_df = self.distance_calculator.calculate_distance_factor(mature_indicators) + + return growth_distance_df, mature_distance_df + + except Exception as e: + logger.error(f"计算距离因子失败: {str(e)}") + import traceback + traceback.print_exc() + return pd.DataFrame(), pd.DataFrame() + + def calculate_common_factors(self, stock_codes: List[str]) -> pd.DataFrame: + """ + 计算通用因子 + + Args: + stock_codes: 股票代码列表 + + Returns: + pd.DataFrame: 包含通用因子的DataFrame + """ + try: + results = [] + + for stock_code in stock_codes: + try: + + factor_data = {'stock_code': stock_code} + + # 1. 毛利率(使用季度数据) + gross_margin = self.financial_analyzer.analyze_gross_profit_margin(stock_code, self.quarter_date) + factor_data['gross_profit_margin'] = gross_margin + + # 2. 成长能力指标 + growth_capability = self.financial_analyzer.analyze_growth_capability(stock_code) + if growth_capability is not None: + # 成长能力越高越好,使用sigmoid函数映射到0-1 + growth_score = 1 / (1 + math.exp(-growth_capability)) + else: + growth_score = 0.5 # 默认中性评分 + factor_data['growth_score'] = growth_score + + # 3. 前五大供应商占比(使用年报数据) + supplier_conc = self.financial_analyzer.analyze_supplier_concentration(stock_code, self.annual_date) + factor_data['supplier_concentration'] = supplier_conc + + # 4. 前五大客户占比(使用年报数据) + customer_conc = self.financial_analyzer.analyze_customer_concentration(stock_code, self.annual_date) + factor_data['customer_concentration'] = customer_conc + + results.append(factor_data) + + except Exception as e: + logger.warning(f"计算股票 {stock_code} 通用因子失败: {str(e)}") + continue + + df = pd.DataFrame(results) + return df + + except Exception as e: + logger.error(f"计算通用因子失败: {str(e)}") + return pd.DataFrame() + + def calculate_growth_specific_factors(self, stock_codes: List[str]) -> pd.DataFrame: + """ + 计算成长期特色因子 + + Args: + stock_codes: 成长期股票代码列表 + + Returns: + pd.DataFrame: 包含成长期特色因子的DataFrame + """ + try: + results = [] + + for stock_code in stock_codes: + try: + + factor_data = {'stock_code': stock_code} + + # 1. 管理费用率(使用季度数据) + admin_ratio = self.financial_analyzer.analyze_admin_expense_ratio(stock_code, self.quarter_date) + factor_data['admin_expense_ratio'] = admin_ratio + + # 2. 研发费用折旧摊销占比(使用年度数据) + financial_data = self.financial_analyzer.get_financial_data(stock_code, self.annual_date) + if financial_data: + intangible_amortize = financial_data.get('cash_flow_statement', {}).get('IA_AMORTIZE', 0) + rd_expense = financial_data.get('profit_statement', {}).get('RESEARCH_EXPENSE', 0) + + if rd_expense and rd_expense != 0: + rd_amortize_ratio = intangible_amortize / rd_expense if intangible_amortize else 0 + else: + rd_amortize_ratio = None # 使用None而不是0,避免这些股票获得最高分 + + factor_data['rd_amortize_ratio'] = rd_amortize_ratio + else: + factor_data['rd_amortize_ratio'] = None + + # 3. 资产负债率(使用季度数据) + asset_liability_ratio = self.financial_analyzer.analyze_asset_liability_ratio(stock_code, self.quarter_date) + factor_data['asset_liability_ratio'] = asset_liability_ratio + + results.append(factor_data) + + except Exception as e: + logger.warning(f"计算股票 {stock_code} 成长期特色因子失败: {str(e)}") + continue + + df = pd.DataFrame(results) + return df + + except Exception as e: + logger.error(f"计算成长期特色因子失败: {str(e)}") + return pd.DataFrame() + + def calculate_mature_specific_factors(self, stock_codes: List[str]) -> pd.DataFrame: + """ + 计算成熟期特色因子 + + Args: + stock_codes: 成熟期股票代码列表 + + Returns: + pd.DataFrame: 包含成熟期特色因子的DataFrame + """ + try: + # 在循环外获取全A股PB和ROE数据,避免重复查询 + all_pb_data = self.financial_analyzer.get_all_stocks_pb_data() + all_roe_data = self.financial_analyzer.get_all_stocks_roe_data(self.quarter_date) + + results = [] + + for stock_code in stock_codes: + try: + factor_data = {'stock_code': stock_code} + + # 1. 应收账款周转率(使用季度数据) + formatted_stock_code = self.financial_analyzer.code_formatter.to_dot_format(stock_code) + financial_data = self.financial_analyzer.get_financial_data(formatted_stock_code, self.quarter_date) + if financial_data: + revenue = financial_data.get('profit_statement', {}).get('OPERATE_INCOME', 0) + accounts_rece = financial_data.get('balance_sheet', {}).get('ACCOUNTS_RECE', 0) + + if accounts_rece and accounts_rece != 0: + turnover_ratio = revenue / accounts_rece if revenue else 0 + else: + turnover_ratio = None # 使用None而不是0 + + factor_data['accounts_receivable_turnover'] = turnover_ratio + else: + factor_data['accounts_receivable_turnover'] = None + + # 2. 研发强度(使用季度数据) + rd_intensity = self.financial_analyzer.analyze_rd_expense_ratio(stock_code, self.quarter_date) + factor_data['rd_intensity'] = rd_intensity + + # 3. PB-ROE排名因子:使用预获取的全A股数据 + if all_pb_data and all_roe_data: + pb_roe_rank_factor = self.financial_analyzer.calculate_pb_roe_rank_factor( + stock_code, all_pb_data, all_roe_data + ) + factor_data['pb_roe_rank_factor'] = pb_roe_rank_factor + else: + factor_data['pb_roe_rank_factor'] = None + + results.append(factor_data) + + except Exception as e: + logger.warning(f"计算股票 {stock_code} 成熟期特色因子失败: {str(e)}") + continue + + df = pd.DataFrame(results) + return df + + except Exception as e: + logger.error(f"计算成熟期特色因子失败: {str(e)}") + return pd.DataFrame() + + def run_strategy(self) -> Dict[str, pd.DataFrame]: + """ + 运行完整的选股策略 + + Returns: + Dict: 包含成长期和成熟期股票分析结果的字典 + """ + try: + # 1. 获取科技概念股票 + tech_stocks_df = self.get_tech_stocks() + if tech_stocks_df.empty: + logger.error("未获取到科技概念股票") + return {} + + stock_codes = tech_stocks_df['stock_code'].unique().tolist() + + # 2. 按企业生命周期筛选 + lifecycle_result = self.filter_by_lifecycle(stock_codes) + growth_stocks = lifecycle_result['growth'] + mature_stocks = lifecycle_result['mature'] + + if not growth_stocks and not mature_stocks: + logger.warning("未找到符合条件的成长期或成熟期股票") + return {} + + # 3. 计算平均距离因子 + growth_distance_df, mature_distance_df = self.calculate_distance_factors(growth_stocks, mature_stocks) + + # 4. 计算通用因子 + all_qualified_stocks = growth_stocks + mature_stocks + common_factors_df = self.calculate_common_factors(all_qualified_stocks) + + # 5. 计算特色因子 + growth_specific_df = self.calculate_growth_specific_factors(growth_stocks) if growth_stocks else pd.DataFrame() + mature_specific_df = self.calculate_mature_specific_factors(mature_stocks) if mature_stocks else pd.DataFrame() + + # 6. 合并结果并计算分数 + result = {} + + # 处理成长期股票 + if not growth_specific_df.empty: + # 成长期结果合并 + growth_result = growth_specific_df.copy() + + # 合并距离因子 + if not growth_distance_df.empty: + growth_result = growth_result.merge( + growth_distance_df[['symbol', 'avg_distance_factor']], + left_on='stock_code', right_on='symbol', how='left' + ).drop('symbol', axis=1) + + # 合并通用因子 + if not common_factors_df.empty: + growth_result = growth_result.merge( + common_factors_df, on='stock_code', how='left' + ) + + # 计算因子分数 + growth_result = self.calculate_factor_scores(growth_result, 'growth') + + # 计算总分并排序 + growth_result = self.calculate_total_score(growth_result, 'growth') + + result['growth'] = growth_result + + # 处理成熟期股票 + if not mature_specific_df.empty: + # 成熟期结果合并 + mature_result = mature_specific_df.copy() + + # 合并距离因子 + if not mature_distance_df.empty: + mature_result = mature_result.merge( + mature_distance_df[['symbol', 'avg_distance_factor']], + left_on='stock_code', right_on='symbol', how='left' + ).drop('symbol', axis=1) + + # 合并通用因子 + if not common_factors_df.empty: + mature_result = mature_result.merge( + common_factors_df, on='stock_code', how='left' + ) + + # 计算因子分数 + mature_result = self.calculate_factor_scores(mature_result, 'mature') + + # 计算总分并排序 + mature_result = self.calculate_total_score(mature_result, 'mature') + + result['mature'] = mature_result + + # 7. 保存结果到数据库 + if result: + self.save_results_to_db(result) + + return result + + except Exception as e: + logger.error(f"策略运行失败: {str(e)}") + return {} + + def calculate_factor_scores(self, df: pd.DataFrame, stage: str) -> pd.DataFrame: + """ + 计算单因子打分(0-100分位数) + + Args: + df: 包含因子数据的DataFrame + stage: 阶段类型 ('growth' 或 'mature') + + Returns: + pd.DataFrame: 包含因子分数的DataFrame + """ + try: + if df.empty: + return df + + df_scored = df.copy() + + # 定义因子方向(正向为True,负向为False) + factor_directions = { + # 通用因子 + 'gross_profit_margin': True, # 毛利率_环比增量 - 正向 + 'growth_score': True, # 成长能力 - 正向 + 'supplier_concentration': False, # 前5大供应商金额占比合计 - 负向 + 'customer_concentration': False, # 前5大客户收入金额占比合计 - 负向 + 'avg_distance_factor': False, # 平均距离因子 - 负向 + + # 成长期特色因子 + 'admin_expense_ratio': False, # 管理费用/营业总收入_环比增量 - 负向 + 'rd_amortize_ratio': False, # 研发费用折旧摊销占比_环比增量 - 负向 + 'asset_liability_ratio': True, # 资产负债率 - 正向 + + # 成熟期特色因子 + 'accounts_receivable_turnover': True, # 应收账款周转率 - 正向 + 'rd_intensity': True, # 研发费用直接投入占比_环比增量 - 正向 + 'pb_roe_rank_factor': False # PB-ROE排名因子 - 负向(越小越好) + } + + # 为每个因子计算分位数分数 + for column in df.columns: + if column == 'stock_code': + continue + + # 只对有效值进行排名计算 + values = df_scored[column].dropna() + if len(values) <= 1: + # 如果只有一个值或没有值,所有股票都得50分或0分 + if len(values) == 1: + df_scored[f'{column}_score'] = df_scored[column].apply(lambda x: 50 if pd.notna(x) else 0) + else: + df_scored[f'{column}_score'] = 0 + continue + + # 根据因子方向确定排序方式 + is_positive = factor_directions.get(column, True) + + # 计算排名分数 + if is_positive: + # 正向因子:值越大分数越高 + ranked_values = values.rank(pct=True) * 100 + else: + # 负向因子:值越小分数越高 + ranked_values = (1 - values.rank(pct=True)) * 100 + + # 初始化分数列 + df_scored[f'{column}_score'] = 0.0 + + # 将分数赋值给对应的行 + for idx in ranked_values.index: + df_scored.loc[idx, f'{column}_score'] = ranked_values[idx] + + return df_scored + + except Exception as e: + logger.error(f"计算因子分数失败: {str(e)}") + import traceback + traceback.print_exc() + return df + + def calculate_total_score(self, df: pd.DataFrame, stage: str) -> pd.DataFrame: + """ + 计算总分 + 使用公式:总分 = 1/8 * Mean(Si) + Mean(Si)/Std(Si) + + Args: + df: 包含因子分数的DataFrame + stage: 阶段类型 ('growth' 或 'mature') + + Returns: + pd.DataFrame: 包含总分的DataFrame + """ + try: + if df.empty: + return df + + df_result = df.copy() + + # 定义因子权重(注意:这里是factor_score而不是factor) + if stage == 'growth': + factor_weights = { + # 通用因子 + 'gross_profit_margin_score': 1/8, + 'growth_score_score': 1/8, # 注意这里是growth_score_score + 'supplier_concentration_score': 1/8, + 'customer_concentration_score': 1/8, + 'avg_distance_factor_score': 1/8, + + # 成长期特色因子 + 'admin_expense_ratio_score': 1/8, + 'rd_amortize_ratio_score': 1/8, + 'asset_liability_ratio_score': 1/8 + } + else: # mature + factor_weights = { + # 通用因子 + 'gross_profit_margin_score': 1/8, + 'growth_score_score': 1/8, # 注意这里是growth_score_score + 'supplier_concentration_score': 1/8, + 'customer_concentration_score': 1/8, + 'avg_distance_factor_score': 1/8, + + # 成熟期特色因子 + 'accounts_receivable_turnover_score': 1/8, + 'rd_intensity_score': 1/8, + 'pb_roe_rank_factor_score': 1/8 + } + + # 计算每只股票的总分 + total_scores = [] + + for index, row in df_result.iterrows(): + # 获取该股票的所有因子分数 + factor_scores = [] + valid_weights = [] + + for factor, weight in factor_weights.items(): + if factor in row and pd.notna(row[factor]) and row[factor] > 0: + factor_scores.append(row[factor]) + valid_weights.append(weight) + + if len(factor_scores) == 0: + total_scores.append(0) + continue + + factor_scores = np.array(factor_scores) + valid_weights = np.array(valid_weights) + + # 重新标准化权重 + valid_weights = valid_weights / valid_weights.sum() + + # 计算加权平均分数 + mean_score = np.average(factor_scores, weights=valid_weights) + + # 计算调整项 Mean(Si)/Std(Si) + if len(factor_scores) > 1 and np.std(factor_scores) > 0: + adjustment = np.mean(factor_scores) / np.std(factor_scores) + else: + adjustment = 0 + + # 计算总分:1/8 * Mean(Si) + Mean(Si)/Std(Si) + total_score = (1/8) * mean_score + adjustment + total_scores.append(total_score) + + df_result['total_score'] = total_scores + + # 按总分降序排列 + df_result = df_result.sort_values('total_score', ascending=False).reset_index(drop=True) + df_result['rank'] = range(1, len(df_result) + 1) + + return df_result + + except Exception as e: + logger.error(f"计算总分失败: {str(e)}") + import traceback + traceback.print_exc() + return df + + def save_results_to_db(self, results: Dict[str, pd.DataFrame]) -> bool: + """ + 将策略结果保存到数据库 + + Args: + results: 策略结果字典,包含 'growth' 和 'mature' 两个DataFrame + + Returns: + bool: 是否保存成功 + """ + try: + if not results: + return False + + # 准备入库数据 + records = [] + for stage, df in results.items(): + if df.empty or 'stock_code' not in df.columns or 'total_score' not in df.columns: + continue + + for _, row in df.iterrows(): + stock_code = row['stock_code'] + total_score = row.get('total_score', 0) + + # 跳过总分为空或0的记录 + if pd.isna(total_score) or total_score == 0: + continue + + records.append({ + 'trade_date': self.target_date, + 'stock_code': stock_code, + 'stage': stage, # 'growth' 或 'mature' + 'total_score': float(total_score), + 'created_at': datetime.now(), + 'updated_at': datetime.now() + }) + + if not records: + return False + + # 删除该日期的旧数据 + with self.mysql_engine.begin() as conn: + delete_sql = text(f""" + DELETE FROM {self.result_table} + WHERE trade_date = :trade_date + """) + conn.execute(delete_sql, {"trade_date": self.target_date}) + + # 批量插入新数据 + if records: + df_to_save = pd.DataFrame(records) + df_to_save.to_sql( + self.result_table, + self.mysql_engine, + if_exists='append', + index=False + ) + + return True + + except Exception as e: + logger.error(f"保存策略结果到数据库失败: {str(e)}") + import traceback + traceback.print_exc() + return False + + def close_connections(self): + """关闭所有数据库连接""" + try: + if hasattr(self, 'lifecycle_calculator'): + del self.lifecycle_calculator + if hasattr(self, 'financial_analyzer'): + self.financial_analyzer.close_connection() + if hasattr(self, 'distance_calculator'): + del self.distance_calculator + if hasattr(self, 'mysql_engine'): + self.mysql_engine.dispose() + except Exception as e: + logger.error(f"关闭连接失败: {str(e)}") + + +def main(): + """主函数 - 科技主题基本面因子选股策略""" + strategy = None + try: + # 创建策略实例(可以指定日期,默认为今天) + import sys + if len(sys.argv) > 1: + target_date = sys.argv[1] # 从命令行参数获取日期 + else: + target_date = "2025-09-01" # 默认日期,可以修改 + + strategy = TechFundamentalFactorStrategy(target_date=target_date) + + # 运行策略(结果会自动保存到数据库) + results = strategy.run_strategy() + + except Exception as e: + logger.error(f"程序执行失败: {str(e)}") + import traceback + traceback.print_exc() + finally: + if strategy: + strategy.close_connections() + + +if __name__ == "__main__": + main() \ No newline at end of file