326 lines
12 KiB
Python
326 lines
12 KiB
Python
# coding:utf-8
|
||
# 判断企业生命周期
|
||
import pandas as pd
|
||
import pymongo
|
||
import logging
|
||
from typing import Dict, List, Optional
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目根目录到路径
|
||
# __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py
|
||
# 需要获取项目根目录: /app
|
||
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis
|
||
src_dir = os.path.dirname(current_file_dir) # /app/src
|
||
project_root = os.path.dirname(src_dir) # /app (项目根目录)
|
||
sys.path.append(project_root)
|
||
|
||
# 导入配置
|
||
try:
|
||
from valuation_analysis.config import MONGO_CONFIG2
|
||
except ImportError:
|
||
import importlib.util
|
||
config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py')
|
||
spec = importlib.util.spec_from_file_location("config", config_path)
|
||
config_module = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(config_module)
|
||
MONGO_CONFIG2 = config_module.MONGO_CONFIG2
|
||
|
||
# 导入股票代码格式化工具
|
||
try:
|
||
from src.tools.stock_code_formatter import StockCodeFormatter
|
||
except ImportError:
|
||
import importlib.util
|
||
# project_root 已经是项目根目录,直接拼接 tools 目录
|
||
formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py')
|
||
|
||
if not os.path.exists(formatter_path):
|
||
raise ImportError(f"无法找到 stock_code_formatter.py,路径: {formatter_path}")
|
||
|
||
spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path)
|
||
formatter_module = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(formatter_module)
|
||
StockCodeFormatter = formatter_module.StockCodeFormatter
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class CompanyLifecycleFactor:
|
||
"""企业生命周期阶段因子计算器"""
|
||
|
||
def __init__(self):
|
||
"""初始化"""
|
||
self.mongo_client = None
|
||
self.db = None
|
||
self.collection = None
|
||
self.connect_mongodb()
|
||
|
||
# 初始化股票代码格式化工具
|
||
self.stock_formatter = StockCodeFormatter()
|
||
|
||
# 定义企业生命周期阶段
|
||
self.lifecycle_stages = {
|
||
1: "引入期",
|
||
2: "成长期",
|
||
3: "成熟期",
|
||
4: "震荡期",
|
||
5: "衰退期"
|
||
}
|
||
|
||
# 现金流组合模式映射到生命周期阶段
|
||
self.cashflow_pattern_mapping = {
|
||
('负', '负', '正'): 1, # 引入期
|
||
('正', '负', '正'): 2, # 成长期
|
||
('正', '负', '负'): 3, # 成熟期
|
||
('负', '正', '正'): 4, # 震荡期
|
||
('正', '正', '负'): 5, # 衰退期
|
||
('负', '正', '负'): 4, # 震荡期(变种)
|
||
('负', '负', '负'): 4, # 震荡期(困难期)
|
||
('正', '正', '正'): 2, # 成长期(变种,现金充足)
|
||
}
|
||
|
||
def connect_mongodb(self):
|
||
"""连接MongoDB数据库"""
|
||
try:
|
||
self.mongo_client = pymongo.MongoClient(
|
||
host=MONGO_CONFIG2['host'],
|
||
port=MONGO_CONFIG2['port'],
|
||
username=MONGO_CONFIG2['username'],
|
||
password=MONGO_CONFIG2['password']
|
||
)
|
||
self.db = self.mongo_client[MONGO_CONFIG2['db']]
|
||
self.collection = self.db['eastmoney_financial_data_v2']
|
||
|
||
# 测试连接
|
||
self.mongo_client.admin.command('ping')
|
||
logger.info("MongoDB连接成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"MongoDB连接失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
|
||
|
||
def get_annual_financial_data(self, stock_code: str, year: int) -> Optional[Dict]:
|
||
"""
|
||
获取指定股票指定年份的年报数据
|
||
|
||
Args:
|
||
stock_code: 股票代码,支持多种格式 (300661.SZ, 300661, SZ300661)
|
||
year: 年份,如2024
|
||
|
||
Returns:
|
||
Dict: 年报财务数据,如果没有找到则返回None
|
||
"""
|
||
try:
|
||
# 标准化股票代码格式
|
||
normalized_code = self.stock_formatter.to_dot_format(stock_code)
|
||
# 构建年报日期(12-31结尾)
|
||
report_date = f"{year}-12-31"
|
||
|
||
# 查询指定股票指定年份的年报数据
|
||
query = {
|
||
"stock_code": normalized_code,
|
||
"report_date": report_date
|
||
}
|
||
|
||
annual_data = self.collection.find_one(query)
|
||
|
||
if annual_data:
|
||
# logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
|
||
return annual_data
|
||
else:
|
||
logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取年报数据失败: {stock_code} - {year} - {str(e)}")
|
||
return None
|
||
|
||
def extract_cashflow_values(self, financial_data: Dict) -> tuple:
|
||
"""
|
||
从财务数据中提取现金流量表的三个关键指标
|
||
|
||
Returns:
|
||
Tuple: (经营现金流净额, 投资现金流净额, 筹资现金流净额)
|
||
"""
|
||
try:
|
||
cash_flow_statement = financial_data.get('cash_flow_statement', {})
|
||
|
||
# 提取三个现金流指标
|
||
operating_cashflow = cash_flow_statement.get('NETCASH_OPERATE')
|
||
investing_cashflow = cash_flow_statement.get('NETCASH_INVEST')
|
||
financing_cashflow = cash_flow_statement.get('NETCASH_FINANCE')
|
||
|
||
# 转换为浮点数
|
||
def safe_float_convert(value):
|
||
if value is None or value == '':
|
||
return None
|
||
try:
|
||
return float(value)
|
||
except (ValueError, TypeError):
|
||
return None
|
||
|
||
operating_cashflow = safe_float_convert(operating_cashflow)
|
||
investing_cashflow = safe_float_convert(investing_cashflow)
|
||
financing_cashflow = safe_float_convert(financing_cashflow)
|
||
|
||
return operating_cashflow, investing_cashflow, financing_cashflow
|
||
|
||
except Exception as e:
|
||
logger.error(f"提取现金流数据失败: {str(e)}")
|
||
return None, None, None
|
||
|
||
def classify_cashflow_pattern(self, operating_cf: float, investing_cf: float, financing_cf: float) -> tuple:
|
||
"""将现金流数值分类为正负"""
|
||
def classify_value(value):
|
||
if value is None:
|
||
return "未知"
|
||
return "正" if value >= 0 else "负"
|
||
|
||
operating_pattern = classify_value(operating_cf)
|
||
investing_pattern = classify_value(investing_cf)
|
||
financing_pattern = classify_value(financing_cf)
|
||
|
||
return operating_pattern, investing_pattern, financing_pattern
|
||
|
||
def determine_lifecycle_stage(self, cashflow_pattern: tuple) -> int:
|
||
"""
|
||
根据现金流模式确定企业生命周期阶段
|
||
|
||
Returns:
|
||
int: 阶段ID (1-5),0表示未知
|
||
"""
|
||
stage_id = self.cashflow_pattern_mapping.get(cashflow_pattern, 0)
|
||
return stage_id
|
||
|
||
def calculate_lifecycle_factor(self, stock_code: str, year: int) -> Dict:
|
||
"""
|
||
计算指定股票指定年份的企业生命周期因子
|
||
|
||
Args:
|
||
stock_code: 股票代码,支持多种格式 (300661.SZ, 300661, SZ300661)
|
||
year: 年份
|
||
|
||
Returns:
|
||
Dict: 生命周期因子结果
|
||
"""
|
||
try:
|
||
# 获取年报数据
|
||
financial_data = self.get_annual_financial_data(stock_code, year)
|
||
if not financial_data:
|
||
return {
|
||
'stock_code': stock_code,
|
||
'year': year,
|
||
'stage_id': 0,
|
||
'stage_name': '数据缺失'
|
||
}
|
||
|
||
# 提取现金流数据
|
||
operating_cf, investing_cf, financing_cf = self.extract_cashflow_values(financial_data)
|
||
|
||
if None in [operating_cf, investing_cf, financing_cf]:
|
||
return {
|
||
'stock_code': stock_code,
|
||
'year': year,
|
||
'stage_id': 0,
|
||
'stage_name': '数据不完整'
|
||
}
|
||
|
||
# 分类现金流模式
|
||
cashflow_pattern = self.classify_cashflow_pattern(operating_cf, investing_cf, financing_cf)
|
||
|
||
# 确定生命周期阶段
|
||
stage_id = self.determine_lifecycle_stage(cashflow_pattern)
|
||
stage_name = self.lifecycle_stages.get(stage_id, '未知阶段')
|
||
|
||
return {
|
||
'stock_code': stock_code,
|
||
'year': year,
|
||
'stage_id': stage_id,
|
||
'stage_name': stage_name
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"计算生命周期因子失败: {stock_code} - {year} - {str(e)}")
|
||
return {
|
||
'stock_code': stock_code,
|
||
'year': year,
|
||
'stage_id': 0,
|
||
'stage_name': '计算失败'
|
||
}
|
||
|
||
def batch_calculate_lifecycle_factors(self, stock_codes: List[str], year: int) -> pd.DataFrame:
|
||
"""
|
||
批量计算多只股票指定年份的企业生命周期因子
|
||
|
||
Args:
|
||
stock_codes: 股票代码列表
|
||
year: 年份
|
||
|
||
Returns:
|
||
pd.DataFrame: 包含所有股票生命周期因子的DataFrame
|
||
"""
|
||
results = []
|
||
total_stocks = len(stock_codes)
|
||
|
||
logger.info(f"开始批量计算 {total_stocks} 只股票 {year} 年的企业生命周期因子")
|
||
|
||
for i, stock_code in enumerate(stock_codes, 1):
|
||
# 显示进度
|
||
if i % 100 == 0 or i == total_stocks:
|
||
progress = (i / total_stocks) * 100
|
||
logger.info(f"进度: [{i}/{total_stocks}] ({progress:.1f}%)")
|
||
|
||
result = self.calculate_lifecycle_factor(stock_code, year)
|
||
results.append(result)
|
||
|
||
# 转换为DataFrame
|
||
df = pd.DataFrame(results)
|
||
|
||
# 统计各阶段分布
|
||
stage_distribution = df['stage_name'].value_counts()
|
||
logger.info(f"{year}年企业生命周期阶段分布:")
|
||
for stage, count in stage_distribution.items():
|
||
percentage = (count / len(df)) * 100
|
||
logger.info(f" {stage}: {count} 只 ({percentage:.1f}%)")
|
||
|
||
return df
|
||
|
||
def __del__(self):
|
||
"""关闭数据库连接"""
|
||
if hasattr(self, 'mongo_client') and self.mongo_client:
|
||
self.mongo_client.close()
|
||
|
||
def main():
|
||
"""主函数示例"""
|
||
try:
|
||
# 创建生命周期因子计算器
|
||
lifecycle_calculator = CompanyLifecycleFactor()
|
||
|
||
# 示例1: 计算单只股票2024年的生命周期阶段
|
||
print("=== 单只股票分析示例 ===")
|
||
result = lifecycle_calculator.calculate_lifecycle_factor('600519.SH', 2024)
|
||
print(f"股票: {result['stock_code']}")
|
||
print(f"年份: {result['year']}")
|
||
print(f"生命周期阶段: {result['stage_name']}")
|
||
|
||
# 示例2: 批量分析
|
||
print("\n=== 批量分析示例 ===")
|
||
test_stocks = ['300879.SZ', '301123.SZ', '300884.SZ', '300918.SZ', '600908.SH']
|
||
df_results = lifecycle_calculator.batch_calculate_lifecycle_factors(test_stocks, 2024)
|
||
|
||
print("\n2024年生命周期阶段结果:")
|
||
print(df_results[['stock_code', 'stage_name']].to_string(index=False))
|
||
|
||
# 保存结果
|
||
# df_results.to_csv(f"company_lifecycle_{2024}.csv", index=False, encoding='utf-8-sig')
|
||
# print(f"\n结果已保存到: company_lifecycle_{2024}.csv")
|
||
|
||
except Exception as e:
|
||
logger.error(f"程序执行失败: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |