"""
行业估值分析模块

提供行业历史PE、PB、PS分位数分析功能以及行业拥挤度指标,包括:
1. 行业历史PE、PB、PS数据获取
2. 分位数计算
3. 行业交易拥挤度计算
"""

import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text
import datetime
import logging
import json
import redis
import time
from typing import Tuple, Dict, List, Optional, Union

from .config import DB_URL, OUTPUT_DIR, LOG_FILE

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(LOG_FILE),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("industry_analysis")

# 添加Redis客户端
redis_client = redis.Redis(
    host='192.168.18.208',  # Redis服务器地址,根据实际情况调整
    port=6379,
    password='wlkj2018',
    db=13,
    socket_timeout=5,
    decode_responses=True
)


class IndustryAnalyzer:
    """行业估值分析器类"""

    def __init__(self, db_url: str = DB_URL):
        """
        初始化行业估值分析器
        
        Args:
            db_url: 数据库连接URL
        """
        self.engine = create_engine(
            db_url,
            pool_size=5,
            max_overflow=10,
            pool_recycle=3600
        )
        logger.info("行业估值分析器初始化完成")
    
    def get_industry_list(self) -> List[Dict]:
        """
        获取所有行业列表
        
        Returns:
            行业列表,每个行业为一个字典,包含code和name
        """
        try:
            query = text("""
                SELECT DISTINCT bk_code, bk_name
                FROM gp_hybk
                ORDER BY bk_name
            """)
            
            with self.engine.connect() as conn:
                result = conn.execute(query).fetchall()
                
            if result:
                return [{"code": str(row[0]), "name": row[1]} for row in result]
            else:
                logger.warning("未找到行业数据")
                return []
        except Exception as e:
            logger.error(f"获取行业列表失败: {e}")
            return []
    
    def get_concept_list(self) -> List[Dict]:
        """
        获取所有概念板块列表
        
        Returns:
            概念板块列表,每个概念板块为一个字典,包含code和name
        """
        try:
            query = text("""
                SELECT DISTINCT bk_code, bk_name
                FROM gp_gnbk
                ORDER BY bk_name
            """)
            
            with self.engine.connect() as conn:
                result = conn.execute(query).fetchall()
                
            if result:
                return [{"code": str(row[0]), "name": row[1]} for row in result]
            else:
                logger.warning("未找到概念板块数据")
                return []
        except Exception as e:
            logger.error(f"获取概念板块列表失败: {e}")
            return []
    
    def get_industry_stocks(self, industry_name: str) -> List[str]:
        """
        获取指定行业的所有股票代码
        
        Args:
            industry_name: 行业名称
            
        Returns:
            股票代码列表
        """
        try:
            query = text("""
                SELECT DISTINCT gp_code 
                FROM gp_hybk 
                WHERE bk_name = :industry_name
            """)
            
            with self.engine.connect() as conn:
                result = conn.execute(query, {"industry_name": industry_name}).fetchall()
                
            if result:
                return [row[0] for row in result]
            else:
                logger.warning(f"未找到行业 {industry_name} 的股票")
                return []
        except Exception as e:
            logger.error(f"获取行业股票失败: {e}")
            return []
    
    def get_industry_valuation_data(self, industry_name: str, start_date: str, metric: str = 'pe') -> pd.DataFrame:
        """
        获取行业估值数据,返回每日行业平均PE/PB/PS
        
        说明:
        - 行业估值数据是指行业内所有股票的平均PE/PB/PS的历史数据
        - 在计算过程中会剔除负值和极端值(如PE>1000)
        
        Args:
            industry_name: 行业名称
            start_date: 开始日期
            metric: 估值指标(pe、pb或ps)
            
        Returns:
            包含行业估值数据的DataFrame,主要包含以下列:
            - timestamp: 日期
            - avg_{metric}: 行业平均值
            - stock_count: 参与计算的股票数量
        """
        try:
            # 验证metric参数
            if metric not in ['pe', 'pb', 'ps']:
                logger.error(f"不支持的估值指标: {metric}")
                return pd.DataFrame()
            
            # 获取行业所有股票
            stock_codes = self.get_industry_stocks(industry_name)
            if not stock_codes:
                return pd.DataFrame()

            # 构建查询 - 只计算每天的行业平均值和参与计算的股票数量
            query = text(f"""
                WITH valid_data AS (
                    SELECT 
                        `timestamp`,
                        symbol,
                        {metric}
                    FROM 
                        gp_day_data 
                    WHERE 
                        symbol IN :stock_codes AND 
                        `timestamp` >= :start_date AND
                        {metric} > 0 AND
                        {metric} < 1000  -- 过滤掉极端异常值
                )
                SELECT 
                    `timestamp`,
                    AVG({metric}) as avg_{metric},
                    COUNT(*) as stock_count
                FROM 
                    valid_data
                GROUP BY 
                    `timestamp`
                ORDER BY 
                    `timestamp`
            """)
            
            with self.engine.connect() as conn:
                # 获取汇总数据
                df = pd.read_sql(
                    query, 
                    conn, 
                    params={"stock_codes": tuple(stock_codes), "start_date": start_date}
                )
            
            if df.empty:
                logger.warning(f"未找到行业 {industry_name} 的估值数据")
                return pd.DataFrame()
                
            logger.info(f"成功获取行业 {industry_name} 的{metric.upper()}数据,共 {len(df)} 条记录")
            return df
            
        except Exception as e:
            logger.error(f"获取行业估值数据失败: {e}")
            return pd.DataFrame()
    
    def calculate_industry_percentiles(self, data: pd.DataFrame, metric: str = 'pe') -> Dict:
        """
        计算行业估值指标的当前分位数及历史统计值
        
        计算的是行业平均PE/PB/PS在其历史分布中的百分位,以及历史最大值、最小值、四分位数等
        
        Args:
            data: 历史数据DataFrame
            metric: 估值指标,pe、pb或ps
            
        Returns:
            包含分位数信息的字典
        """
        if data.empty:
            logger.warning(f"数据为空,无法计算行业{metric}分位数")
            return {}
        
        # 验证metric参数
        if metric not in ['pe', 'pb', 'ps']:
            logger.error(f"不支持的估值指标: {metric}")
            return {}
            
        # 列名
        avg_col = f'avg_{metric}'
            
        # 获取最新值
        latest_data = data.iloc[-1]
        
        # 计算当前均值在历史分布中的百分位
        # 使用 <= 是为了计算当前值在历史数据中的累积分布函数值
        # 这样可以得到,有多少比例的历史数据小于等于当前值,即当前值的百分位
        percentile = (data[avg_col] <= latest_data[avg_col]).mean() * 100
        
        # 计算行业平均PE的历史最小值、最大值、四分位数等
        min_value = float(data[avg_col].min())
        max_value = float(data[avg_col].max())
        mean_value = float(data[avg_col].mean())
        median_value = float(data[avg_col].median())
        q1_value = float(data[avg_col].quantile(0.25))
        q3_value = float(data[avg_col].quantile(0.75))
            
        # 计算各种分位数
        result = {
            'date': latest_data['timestamp'].strftime('%Y-%m-%d'),
            'current': float(latest_data[avg_col]),
            'min': min_value,
            'max': max_value,
            'mean': mean_value,
            'median': median_value,
            'q1': q1_value,
            'q3': q3_value,
            'percentile': float(percentile),
            'stock_count': int(latest_data['stock_count'])
        }
        
        logger.info(f"计算行业 {metric} 分位数完成: 当前{metric}={result['current']:.2f}, 百分位={result['percentile']:.2f}%")
        return result
    
    def get_industry_crowding_index(self, industry_name: str, start_date: str = None, end_date: str = None, use_cache: bool = True) -> pd.DataFrame:
        """
        计算行业交易拥挤度指标,并使用Redis缓存结果
        
        对于拥挤度指标,固定使用3年数据,不受start_date影响
        缓存时间为1天
        
        Args:
            industry_name: 行业名称
            start_date: 不再使用此参数,保留是为了兼容性
            end_date: 结束日期(默认为当前日期)
            use_cache: 是否使用缓存,默认为True
            
        Returns:
            包含行业拥挤度指标的DataFrame
        """
        try:
            # 始终使用3年前作为开始日期
            three_years_ago = (datetime.datetime.now() - datetime.timedelta(days=3*365)).strftime('%Y-%m-%d')
            
            if end_date is None:
                end_date = datetime.datetime.now().strftime('%Y-%m-%d')
            
            # 检查缓存
            if use_cache:
                cache_key = f"industry_crowding:{industry_name}"
                cached_data = redis_client.get(cache_key)
                
                if cached_data:
                    try:
                        # 尝试解析缓存的JSON数据
                        cached_df_dict = json.loads(cached_data)
                        logger.info(f"从缓存获取行业 {industry_name} 的拥挤度数据")
                        
                        # 将缓存的字典转换回DataFrame
                        df = pd.DataFrame(cached_df_dict)
                        
                        # 确保trade_date列是日期类型
                        df['trade_date'] = pd.to_datetime(df['trade_date'])
                        
                        return df
                    except Exception as cache_error:
                        logger.warning(f"解析缓存的拥挤度数据失败,将重新查询: {cache_error}")
            
            # 获取行业所有股票
            stock_codes = self.get_industry_stocks(industry_name)
            if not stock_codes:
                return pd.DataFrame()
                
            # 优化方案:分别查询市场总成交额和行业成交额,然后在Python中计算比率
            
            # 查询1:获取每日总成交额
            query_total = text("""
                SELECT 
                    `timestamp` AS trade_date,
                    SUM(amount) AS total_market_amount
                FROM 
                    gp_day_data
                WHERE 
                    `timestamp` BETWEEN :start_date AND :end_date
                GROUP BY 
                    `timestamp`
                ORDER BY 
                    `timestamp`
            """)
            
            # 查询2:获取行业每日成交额
            query_industry = text("""
                SELECT 
                    `timestamp` AS trade_date,
                    SUM(amount) AS industry_amount
                FROM 
                    gp_day_data
                WHERE 
                    symbol IN :stock_codes AND
                    `timestamp` BETWEEN :start_date AND :end_date
                GROUP BY 
                    `timestamp`
                ORDER BY 
                    `timestamp`
            """)
            
            with self.engine.connect() as conn:
                # 执行两个独立的查询
                df_total = pd.read_sql(
                    query_total, 
                    conn, 
                    params={"start_date": three_years_ago, "end_date": end_date}
                )
                
                df_industry = pd.read_sql(
                    query_industry, 
                    conn, 
                    params={
                        "stock_codes": tuple(stock_codes), 
                        "start_date": three_years_ago, 
                        "end_date": end_date
                    }
                )
            
            # 检查查询结果
            if df_total.empty or df_industry.empty:
                logger.warning(f"未找到行业 {industry_name} 的交易数据")
                return pd.DataFrame()
                
            # 在Python中合并数据并计算比率
            df = pd.merge(df_total, df_industry, on='trade_date', how='inner')
            
            # 计算行业成交额占比
            df['industry_amount_ratio'] = (df['industry_amount'] / df['total_market_amount']) * 100
            
            # 在Python中计算百分位
            df['percentile'] = df['industry_amount_ratio'].rank(pct=True) * 100
                
            # 添加拥挤度评级
            df['crowding_level'] = pd.cut(
                df['percentile'],
                bins=[0, 20, 40, 60, 80, 100],
                labels=['不拥挤', '较不拥挤', '中性', '较为拥挤', '极度拥挤']
            )
            
            # 将DataFrame转换为字典,以便缓存
            df_dict = df.to_dict(orient='records')
            
            # 缓存结果,有效期1天(86400秒)
            if use_cache:
                try:
                    redis_client.set(
                        cache_key,
                        json.dumps(df_dict, default=str),  # 使用default=str处理日期等特殊类型
                        ex=86400  # 1天的秒数
                    )
                    logger.info(f"已缓存行业 {industry_name} 的拥挤度数据,有效期为1天")
                except Exception as cache_error:
                    logger.warning(f"缓存行业拥挤度数据失败: {cache_error}")
                
            logger.info(f"成功计算行业 {industry_name} 的拥挤度指标,共 {len(df)} 条记录")
            return df
            
        except Exception as e:
            logger.error(f"计算行业拥挤度指标失败: {e}")
            return pd.DataFrame()
    
    def get_industry_analysis(self, industry_name: str, metric: str = 'pe', start_date: str = None) -> Dict:
        """
        获取行业综合分析结果
        
        Args:
            industry_name: 行业名称
            metric: 估值指标(pe、pb或ps)
            start_date: 开始日期(默认为3年前)
            
        Returns:
            行业分析结果字典,包含以下内容:
            - success: 是否成功
            - industry_name: 行业名称
            - metric: 估值指标
            - analysis_date: 分析日期
            - valuation: 估值数据,包含:
                - dates: 日期列表
                - avg_values: 行业平均值列表
                - stock_counts: 参与计算的股票数量列表
                - percentiles: 分位数信息,包含行业平均值的历史最大值、最小值、四分位数等
            - crowding(如有): 拥挤度数据,包含:
                - dates: 日期列表
                - ratios: 拥挤度比例列表
                - percentiles: 拥挤度百分位列表
                - current: 当前拥挤度信息
        """
        try:
            # 默认查询近3年数据
            if start_date is None:
                start_date = (datetime.datetime.now() - datetime.timedelta(days=3*365)).strftime('%Y-%m-%d')
                
            # 获取估值数据
            valuation_data = self.get_industry_valuation_data(industry_name, start_date, metric)
            if valuation_data.empty:
                return {"success": False, "message": f"无法获取行业 {industry_name} 的估值数据"}
                
            # 计算估值分位数
            percentiles = self.calculate_industry_percentiles(valuation_data, metric)
            if not percentiles:
                return {"success": False, "message": f"无法计算行业 {industry_name} 的估值分位数"}
                
            # 获取拥挤度指标(始终使用3年数据,不受start_date影响)
            crowding_data = self.get_industry_crowding_index(industry_name)
            
            # 为了兼容前端,准备一些行业平均值的历史统计数据
            avg_values = valuation_data[f'avg_{metric}'].tolist()
            
            # 准备返回结果
            result = {
                "success": True,
                "industry_name": industry_name,
                "metric": metric.upper(),
                "analysis_date": datetime.datetime.now().strftime('%Y-%m-%d'),
                "valuation": {
                    "dates": valuation_data['timestamp'].dt.strftime('%Y-%m-%d').tolist(),
                    "avg_values": avg_values,
                    # 填充行业平均值的历史统计线
                    "min_values": [percentiles['min']] * len(avg_values),  # 行业平均PE历史最小值
                    "max_values": [percentiles['max']] * len(avg_values),  # 行业平均PE历史最大值
                    "q1_values": [percentiles['q1']] * len(avg_values),    # 行业平均PE历史第一四分位数
                    "q3_values": [percentiles['q3']] * len(avg_values),    # 行业平均PE历史第三四分位数
                    "median_values": [percentiles['median']] * len(avg_values),  # 行业平均PE历史中位数
                    "stock_counts": valuation_data['stock_count'].tolist(),
                    "percentiles": percentiles
                }
            }
            
            # 添加拥挤度数据(如果有)
            if not crowding_data.empty:
                current_crowding = crowding_data.iloc[-1]
                result["crowding"] = {
                    "dates": crowding_data['trade_date'].dt.strftime('%Y-%m-%d').tolist(),
                    "ratios": crowding_data['industry_amount_ratio'].tolist(),
                    "percentiles": crowding_data['percentile'].tolist(),
                    "current": {
                        "date": current_crowding['trade_date'].strftime('%Y-%m-%d'),
                        "ratio": float(current_crowding['industry_amount_ratio']),
                        "percentile": float(current_crowding['percentile']),
                        "level": current_crowding['crowding_level'],
                        # 添加行业成交额比例的历史分位信息
                        "ratio_stats": {
                            "min": float(crowding_data['industry_amount_ratio'].min()),
                            "max": float(crowding_data['industry_amount_ratio'].max()),
                            "mean": float(crowding_data['industry_amount_ratio'].mean()),
                            "median": float(crowding_data['industry_amount_ratio'].median()),
                            "q1": float(crowding_data['industry_amount_ratio'].quantile(0.25)),
                            "q3": float(crowding_data['industry_amount_ratio'].quantile(0.75)),
                        }
                    }
                }
            
            return result
            
        except Exception as e:
            logger.error(f"获取行业综合分析失败: {e}")
            return {"success": False, "message": f"获取行业综合分析失败: {e}"}
    
    def get_stock_concepts(self, stock_code: str) -> List[str]:
        """
        获取指定股票所属的概念板块列表
        
        Args:
            stock_code: 股票代码
            
        Returns:
            概念板块名称列表
        """
        try:
            # 转换股票代码格式
            formatted_code = self._convert_stock_code_format(stock_code)
            
            query = text("""
                SELECT DISTINCT bk_name 
                FROM gp_gnbk 
                WHERE gp_code = :stock_code
            """)
            
            with self.engine.connect() as conn:
                result = conn.execute(query, {"stock_code": formatted_code}).fetchall()
                
            if result:
                return [row[0] for row in result]
            else:
                logger.warning(f"未找到股票 {stock_code} 的概念板块数据")
                return []
                
        except Exception as e:
            logger.error(f"获取股票概念板块失败: {e}")
            return []
    
    def _convert_stock_code_format(self, stock_code: str) -> str:
        """
        转换股票代码格式
        
        Args:
            stock_code: 原始股票代码,格式如 "600519.SH"
            
        Returns:
            转换后的股票代码,格式如 "SH600519"
        """
        try:
            code, market = stock_code.split('.')
            return f"{market}{code}"
        except Exception as e:
            logger.error(f"转换股票代码格式失败: {str(e)}")
            return stock_code