stock_fundamentals/src/quantitative_analysis/company_lifecycle_factor.py

326 lines
12 KiB
Python
Raw Normal View History

2025-07-03 15:57:04 +08:00
# coding:utf-8
# 判断企业生命周期
import pandas as pd
import pymongo
import logging
from typing import Dict, List, Optional
import sys
import os
# 添加项目根目录到路径
2025-11-28 15:34:10 +08:00
# __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py
# 需要获取项目根目录: /app
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis
src_dir = os.path.dirname(current_file_dir) # /app/src
project_root = os.path.dirname(src_dir) # /app (项目根目录)
2025-07-03 15:57:04 +08:00
sys.path.append(project_root)
# 导入配置
try:
from valuation_analysis.config import MONGO_CONFIG2
except ImportError:
import importlib.util
2025-11-28 15:34:10 +08:00
config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py')
2025-07-03 15:57:04 +08:00
spec = importlib.util.spec_from_file_location("config", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
MONGO_CONFIG2 = config_module.MONGO_CONFIG2
# 导入股票代码格式化工具
try:
from tools.stock_code_formatter import StockCodeFormatter
except ImportError:
import importlib.util
2025-11-28 15:34:10 +08:00
# project_root 已经是项目根目录,直接拼接 tools 目录
formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py')
if not os.path.exists(formatter_path):
raise ImportError(f"无法找到 stock_code_formatter.py路径: {formatter_path}")
2025-07-03 15:57:04 +08:00
spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path)
formatter_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(formatter_module)
StockCodeFormatter = formatter_module.StockCodeFormatter
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CompanyLifecycleFactor:
"""企业生命周期阶段因子计算器"""
def __init__(self):
"""初始化"""
self.mongo_client = None
self.db = None
self.collection = None
self.connect_mongodb()
# 初始化股票代码格式化工具
self.stock_formatter = StockCodeFormatter()
# 定义企业生命周期阶段
self.lifecycle_stages = {
1: "引入期",
2: "成长期",
3: "成熟期",
4: "震荡期",
5: "衰退期"
}
# 现金流组合模式映射到生命周期阶段
self.cashflow_pattern_mapping = {
('', '', ''): 1, # 引入期
('', '', ''): 2, # 成长期
('', '', ''): 3, # 成熟期
('', '', ''): 4, # 震荡期
('', '', ''): 5, # 衰退期
('', '', ''): 4, # 震荡期(变种)
('', '', ''): 4, # 震荡期(困难期)
('', '', ''): 2, # 成长期(变种,现金充足)
}
def connect_mongodb(self):
"""连接MongoDB数据库"""
try:
self.mongo_client = pymongo.MongoClient(
host=MONGO_CONFIG2['host'],
port=MONGO_CONFIG2['port'],
username=MONGO_CONFIG2['username'],
password=MONGO_CONFIG2['password']
)
self.db = self.mongo_client[MONGO_CONFIG2['db']]
self.collection = self.db['eastmoney_financial_data_v2']
# 测试连接
self.mongo_client.admin.command('ping')
logger.info("MongoDB连接成功")
except Exception as e:
logger.error(f"MongoDB连接失败: {str(e)}")
raise
def get_annual_financial_data(self, stock_code: str, year: int) -> Optional[Dict]:
"""
获取指定股票指定年份的年报数据
Args:
stock_code: 股票代码支持多种格式 (300661.SZ, 300661, SZ300661)
year: 年份如2024
Returns:
Dict: 年报财务数据如果没有找到则返回None
"""
try:
# 标准化股票代码格式
normalized_code = self.stock_formatter.to_dot_format(stock_code)
# 构建年报日期12-31结尾
report_date = f"{year}-12-31"
# 查询指定股票指定年份的年报数据
query = {
"stock_code": normalized_code,
"report_date": report_date
}
annual_data = self.collection.find_one(query)
if annual_data:
2025-11-28 15:34:10 +08:00
# logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
2025-07-03 15:57:04 +08:00
return annual_data
else:
logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
return None
except Exception as e:
logger.error(f"获取年报数据失败: {stock_code} - {year} - {str(e)}")
return None
def extract_cashflow_values(self, financial_data: Dict) -> tuple:
"""
从财务数据中提取现金流量表的三个关键指标
Returns:
Tuple: (经营现金流净额, 投资现金流净额, 筹资现金流净额)
"""
try:
cash_flow_statement = financial_data.get('cash_flow_statement', {})
# 提取三个现金流指标
operating_cashflow = cash_flow_statement.get('NETCASH_OPERATE')
investing_cashflow = cash_flow_statement.get('NETCASH_INVEST')
financing_cashflow = cash_flow_statement.get('NETCASH_FINANCE')
# 转换为浮点数
def safe_float_convert(value):
if value is None or value == '':
return None
try:
return float(value)
except (ValueError, TypeError):
return None
operating_cashflow = safe_float_convert(operating_cashflow)
investing_cashflow = safe_float_convert(investing_cashflow)
financing_cashflow = safe_float_convert(financing_cashflow)
return operating_cashflow, investing_cashflow, financing_cashflow
except Exception as e:
logger.error(f"提取现金流数据失败: {str(e)}")
return None, None, None
def classify_cashflow_pattern(self, operating_cf: float, investing_cf: float, financing_cf: float) -> tuple:
"""将现金流数值分类为正负"""
def classify_value(value):
if value is None:
return "未知"
return "" if value >= 0 else ""
operating_pattern = classify_value(operating_cf)
investing_pattern = classify_value(investing_cf)
financing_pattern = classify_value(financing_cf)
return operating_pattern, investing_pattern, financing_pattern
def determine_lifecycle_stage(self, cashflow_pattern: tuple) -> int:
"""
根据现金流模式确定企业生命周期阶段
Returns:
int: 阶段ID (1-5)0表示未知
"""
stage_id = self.cashflow_pattern_mapping.get(cashflow_pattern, 0)
return stage_id
def calculate_lifecycle_factor(self, stock_code: str, year: int) -> Dict:
"""
计算指定股票指定年份的企业生命周期因子
Args:
stock_code: 股票代码支持多种格式 (300661.SZ, 300661, SZ300661)
year: 年份
Returns:
Dict: 生命周期因子结果
"""
try:
# 获取年报数据
financial_data = self.get_annual_financial_data(stock_code, year)
if not financial_data:
return {
'stock_code': stock_code,
'year': year,
'stage_id': 0,
'stage_name': '数据缺失'
}
# 提取现金流数据
operating_cf, investing_cf, financing_cf = self.extract_cashflow_values(financial_data)
if None in [operating_cf, investing_cf, financing_cf]:
return {
'stock_code': stock_code,
'year': year,
'stage_id': 0,
'stage_name': '数据不完整'
}
# 分类现金流模式
cashflow_pattern = self.classify_cashflow_pattern(operating_cf, investing_cf, financing_cf)
# 确定生命周期阶段
stage_id = self.determine_lifecycle_stage(cashflow_pattern)
stage_name = self.lifecycle_stages.get(stage_id, '未知阶段')
return {
'stock_code': stock_code,
'year': year,
'stage_id': stage_id,
'stage_name': stage_name
}
except Exception as e:
logger.error(f"计算生命周期因子失败: {stock_code} - {year} - {str(e)}")
return {
'stock_code': stock_code,
'year': year,
'stage_id': 0,
'stage_name': '计算失败'
}
def batch_calculate_lifecycle_factors(self, stock_codes: List[str], year: int) -> pd.DataFrame:
"""
批量计算多只股票指定年份的企业生命周期因子
Args:
stock_codes: 股票代码列表
year: 年份
Returns:
pd.DataFrame: 包含所有股票生命周期因子的DataFrame
"""
results = []
total_stocks = len(stock_codes)
logger.info(f"开始批量计算 {total_stocks} 只股票 {year} 年的企业生命周期因子")
for i, stock_code in enumerate(stock_codes, 1):
# 显示进度
if i % 100 == 0 or i == total_stocks:
progress = (i / total_stocks) * 100
logger.info(f"进度: [{i}/{total_stocks}] ({progress:.1f}%)")
result = self.calculate_lifecycle_factor(stock_code, year)
results.append(result)
# 转换为DataFrame
df = pd.DataFrame(results)
# 统计各阶段分布
stage_distribution = df['stage_name'].value_counts()
logger.info(f"{year}年企业生命周期阶段分布:")
for stage, count in stage_distribution.items():
percentage = (count / len(df)) * 100
logger.info(f" {stage}: {count} 只 ({percentage:.1f}%)")
return df
def __del__(self):
"""关闭数据库连接"""
if hasattr(self, 'mongo_client') and self.mongo_client:
self.mongo_client.close()
def main():
"""主函数示例"""
try:
# 创建生命周期因子计算器
lifecycle_calculator = CompanyLifecycleFactor()
# 示例1: 计算单只股票2024年的生命周期阶段
print("=== 单只股票分析示例 ===")
result = lifecycle_calculator.calculate_lifecycle_factor('600519.SH', 2024)
print(f"股票: {result['stock_code']}")
print(f"年份: {result['year']}")
print(f"生命周期阶段: {result['stage_name']}")
# 示例2: 批量分析
print("\n=== 批量分析示例 ===")
test_stocks = ['300879.SZ', '301123.SZ', '300884.SZ', '300918.SZ', '600908.SH']
df_results = lifecycle_calculator.batch_calculate_lifecycle_factors(test_stocks, 2024)
print("\n2024年生命周期阶段结果:")
print(df_results[['stock_code', 'stage_name']].to_string(index=False))
# 保存结果
# df_results.to_csv(f"company_lifecycle_{2024}.csv", index=False, encoding='utf-8-sig')
# print(f"\n结果已保存到: company_lifecycle_{2024}.csv")
except Exception as e:
logger.error(f"程序执行失败: {str(e)}")
if __name__ == "__main__":
main()