stock_fundamentals/src/valuation_analysis/industry_analysis.py

565 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
行业估值分析模块
提供行业历史PE、PB、PS分位数分析功能以及行业拥挤度指标包括
1. 行业历史PE、PB、PS数据获取
2. 分位数计算
3. 行业交易拥挤度计算
"""
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text
import datetime
import logging
import json
import redis
import time
from typing import Tuple, Dict, List, Optional, Union
from .config import DB_URL, OUTPUT_DIR, LOG_FILE
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(LOG_FILE),
logging.StreamHandler()
]
)
logger = logging.getLogger("industry_analysis")
# 添加Redis客户端
redis_client = redis.Redis(
host='192.168.18.208', # Redis服务器地址根据实际情况调整
port=6379,
password='wlkj2018',
db=13,
socket_timeout=5,
decode_responses=True
)
class IndustryAnalyzer:
"""行业估值分析器类"""
def __init__(self, db_url: str = DB_URL):
"""
初始化行业估值分析器
Args:
db_url: 数据库连接URL
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600
)
logger.info("行业估值分析器初始化完成")
def get_industry_list(self) -> List[Dict]:
"""
获取所有行业列表
Returns:
行业列表每个行业为一个字典包含code和name
"""
try:
query = text("""
SELECT DISTINCT bk_code, bk_name
FROM gp_hybk
ORDER BY bk_name
""")
with self.engine.connect() as conn:
result = conn.execute(query).fetchall()
if result:
return [{"code": str(row[0]), "name": row[1]} for row in result]
else:
logger.warning("未找到行业数据")
return []
except Exception as e:
logger.error(f"获取行业列表失败: {e}")
return []
def get_concept_list(self) -> List[Dict]:
"""
获取所有概念板块列表
Returns:
概念板块列表每个概念板块为一个字典包含code和name
"""
try:
query = text("""
SELECT DISTINCT bk_code, bk_name
FROM gp_gnbk
ORDER BY bk_name
""")
with self.engine.connect() as conn:
result = conn.execute(query).fetchall()
if result:
return [{"code": str(row[0]), "name": row[1]} for row in result]
else:
logger.warning("未找到概念板块数据")
return []
except Exception as e:
logger.error(f"获取概念板块列表失败: {e}")
return []
def get_industry_stocks(self, industry_name: str) -> List[str]:
"""
获取指定行业的所有股票代码
Args:
industry_name: 行业名称
Returns:
股票代码列表
"""
try:
query = text("""
SELECT DISTINCT gp_code
FROM gp_hybk
WHERE bk_name = :industry_name
""")
with self.engine.connect() as conn:
result = conn.execute(query, {"industry_name": industry_name}).fetchall()
if result:
return [row[0] for row in result]
else:
logger.warning(f"未找到行业 {industry_name} 的股票")
return []
except Exception as e:
logger.error(f"获取行业股票失败: {e}")
return []
def get_industry_valuation_data(self, industry_name: str, start_date: str, metric: str = 'pe') -> pd.DataFrame:
"""
获取行业估值数据返回每日行业平均PE/PB/PS
说明:
- 行业估值数据是指行业内所有股票的平均PE/PB/PS的历史数据
- 在计算过程中会剔除负值和极端值(如PE>1000)
Args:
industry_name: 行业名称
start_date: 开始日期
metric: 估值指标pe、pb或ps
Returns:
包含行业估值数据的DataFrame主要包含以下列
- timestamp: 日期
- avg_{metric}: 行业平均值
- stock_count: 参与计算的股票数量
"""
try:
# 验证metric参数
if metric not in ['pe', 'pb', 'ps']:
logger.error(f"不支持的估值指标: {metric}")
return pd.DataFrame()
# 获取行业所有股票
stock_codes = self.get_industry_stocks(industry_name)
if not stock_codes:
return pd.DataFrame()
# 构建查询 - 只计算每天的行业平均值和参与计算的股票数量
query = text(f"""
WITH valid_data AS (
SELECT
`timestamp`,
symbol,
{metric}
FROM
gp_day_data
WHERE
symbol IN :stock_codes AND
`timestamp` >= :start_date AND
{metric} > 0 AND
{metric} < 1000 -- 过滤掉极端异常值
)
SELECT
`timestamp`,
AVG({metric}) as avg_{metric},
COUNT(*) as stock_count
FROM
valid_data
GROUP BY
`timestamp`
ORDER BY
`timestamp`
""")
with self.engine.connect() as conn:
# 获取汇总数据
df = pd.read_sql(
query,
conn,
params={"stock_codes": tuple(stock_codes), "start_date": start_date}
)
if df.empty:
logger.warning(f"未找到行业 {industry_name} 的估值数据")
return pd.DataFrame()
logger.info(f"成功获取行业 {industry_name}{metric.upper()}数据,共 {len(df)} 条记录")
return df
except Exception as e:
logger.error(f"获取行业估值数据失败: {e}")
return pd.DataFrame()
def calculate_industry_percentiles(self, data: pd.DataFrame, metric: str = 'pe') -> Dict:
"""
计算行业估值指标的当前分位数及历史统计值
计算的是行业平均PE/PB/PS在其历史分布中的百分位以及历史最大值、最小值、四分位数等
Args:
data: 历史数据DataFrame
metric: 估值指标pe、pb或ps
Returns:
包含分位数信息的字典
"""
if data.empty:
logger.warning(f"数据为空,无法计算行业{metric}分位数")
return {}
# 验证metric参数
if metric not in ['pe', 'pb', 'ps']:
logger.error(f"不支持的估值指标: {metric}")
return {}
# 列名
avg_col = f'avg_{metric}'
# 获取最新值
latest_data = data.iloc[-1]
# 计算当前均值在历史分布中的百分位
# 使用 <= 是为了计算当前值在历史数据中的累积分布函数值
# 这样可以得到,有多少比例的历史数据小于等于当前值,即当前值的百分位
percentile = (data[avg_col] <= latest_data[avg_col]).mean() * 100
# 计算行业平均PE的历史最小值、最大值、四分位数等
min_value = float(data[avg_col].min())
max_value = float(data[avg_col].max())
mean_value = float(data[avg_col].mean())
median_value = float(data[avg_col].median())
q1_value = float(data[avg_col].quantile(0.25))
q3_value = float(data[avg_col].quantile(0.75))
# 计算各种分位数
result = {
'date': latest_data['timestamp'].strftime('%Y-%m-%d'),
'current': float(latest_data[avg_col]),
'min': min_value,
'max': max_value,
'mean': mean_value,
'median': median_value,
'q1': q1_value,
'q3': q3_value,
'percentile': float(percentile),
'stock_count': int(latest_data['stock_count'])
}
logger.info(f"计算行业 {metric} 分位数完成: 当前{metric}={result['current']:.2f}, 百分位={result['percentile']:.2f}%")
return result
def get_industry_crowding_index(self, industry_name: str, start_date: str = None, end_date: str = None, use_cache: bool = True) -> pd.DataFrame:
"""
计算行业交易拥挤度指标并使用Redis缓存结果
对于拥挤度指标固定使用3年数据不受start_date影响
缓存时间为1天
Args:
industry_name: 行业名称
start_date: 不再使用此参数,保留是为了兼容性
end_date: 结束日期(默认为当前日期)
use_cache: 是否使用缓存默认为True
Returns:
包含行业拥挤度指标的DataFrame
"""
try:
# 始终使用3年前作为开始日期
three_years_ago = (datetime.datetime.now() - datetime.timedelta(days=3*365)).strftime('%Y-%m-%d')
if end_date is None:
end_date = datetime.datetime.now().strftime('%Y-%m-%d')
# 检查缓存
if use_cache:
cache_key = f"industry_crowding:{industry_name}"
cached_data = redis_client.get(cache_key)
if cached_data:
try:
# 尝试解析缓存的JSON数据
cached_df_dict = json.loads(cached_data)
logger.info(f"从缓存获取行业 {industry_name} 的拥挤度数据")
# 将缓存的字典转换回DataFrame
df = pd.DataFrame(cached_df_dict)
# 确保trade_date列是日期类型
df['trade_date'] = pd.to_datetime(df['trade_date'])
return df
except Exception as cache_error:
logger.warning(f"解析缓存的拥挤度数据失败,将重新查询: {cache_error}")
# 获取行业所有股票
stock_codes = self.get_industry_stocks(industry_name)
if not stock_codes:
return pd.DataFrame()
# 优化方案分别查询市场总成交额和行业成交额然后在Python中计算比率
# 查询1获取每日总成交额
query_total = text("""
SELECT
`timestamp` AS trade_date,
SUM(amount) AS total_market_amount
FROM
gp_day_data
WHERE
`timestamp` BETWEEN :start_date AND :end_date
GROUP BY
`timestamp`
ORDER BY
`timestamp`
""")
# 查询2获取行业每日成交额
query_industry = text("""
SELECT
`timestamp` AS trade_date,
SUM(amount) AS industry_amount
FROM
gp_day_data
WHERE
symbol IN :stock_codes AND
`timestamp` BETWEEN :start_date AND :end_date
GROUP BY
`timestamp`
ORDER BY
`timestamp`
""")
with self.engine.connect() as conn:
# 执行两个独立的查询
df_total = pd.read_sql(
query_total,
conn,
params={"start_date": three_years_ago, "end_date": end_date}
)
df_industry = pd.read_sql(
query_industry,
conn,
params={
"stock_codes": tuple(stock_codes),
"start_date": three_years_ago,
"end_date": end_date
}
)
# 检查查询结果
if df_total.empty or df_industry.empty:
logger.warning(f"未找到行业 {industry_name} 的交易数据")
return pd.DataFrame()
# 在Python中合并数据并计算比率
df = pd.merge(df_total, df_industry, on='trade_date', how='inner')
# 计算行业成交额占比
df['industry_amount_ratio'] = (df['industry_amount'] / df['total_market_amount']) * 100
# 在Python中计算百分位
df['percentile'] = df['industry_amount_ratio'].rank(pct=True) * 100
# 添加拥挤度评级
df['crowding_level'] = pd.cut(
df['percentile'],
bins=[0, 20, 40, 60, 80, 100],
labels=['不拥挤', '较不拥挤', '中性', '较为拥挤', '极度拥挤']
)
# 将DataFrame转换为字典以便缓存
df_dict = df.to_dict(orient='records')
# 缓存结果有效期1天86400秒
if use_cache:
try:
redis_client.set(
cache_key,
json.dumps(df_dict, default=str), # 使用default=str处理日期等特殊类型
ex=86400 # 1天的秒数
)
logger.info(f"已缓存行业 {industry_name} 的拥挤度数据有效期为1天")
except Exception as cache_error:
logger.warning(f"缓存行业拥挤度数据失败: {cache_error}")
logger.info(f"成功计算行业 {industry_name} 的拥挤度指标,共 {len(df)} 条记录")
return df
except Exception as e:
logger.error(f"计算行业拥挤度指标失败: {e}")
return pd.DataFrame()
def get_industry_analysis(self, industry_name: str, metric: str = 'pe', start_date: str = None) -> Dict:
"""
获取行业综合分析结果
Args:
industry_name: 行业名称
metric: 估值指标pe、pb或ps
start_date: 开始日期默认为3年前
Returns:
行业分析结果字典,包含以下内容:
- success: 是否成功
- industry_name: 行业名称
- metric: 估值指标
- analysis_date: 分析日期
- valuation: 估值数据,包含:
- dates: 日期列表
- avg_values: 行业平均值列表
- stock_counts: 参与计算的股票数量列表
- percentiles: 分位数信息,包含行业平均值的历史最大值、最小值、四分位数等
- crowding如有: 拥挤度数据,包含:
- dates: 日期列表
- ratios: 拥挤度比例列表
- percentiles: 拥挤度百分位列表
- current: 当前拥挤度信息
"""
try:
# 默认查询近3年数据
if start_date is None:
start_date = (datetime.datetime.now() - datetime.timedelta(days=3*365)).strftime('%Y-%m-%d')
# 获取估值数据
valuation_data = self.get_industry_valuation_data(industry_name, start_date, metric)
if valuation_data.empty:
return {"success": False, "message": f"无法获取行业 {industry_name} 的估值数据"}
# 计算估值分位数
percentiles = self.calculate_industry_percentiles(valuation_data, metric)
if not percentiles:
return {"success": False, "message": f"无法计算行业 {industry_name} 的估值分位数"}
# 获取拥挤度指标始终使用3年数据不受start_date影响
crowding_data = self.get_industry_crowding_index(industry_name)
# 为了兼容前端,准备一些行业平均值的历史统计数据
avg_values = valuation_data[f'avg_{metric}'].tolist()
# 准备返回结果
result = {
"success": True,
"industry_name": industry_name,
"metric": metric.upper(),
"analysis_date": datetime.datetime.now().strftime('%Y-%m-%d'),
"valuation": {
"dates": valuation_data['timestamp'].dt.strftime('%Y-%m-%d').tolist(),
"avg_values": avg_values,
# 填充行业平均值的历史统计线
"min_values": [percentiles['min']] * len(avg_values), # 行业平均PE历史最小值
"max_values": [percentiles['max']] * len(avg_values), # 行业平均PE历史最大值
"q1_values": [percentiles['q1']] * len(avg_values), # 行业平均PE历史第一四分位数
"q3_values": [percentiles['q3']] * len(avg_values), # 行业平均PE历史第三四分位数
"median_values": [percentiles['median']] * len(avg_values), # 行业平均PE历史中位数
"stock_counts": valuation_data['stock_count'].tolist(),
"percentiles": percentiles
}
}
# 添加拥挤度数据(如果有)
if not crowding_data.empty:
current_crowding = crowding_data.iloc[-1]
result["crowding"] = {
"dates": crowding_data['trade_date'].dt.strftime('%Y-%m-%d').tolist(),
"ratios": crowding_data['industry_amount_ratio'].tolist(),
"percentiles": crowding_data['percentile'].tolist(),
"current": {
"date": current_crowding['trade_date'].strftime('%Y-%m-%d'),
"ratio": float(current_crowding['industry_amount_ratio']),
"percentile": float(current_crowding['percentile']),
"level": current_crowding['crowding_level'],
# 添加行业成交额比例的历史分位信息
"ratio_stats": {
"min": float(crowding_data['industry_amount_ratio'].min()),
"max": float(crowding_data['industry_amount_ratio'].max()),
"mean": float(crowding_data['industry_amount_ratio'].mean()),
"median": float(crowding_data['industry_amount_ratio'].median()),
"q1": float(crowding_data['industry_amount_ratio'].quantile(0.25)),
"q3": float(crowding_data['industry_amount_ratio'].quantile(0.75)),
}
}
}
return result
except Exception as e:
logger.error(f"获取行业综合分析失败: {e}")
return {"success": False, "message": f"获取行业综合分析失败: {e}"}
def get_stock_concepts(self, stock_code: str) -> List[str]:
"""
获取指定股票所属的概念板块列表
Args:
stock_code: 股票代码
Returns:
概念板块名称列表
"""
try:
# 转换股票代码格式
formatted_code = self._convert_stock_code_format(stock_code)
query = text("""
SELECT DISTINCT bk_name
FROM gp_gnbk
WHERE gp_code = :stock_code
""")
with self.engine.connect() as conn:
result = conn.execute(query, {"stock_code": formatted_code}).fetchall()
if result:
return [row[0] for row in result]
else:
logger.warning(f"未找到股票 {stock_code} 的概念板块数据")
return []
except Exception as e:
logger.error(f"获取股票概念板块失败: {e}")
return []
def _convert_stock_code_format(self, stock_code: str) -> str:
"""
转换股票代码格式
Args:
stock_code: 原始股票代码,格式如 "600519.SH"
Returns:
转换后的股票代码,格式如 "SH600519"
"""
try:
code, market = stock_code.split('.')
return f"{market}{code}"
except Exception as e:
logger.error(f"转换股票代码格式失败: {str(e)}")
return stock_code