stock_fundamentals/src/quantitative_analysis/company_lifecycle_factor.py

326 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding:utf-8
# 判断企业生命周期
import pandas as pd
import pymongo
import logging
from typing import Dict, List, Optional
import sys
import os
# 添加项目根目录到路径
# __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py
# 需要获取项目根目录: /app
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis
src_dir = os.path.dirname(current_file_dir) # /app/src
project_root = os.path.dirname(src_dir) # /app (项目根目录)
sys.path.append(project_root)
# 导入配置
try:
from valuation_analysis.config import MONGO_CONFIG2
except ImportError:
import importlib.util
config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py')
spec = importlib.util.spec_from_file_location("config", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
MONGO_CONFIG2 = config_module.MONGO_CONFIG2
# 导入股票代码格式化工具
try:
from src.tools.stock_code_formatter import StockCodeFormatter
except ImportError:
import importlib.util
# project_root 已经是项目根目录,直接拼接 tools 目录
formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py')
if not os.path.exists(formatter_path):
raise ImportError(f"无法找到 stock_code_formatter.py路径: {formatter_path}")
spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path)
formatter_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(formatter_module)
StockCodeFormatter = formatter_module.StockCodeFormatter
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CompanyLifecycleFactor:
"""企业生命周期阶段因子计算器"""
def __init__(self):
"""初始化"""
self.mongo_client = None
self.db = None
self.collection = None
self.connect_mongodb()
# 初始化股票代码格式化工具
self.stock_formatter = StockCodeFormatter()
# 定义企业生命周期阶段
self.lifecycle_stages = {
1: "引入期",
2: "成长期",
3: "成熟期",
4: "震荡期",
5: "衰退期"
}
# 现金流组合模式映射到生命周期阶段
self.cashflow_pattern_mapping = {
('', '', ''): 1, # 引入期
('', '', ''): 2, # 成长期
('', '', ''): 3, # 成熟期
('', '', ''): 4, # 震荡期
('', '', ''): 5, # 衰退期
('', '', ''): 4, # 震荡期(变种)
('', '', ''): 4, # 震荡期(困难期)
('', '', ''): 2, # 成长期(变种,现金充足)
}
def connect_mongodb(self):
"""连接MongoDB数据库"""
try:
self.mongo_client = pymongo.MongoClient(
host=MONGO_CONFIG2['host'],
port=MONGO_CONFIG2['port'],
username=MONGO_CONFIG2['username'],
password=MONGO_CONFIG2['password']
)
self.db = self.mongo_client[MONGO_CONFIG2['db']]
self.collection = self.db['eastmoney_financial_data_v2']
# 测试连接
self.mongo_client.admin.command('ping')
logger.info("MongoDB连接成功")
except Exception as e:
logger.error(f"MongoDB连接失败: {str(e)}")
raise
def get_annual_financial_data(self, stock_code: str, year: int) -> Optional[Dict]:
"""
获取指定股票指定年份的年报数据
Args:
stock_code: 股票代码,支持多种格式 (300661.SZ, 300661, SZ300661)
year: 年份如2024
Returns:
Dict: 年报财务数据如果没有找到则返回None
"""
try:
# 标准化股票代码格式
normalized_code = self.stock_formatter.to_dot_format(stock_code)
# 构建年报日期12-31结尾
report_date = f"{year}-12-31"
# 查询指定股票指定年份的年报数据
query = {
"stock_code": normalized_code,
"report_date": report_date
}
annual_data = self.collection.find_one(query)
if annual_data:
# logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
return annual_data
else:
logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
return None
except Exception as e:
logger.error(f"获取年报数据失败: {stock_code} - {year} - {str(e)}")
return None
def extract_cashflow_values(self, financial_data: Dict) -> tuple:
"""
从财务数据中提取现金流量表的三个关键指标
Returns:
Tuple: (经营现金流净额, 投资现金流净额, 筹资现金流净额)
"""
try:
cash_flow_statement = financial_data.get('cash_flow_statement', {})
# 提取三个现金流指标
operating_cashflow = cash_flow_statement.get('NETCASH_OPERATE')
investing_cashflow = cash_flow_statement.get('NETCASH_INVEST')
financing_cashflow = cash_flow_statement.get('NETCASH_FINANCE')
# 转换为浮点数
def safe_float_convert(value):
if value is None or value == '':
return None
try:
return float(value)
except (ValueError, TypeError):
return None
operating_cashflow = safe_float_convert(operating_cashflow)
investing_cashflow = safe_float_convert(investing_cashflow)
financing_cashflow = safe_float_convert(financing_cashflow)
return operating_cashflow, investing_cashflow, financing_cashflow
except Exception as e:
logger.error(f"提取现金流数据失败: {str(e)}")
return None, None, None
def classify_cashflow_pattern(self, operating_cf: float, investing_cf: float, financing_cf: float) -> tuple:
"""将现金流数值分类为正负"""
def classify_value(value):
if value is None:
return "未知"
return "" if value >= 0 else ""
operating_pattern = classify_value(operating_cf)
investing_pattern = classify_value(investing_cf)
financing_pattern = classify_value(financing_cf)
return operating_pattern, investing_pattern, financing_pattern
def determine_lifecycle_stage(self, cashflow_pattern: tuple) -> int:
"""
根据现金流模式确定企业生命周期阶段
Returns:
int: 阶段ID (1-5)0表示未知
"""
stage_id = self.cashflow_pattern_mapping.get(cashflow_pattern, 0)
return stage_id
def calculate_lifecycle_factor(self, stock_code: str, year: int) -> Dict:
"""
计算指定股票指定年份的企业生命周期因子
Args:
stock_code: 股票代码,支持多种格式 (300661.SZ, 300661, SZ300661)
year: 年份
Returns:
Dict: 生命周期因子结果
"""
try:
# 获取年报数据
financial_data = self.get_annual_financial_data(stock_code, year)
if not financial_data:
return {
'stock_code': stock_code,
'year': year,
'stage_id': 0,
'stage_name': '数据缺失'
}
# 提取现金流数据
operating_cf, investing_cf, financing_cf = self.extract_cashflow_values(financial_data)
if None in [operating_cf, investing_cf, financing_cf]:
return {
'stock_code': stock_code,
'year': year,
'stage_id': 0,
'stage_name': '数据不完整'
}
# 分类现金流模式
cashflow_pattern = self.classify_cashflow_pattern(operating_cf, investing_cf, financing_cf)
# 确定生命周期阶段
stage_id = self.determine_lifecycle_stage(cashflow_pattern)
stage_name = self.lifecycle_stages.get(stage_id, '未知阶段')
return {
'stock_code': stock_code,
'year': year,
'stage_id': stage_id,
'stage_name': stage_name
}
except Exception as e:
logger.error(f"计算生命周期因子失败: {stock_code} - {year} - {str(e)}")
return {
'stock_code': stock_code,
'year': year,
'stage_id': 0,
'stage_name': '计算失败'
}
def batch_calculate_lifecycle_factors(self, stock_codes: List[str], year: int) -> pd.DataFrame:
"""
批量计算多只股票指定年份的企业生命周期因子
Args:
stock_codes: 股票代码列表
year: 年份
Returns:
pd.DataFrame: 包含所有股票生命周期因子的DataFrame
"""
results = []
total_stocks = len(stock_codes)
logger.info(f"开始批量计算 {total_stocks} 只股票 {year} 年的企业生命周期因子")
for i, stock_code in enumerate(stock_codes, 1):
# 显示进度
if i % 100 == 0 or i == total_stocks:
progress = (i / total_stocks) * 100
logger.info(f"进度: [{i}/{total_stocks}] ({progress:.1f}%)")
result = self.calculate_lifecycle_factor(stock_code, year)
results.append(result)
# 转换为DataFrame
df = pd.DataFrame(results)
# 统计各阶段分布
stage_distribution = df['stage_name'].value_counts()
logger.info(f"{year}年企业生命周期阶段分布:")
for stage, count in stage_distribution.items():
percentage = (count / len(df)) * 100
logger.info(f" {stage}: {count} 只 ({percentage:.1f}%)")
return df
def __del__(self):
"""关闭数据库连接"""
if hasattr(self, 'mongo_client') and self.mongo_client:
self.mongo_client.close()
def main():
"""主函数示例"""
try:
# 创建生命周期因子计算器
lifecycle_calculator = CompanyLifecycleFactor()
# 示例1: 计算单只股票2024年的生命周期阶段
print("=== 单只股票分析示例 ===")
result = lifecycle_calculator.calculate_lifecycle_factor('600519.SH', 2024)
print(f"股票: {result['stock_code']}")
print(f"年份: {result['year']}")
print(f"生命周期阶段: {result['stage_name']}")
# 示例2: 批量分析
print("\n=== 批量分析示例 ===")
test_stocks = ['300879.SZ', '301123.SZ', '300884.SZ', '300918.SZ', '600908.SH']
df_results = lifecycle_calculator.batch_calculate_lifecycle_factors(test_stocks, 2024)
print("\n2024年生命周期阶段结果:")
print(df_results[['stock_code', 'stage_name']].to_string(index=False))
# 保存结果
# df_results.to_csv(f"company_lifecycle_{2024}.csv", index=False, encoding='utf-8-sig')
# print(f"\n结果已保存到: company_lifecycle_{2024}.csv")
except Exception as e:
logger.error(f"程序执行失败: {str(e)}")
if __name__ == "__main__":
main()