# coding:utf-8 # 判断企业生命周期 import pandas as pd import pymongo import logging from typing import Dict, List, Optional import sys import os # 添加项目根目录到路径 # __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py # 需要获取项目根目录: /app current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis src_dir = os.path.dirname(current_file_dir) # /app/src project_root = os.path.dirname(src_dir) # /app (项目根目录) sys.path.append(project_root) # 导入配置 try: from valuation_analysis.config import MONGO_CONFIG2 except ImportError: import importlib.util config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py') spec = importlib.util.spec_from_file_location("config", config_path) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) MONGO_CONFIG2 = config_module.MONGO_CONFIG2 # 导入股票代码格式化工具 try: from tools.stock_code_formatter import StockCodeFormatter except ImportError: import importlib.util # project_root 已经是项目根目录,直接拼接 tools 目录 formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py') if not os.path.exists(formatter_path): raise ImportError(f"无法找到 stock_code_formatter.py,路径: {formatter_path}") spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path) formatter_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(formatter_module) StockCodeFormatter = formatter_module.StockCodeFormatter # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CompanyLifecycleFactor: """企业生命周期阶段因子计算器""" def __init__(self): """初始化""" self.mongo_client = None self.db = None self.collection = None self.connect_mongodb() # 初始化股票代码格式化工具 self.stock_formatter = StockCodeFormatter() # 定义企业生命周期阶段 self.lifecycle_stages = { 1: "引入期", 2: "成长期", 3: "成熟期", 4: "震荡期", 5: "衰退期" } # 现金流组合模式映射到生命周期阶段 self.cashflow_pattern_mapping = { ('负', '负', '正'): 1, # 引入期 ('正', '负', '正'): 2, # 成长期 ('正', '负', '负'): 3, # 成熟期 ('负', '正', '正'): 4, # 震荡期 ('正', '正', '负'): 5, # 衰退期 ('负', '正', '负'): 4, # 震荡期(变种) ('负', '负', '负'): 4, # 震荡期(困难期) ('正', '正', '正'): 2, # 成长期(变种,现金充足) } def connect_mongodb(self): """连接MongoDB数据库""" try: self.mongo_client = pymongo.MongoClient( host=MONGO_CONFIG2['host'], port=MONGO_CONFIG2['port'], username=MONGO_CONFIG2['username'], password=MONGO_CONFIG2['password'] ) self.db = self.mongo_client[MONGO_CONFIG2['db']] self.collection = self.db['eastmoney_financial_data_v2'] # 测试连接 self.mongo_client.admin.command('ping') logger.info("MongoDB连接成功") except Exception as e: logger.error(f"MongoDB连接失败: {str(e)}") raise def get_annual_financial_data(self, stock_code: str, year: int) -> Optional[Dict]: """ 获取指定股票指定年份的年报数据 Args: stock_code: 股票代码,支持多种格式 (300661.SZ, 300661, SZ300661) year: 年份,如2024 Returns: Dict: 年报财务数据,如果没有找到则返回None """ try: # 标准化股票代码格式 normalized_code = self.stock_formatter.to_dot_format(stock_code) # 构建年报日期(12-31结尾) report_date = f"{year}-12-31" # 查询指定股票指定年份的年报数据 query = { "stock_code": normalized_code, "report_date": report_date } annual_data = self.collection.find_one(query) if annual_data: # logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") return annual_data else: logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}") return None except Exception as e: logger.error(f"获取年报数据失败: {stock_code} - {year} - {str(e)}") return None def extract_cashflow_values(self, financial_data: Dict) -> tuple: """ 从财务数据中提取现金流量表的三个关键指标 Returns: Tuple: (经营现金流净额, 投资现金流净额, 筹资现金流净额) """ try: cash_flow_statement = financial_data.get('cash_flow_statement', {}) # 提取三个现金流指标 operating_cashflow = cash_flow_statement.get('NETCASH_OPERATE') investing_cashflow = cash_flow_statement.get('NETCASH_INVEST') financing_cashflow = cash_flow_statement.get('NETCASH_FINANCE') # 转换为浮点数 def safe_float_convert(value): if value is None or value == '': return None try: return float(value) except (ValueError, TypeError): return None operating_cashflow = safe_float_convert(operating_cashflow) investing_cashflow = safe_float_convert(investing_cashflow) financing_cashflow = safe_float_convert(financing_cashflow) return operating_cashflow, investing_cashflow, financing_cashflow except Exception as e: logger.error(f"提取现金流数据失败: {str(e)}") return None, None, None def classify_cashflow_pattern(self, operating_cf: float, investing_cf: float, financing_cf: float) -> tuple: """将现金流数值分类为正负""" def classify_value(value): if value is None: return "未知" return "正" if value >= 0 else "负" operating_pattern = classify_value(operating_cf) investing_pattern = classify_value(investing_cf) financing_pattern = classify_value(financing_cf) return operating_pattern, investing_pattern, financing_pattern def determine_lifecycle_stage(self, cashflow_pattern: tuple) -> int: """ 根据现金流模式确定企业生命周期阶段 Returns: int: 阶段ID (1-5),0表示未知 """ stage_id = self.cashflow_pattern_mapping.get(cashflow_pattern, 0) return stage_id def calculate_lifecycle_factor(self, stock_code: str, year: int) -> Dict: """ 计算指定股票指定年份的企业生命周期因子 Args: stock_code: 股票代码,支持多种格式 (300661.SZ, 300661, SZ300661) year: 年份 Returns: Dict: 生命周期因子结果 """ try: # 获取年报数据 financial_data = self.get_annual_financial_data(stock_code, year) if not financial_data: return { 'stock_code': stock_code, 'year': year, 'stage_id': 0, 'stage_name': '数据缺失' } # 提取现金流数据 operating_cf, investing_cf, financing_cf = self.extract_cashflow_values(financial_data) if None in [operating_cf, investing_cf, financing_cf]: return { 'stock_code': stock_code, 'year': year, 'stage_id': 0, 'stage_name': '数据不完整' } # 分类现金流模式 cashflow_pattern = self.classify_cashflow_pattern(operating_cf, investing_cf, financing_cf) # 确定生命周期阶段 stage_id = self.determine_lifecycle_stage(cashflow_pattern) stage_name = self.lifecycle_stages.get(stage_id, '未知阶段') return { 'stock_code': stock_code, 'year': year, 'stage_id': stage_id, 'stage_name': stage_name } except Exception as e: logger.error(f"计算生命周期因子失败: {stock_code} - {year} - {str(e)}") return { 'stock_code': stock_code, 'year': year, 'stage_id': 0, 'stage_name': '计算失败' } def batch_calculate_lifecycle_factors(self, stock_codes: List[str], year: int) -> pd.DataFrame: """ 批量计算多只股票指定年份的企业生命周期因子 Args: stock_codes: 股票代码列表 year: 年份 Returns: pd.DataFrame: 包含所有股票生命周期因子的DataFrame """ results = [] total_stocks = len(stock_codes) logger.info(f"开始批量计算 {total_stocks} 只股票 {year} 年的企业生命周期因子") for i, stock_code in enumerate(stock_codes, 1): # 显示进度 if i % 100 == 0 or i == total_stocks: progress = (i / total_stocks) * 100 logger.info(f"进度: [{i}/{total_stocks}] ({progress:.1f}%)") result = self.calculate_lifecycle_factor(stock_code, year) results.append(result) # 转换为DataFrame df = pd.DataFrame(results) # 统计各阶段分布 stage_distribution = df['stage_name'].value_counts() logger.info(f"{year}年企业生命周期阶段分布:") for stage, count in stage_distribution.items(): percentage = (count / len(df)) * 100 logger.info(f" {stage}: {count} 只 ({percentage:.1f}%)") return df def __del__(self): """关闭数据库连接""" if hasattr(self, 'mongo_client') and self.mongo_client: self.mongo_client.close() def main(): """主函数示例""" try: # 创建生命周期因子计算器 lifecycle_calculator = CompanyLifecycleFactor() # 示例1: 计算单只股票2024年的生命周期阶段 print("=== 单只股票分析示例 ===") result = lifecycle_calculator.calculate_lifecycle_factor('600519.SH', 2024) print(f"股票: {result['stock_code']}") print(f"年份: {result['year']}") print(f"生命周期阶段: {result['stage_name']}") # 示例2: 批量分析 print("\n=== 批量分析示例 ===") test_stocks = ['300879.SZ', '301123.SZ', '300884.SZ', '300918.SZ', '600908.SH'] df_results = lifecycle_calculator.batch_calculate_lifecycle_factors(test_stocks, 2024) print("\n2024年生命周期阶段结果:") print(df_results[['stock_code', 'stage_name']].to_string(index=False)) # 保存结果 # df_results.to_csv(f"company_lifecycle_{2024}.csv", index=False, encoding='utf-8-sig') # print(f"\n结果已保存到: company_lifecycle_{2024}.csv") except Exception as e: logger.error(f"程序执行失败: {str(e)}") if __name__ == "__main__": main()