From 8f73099e18617c8a3fb6401c6784fbb495e92ae2 Mon Sep 17 00:00:00 2001 From: liao Date: Wed, 21 May 2025 14:01:54 +0800 Subject: [PATCH] commit; --- requirements.txt | 3 +- src/app.py | 244 ++++- src/valuation_analysis/config.py | 10 + .../eastmoney_rzrq_collector.py | 9 +- src/valuation_analysis/financial_analysis.py | 869 ++++++++++++++++++ src/valuation_analysis/industry_analysis.py | 111 ++- .../stock_price_collector.py | 463 ++++++++++ 7 files changed, 1632 insertions(+), 77 deletions(-) create mode 100644 src/valuation_analysis/financial_analysis.py create mode 100644 src/valuation_analysis/stock_price_collector.py diff --git a/requirements.txt b/requirements.txt index 00f6b65..5216a58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ markdown2>=2.5.3 google-genai redis==5.2.1 pandas==2.2.3 -apscheduler==3.11.0 \ No newline at end of file +apscheduler==3.11.0 +pymongo==4.13.0 \ No newline at end of file diff --git a/src/app.py b/src/app.py index 604a912..328f446 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,6 @@ import sys import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, time import pandas as pd import uuid import json @@ -45,6 +45,9 @@ from src.scripts.stock_daily_data_collector import collect_stock_daily_data from utils.distributed_lock import DistributedLock from valuation_analysis.industry_analysis import redis_client +from valuation_analysis.financial_analysis import FinancialAnalyzer +from src.valuation_analysis.stock_price_collector import StockPriceCollector + # 设置日志 logging.basicConfig( level=logging.INFO, @@ -183,6 +186,77 @@ def run_backtest_task(task_id, stocks_buy_dates, end_date): backtest_tasks[task_id]['error'] = str(e) logger.error(f"回测任务 {task_id} 失败:{str(e)}") +def initialize_stock_price_schedule(): + """ + 初始化实时股价数据采集定时任务 + """ + # 创建分布式锁 + price_lock = DistributedLock(redis_client, "stock_price_collector", expire_time=3600) # 1小时过期 + + # 尝试获取锁 + if not price_lock.acquire(): + logger.info("其他服务器正在运行实时股价数据采集任务,本服务器跳过") + return None + + try: + from apscheduler.schedulers.background import BackgroundScheduler + from apscheduler.triggers.cron import CronTrigger + + # 创建定时任务调度器 + scheduler = BackgroundScheduler() + + def is_trading_time(): + """判断当前是否为交易时间""" + now = datetime.now() + current_time = now.time() + + # 定义交易时间段 + morning_start = time(9, 25) # 上午开盘前5分钟 + morning_end = time(11, 30) # 上午收盘 + afternoon_start = time(13, 0) # 下午开盘 + afternoon_end = time(15, 0) # 下午收盘 + + # 判断是否为工作日 + if now.weekday() >= 5: # 5是周六,6是周日 + return False + + # 判断是否在交易时间段内 + is_morning = morning_start <= current_time <= morning_end + is_afternoon = afternoon_start <= current_time <= afternoon_end + + return is_morning or is_afternoon + + def update_stock_price(): + """更新实时股价数据""" + if not is_trading_time(): + return + + try: + collector = StockPriceCollector() + collector.update_latest_data() + except Exception as e: + logger.error(f"更新实时股价数据失败: {e}") + + # 添加定时任务 + scheduler.add_job( + func=update_stock_price, + trigger='interval', + minutes=5, + id='stock_price_update', + name='实时股价数据采集', + replace_existing=True + ) + + # 启动调度器 + scheduler.start() + logger.info("实时股价数据采集定时任务已初始化,将在交易时间内每5分钟执行一次") + return scheduler + + except Exception as e: + logger.error(f"初始化实时股价数据采集定时任务失败: {str(e)}") + price_lock.release() + return None + def initialize_rzrq_collector_schedule(): """初始化融资融券数据采集定时任务""" # 创建分布式锁 @@ -203,7 +277,7 @@ def initialize_rzrq_collector_schedule(): # 添加每天下午5点执行的任务 scheduler.add_job( func=run_rzrq_initial_collection, - trigger=CronTrigger(hour=18, minute=0), + trigger=CronTrigger(hour=18, minute=40), id='rzrq_daily_update', name='每日更新融资融券数据', replace_existing=True @@ -284,9 +358,9 @@ def run_stock_daily_collection(): return False def run_rzrq_initial_collection(): - """执行融资融券数据初始全量采集""" + """执行融资融券数据更新采集""" try: - logger.info("开始执行融资融券数据初始全量采集") + logger.info("开始执行融资融券数据更新采集") # 生成任务ID task_id = f"rzrq-{uuid.uuid4().hex[:16]}" @@ -296,7 +370,7 @@ def run_rzrq_initial_collection(): 'status': 'running', 'created_at': datetime.now().isoformat(), 'type': 'initial_collection', - 'message': '开始执行融资融券数据初始全量采集' + 'message': '开始执行融资融券数据更新采集' } # 在新线程中执行采集任务 @@ -307,16 +381,16 @@ def run_rzrq_initial_collection(): if result: rzrq_tasks[task_id]['status'] = 'completed' - rzrq_tasks[task_id]['message'] = '融资融券数据初始全量采集完成' - logger.info(f"融资融券数据初始全量采集任务 {task_id} 完成") + rzrq_tasks[task_id]['message'] = '融资融券数据更新完成' + logger.info(f"融资融券数据更新任务 {task_id} 完成") else: rzrq_tasks[task_id]['status'] = 'failed' - rzrq_tasks[task_id]['message'] = '融资融券数据初始全量采集失败' - logger.error(f"融资融券数据初始全量采集任务 {task_id} 失败") + rzrq_tasks[task_id]['message'] = '融资融券数据更新失败' + logger.error(f"融资融券数据更新任务 {task_id} 失败") except Exception as e: rzrq_tasks[task_id]['status'] = 'failed' rzrq_tasks[task_id]['message'] = f'执行失败: {str(e)}' - logger.error(f"执行融资融券数据初始全量采集线程中出错: {str(e)}") + logger.error(f"执行融资融券数据更新线程中出错: {str(e)}") # 创建并启动线程 thread = Thread(target=collection_task) @@ -325,7 +399,7 @@ def run_rzrq_initial_collection(): return task_id except Exception as e: - logger.error(f"启动融资融券数据初始全量采集任务失败: {str(e)}") + logger.error(f"启动融资融券数据更新任务失败: {str(e)}") if 'task_id' in locals(): rzrq_tasks[task_id]['status'] = 'failed' rzrq_tasks[task_id]['message'] = f'启动失败: {str(e)}' @@ -2567,7 +2641,7 @@ def initialize_industry_crowding_schedule(): # 添加每天晚上10点执行的任务 scheduler.add_job( func=precalculate_industry_crowding, - trigger=CronTrigger(hour=22, minute=0), + trigger=CronTrigger(hour=20, minute=30), id='industry_crowding_precalc', name='预计算行业拥挤度指标', replace_existing=True @@ -2575,7 +2649,7 @@ def initialize_industry_crowding_schedule(): # 启动调度器 scheduler.start() - logger.info("行业拥挤度指标预计算定时任务已初始化,将在每天22:00执行") + logger.info("行业拥挤度指标预计算定时任务已初始化,将在每天20:30执行") return scheduler except Exception as e: logger.error(f"初始化行业拥挤度指标预计算定时任务失败: {str(e)}") @@ -2585,43 +2659,126 @@ def initialize_industry_crowding_schedule(): def precalculate_industry_crowding(): """预计算所有行业的拥挤度指标""" try: - logger.info("开始预计算所有行业的拥挤度指标") + from .valuation_analysis.industry_analysis import IndustryAnalyzer - # 获取所有行业列表 - industries = industry_analyzer.get_industry_list() - if not industries: - logger.error("获取行业列表失败") - return - - # 记录成功和失败的数量 - success_count = 0 - fail_count = 0 + analyzer = IndustryAnalyzer() + industries = analyzer.get_all_industries() - # 遍历所有行业 for industry in industries: try: - industry_name = industry['name'] - logger.info(f"正在计算行业 {industry_name} 的拥挤度指标") - - # 调用拥挤度计算方法 - df = industry_analyzer.get_industry_crowding_index(industry_name) - + # 调用时设置 use_cache=False,强制重新计算 + df = analyzer.get_industry_crowding_index(industry, use_cache=False) if not df.empty: - success_count += 1 - logger.info(f"成功计算行业 {industry_name} 的拥挤度指标") + logger.info(f"成功预计算行业 {industry} 的拥挤度指标") else: - fail_count += 1 - logger.warning(f"计算行业 {industry_name} 的拥挤度指标失败") - + logger.warning(f"行业 {industry} 的拥挤度指标计算失败") except Exception as e: - fail_count += 1 - logger.error(f"计算行业 {industry_name} 的拥挤度指标时出错: {str(e)}") + logger.error(f"预计算行业 {industry} 的拥挤度指标时出错: {str(e)}") continue - - logger.info(f"行业拥挤度指标预计算完成,成功: {success_count},失败: {fail_count}") - + + logger.info("所有行业的拥挤度指标预计算完成") except Exception as e: logger.error(f"预计算行业拥挤度指标失败: {str(e)}") + finally: + # 释放分布式锁 + industry_crowding_lock = DistributedLock(redis_client, "industry_crowding_calculator") + industry_crowding_lock.release() + +@app.route('/api/financial/analysis', methods=['GET']) +def financial_analysis(): + """ + 财务分析接口 + + 请求参数: + stock_code: 股票代码 + + 返回: + 分析结果JSON + """ + try: + stock_code = request.args.get('stock_code') + if not stock_code: + return jsonify({ + 'success': False, + 'message': '缺少必要参数:stock_code' + }), 400 + + analyzer = FinancialAnalyzer() + result = analyzer.analyze_financial_data(stock_code) + + return jsonify(result) + + except Exception as e: + logger.error(f"财务分析失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'财务分析失败: {str(e)}' + }), 500 + +@app.route('/api/financial/indicators', methods=['GET']) +def get_financial_indicators(): + """ + 获取财务指标接口 + + 请求参数: + stock_code: 股票代码 + + 返回: + JSON格式的财务指标数据 + """ + try: + # 获取股票代码 + stock_code = request.args.get('stock_code') + if not stock_code: + return jsonify({ + 'success': False, + 'message': '缺少必要参数: stock_code' + }), 400 + + # 创建分析器实例 + analyzer = FinancialAnalyzer() + + # 获取财务指标 + result = analyzer.extract_financial_indicators(stock_code) + + return jsonify(result) + + except Exception as e: + logger.error(f"获取财务指标失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'获取财务指标失败: {str(e)}' + }), 500 + +@app.route('/api/financial/test_structure', methods=['GET']) +def test_mongo_structure(): + """ + 测试MongoDB集合结构接口 + + 请求参数: + stock_code: 股票代码(可选) + + 返回: + JSON格式的集合结构信息 + """ + try: + # 获取股票代码(可选) + stock_code = request.args.get('stock_code') + + # 创建分析器实例 + analyzer = FinancialAnalyzer() + + # 获取集合结构 + result = analyzer.test_mongo_structure(stock_code) + + return jsonify(result) + + except Exception as e: + logger.error(f"测试MongoDB结构失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'测试MongoDB结构失败: {str(e)}' + }), 500 if __name__ == '__main__': """ @@ -2644,9 +2801,9 @@ if __name__ == '__main__': else: print("股票日线采集器锁释放失败或不存在") if industry_crowding_lock.release(): - print("成功释放股票日线采集器锁") + print("成功释放行业拥挤度锁") else: - print("股票日线采集器锁释放失败或不存在") + print("行业拥挤度锁释放失败或不存在") print("锁释放操作完成") """ @@ -2660,5 +2817,8 @@ if __name__ == '__main__': # 初始化行业拥挤度指标预计算定时任务 industry_crowding_scheduler = initialize_industry_crowding_schedule() + # 初始化实时股价数据采集定时任务 + initialize_stock_price_schedule() + # 启动Web服务器 app.run(host='0.0.0.0', port=5000, debug=True) diff --git a/src/valuation_analysis/config.py b/src/valuation_analysis/config.py index a3954d1..8868189 100644 --- a/src/valuation_analysis/config.py +++ b/src/valuation_analysis/config.py @@ -17,6 +17,16 @@ DB_CONFIG = { # 创建数据库连接URL DB_URL = f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}" +# MongoDB配置 +MONGO_CONFIG = { + 'host': '192.168.18.75', + 'port': 27017, + 'db': 'judge', + 'username': 'root', + 'password': 'wlkj2018', + 'collection': 'wind_financial_analysis' +} + # 项目根目录 ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) diff --git a/src/valuation_analysis/eastmoney_rzrq_collector.py b/src/valuation_analysis/eastmoney_rzrq_collector.py index 967bffe..e180db6 100644 --- a/src/valuation_analysis/eastmoney_rzrq_collector.py +++ b/src/valuation_analysis/eastmoney_rzrq_collector.py @@ -284,7 +284,6 @@ class EastmoneyRzrqCollector: # 确保数据表存在 if not self._ensure_table_exists(): return False - # 将nan值转换为None(在SQL中会变成NULL) data = data.replace({pd.NA: None, pd.NaT: None}) data = data.where(pd.notnull(data), None) @@ -422,18 +421,18 @@ class EastmoneyRzrqCollector: def initial_data_collection(self) -> bool: """ - 首次全量采集融资融券数据 + 每日更新任务采集融资融券数据 Returns: 是否成功采集所有数据 """ try: logger.info("开始获取最新融资融券数据...") - df = collector.fetch_data(page=1) + df = self.fetch_data(page=1) if not df.empty: # 保存数据到数据库 - if collector.save_to_database(df): + if self.save_to_database(df): logger.info(f"成功更新最新数据,日期:{df.iloc[0]['trade_date']}") else: logger.error("更新最新数据失败") @@ -443,7 +442,7 @@ class EastmoneyRzrqCollector: return True except Exception as e: - logger.error(f"首次全量采集失败: {e}") + logger.error(f"每日更新任务采集失败: {e}") return False def get_chart_data(self, limit_days: int = 30) -> dict: diff --git a/src/valuation_analysis/financial_analysis.py b/src/valuation_analysis/financial_analysis.py new file mode 100644 index 0000000..2fec7f0 --- /dev/null +++ b/src/valuation_analysis/financial_analysis.py @@ -0,0 +1,869 @@ +""" +财务分析模块 + +提供从MySQL和MongoDB获取财务数据并进行分析的功能 +""" + +import pandas as pd +import numpy as np +from sqlalchemy import create_engine +from pymongo import MongoClient +import logging +from typing import Dict, List, Optional, Union, Tuple +import json + +from .config import DB_URL, MONGO_CONFIG, LOG_FILE +from .stock_price_collector import StockPriceCollector +from .industry_analysis import IndustryAnalyzer + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(LOG_FILE), + logging.StreamHandler() + ] +) +logger = logging.getLogger("financial_analysis") + +class FinancialAnalyzer: + """财务分析器类""" + + def __init__(self): + """初始化财务分析器""" + # 初始化MySQL连接 + self.mysql_engine = create_engine( + DB_URL, + pool_size=5, + max_overflow=10, + pool_recycle=3600 + ) + + # 初始化MongoDB连接 + self.mongo_client = MongoClient( + host=MONGO_CONFIG['host'], + port=MONGO_CONFIG['port'], + username=MONGO_CONFIG['username'], + password=MONGO_CONFIG['password'] + ) + self.mongo_db = self.mongo_client[MONGO_CONFIG['db']] + self.mongo_collection = self.mongo_db[MONGO_CONFIG['collection']] + self.wacc_collection = self.mongo_db['wind_stock_wacc_roic'] + + logger.info("财务分析器初始化完成") + + def _get_growth_indicators(self, wind_data: Dict) -> Dict[str, float]: + """ + 从wind_data中提取增长指标 + + Args: + wind_data: wind_data字典 + + Returns: + 包含增长指标的字典 + """ + indicators = { + 'rd_expense_growth': self._find_indicator(wind_data, 'grows', '研发费用同比增长'), + 'roe_growth': self._find_indicator(wind_data, 'grows', '净资产收益率(摊薄)(同比增长)'), + 'diluted_eps_growth': self._find_indicator(wind_data, 'grows', '稀释每股收益(同比增长率)'), + 'operating_cash_flow_per_share_growth': self._find_indicator(wind_data, 'grows', '每股经营活动产生的现金流量净额(同比增长率)'), + 'revenue_growth': self._find_indicator(wind_data, 'grows', '营业收入(同比增长率)'), + 'operating_profit_growth': self._find_indicator(wind_data, 'grows', '营业利润(同比增长率)'), + 'net_profit_growth_excl_nonrecurring': self._find_indicator(wind_data, 'grows', '归属母公司股东的净利润-扣除非经常性损益(同比增长率)'), + 'operating_cash_flow_growth': self._find_indicator(wind_data, 'grows', '经营活动产生的现金流量净额(同比增长率)') + } + return indicators + + def _calculate_growth_change(self, current_value: float, previous_value: float) -> Optional[float]: + """ + 计算增长率的增长变化 + + Args: + current_value: 当前值 + previous_value: 上一期值 + + Returns: + 增长变化率,如果无法计算则返回None + """ + try: + if current_value is None or previous_value is None: + return None + if previous_value == 0: + return None + return (current_value - previous_value) / abs(previous_value) + except Exception as e: + logger.error(f"计算增长变化率失败: {str(e)}") + return None + + def get_growth_change_indicators(self, stock_code: str) -> Dict: + """ + 获取增长指标的变化率 + + Args: + stock_code: 股票代码 + + Returns: + 包含增长指标变化率的字典 + """ + try: + # 查询MongoDB + record = self.mongo_collection.find_one({'code': stock_code}) + if not record or 'wind_data' not in record: + return { + 'success': False, + 'message': f'未找到股票 {stock_code} 的财务数据' + } + + # 获取2023-12-31和2024-12-31的数据 + wind_data_2023 = None + wind_data_2024 = None + + for data in record['wind_data']: + if data['time'] == '2023-12-31': + wind_data_2023 = data + elif data['time'] == '2024-12-31': + wind_data_2024 = data + + if not wind_data_2023 or not wind_data_2024: + return { + 'success': False, + 'message': f'未找到股票 {stock_code} 的2023或2024年财务数据' + } + + # 获取两个时间点的指标 + indicators_2023 = self._get_growth_indicators(wind_data_2023) + indicators_2024 = self._get_growth_indicators(wind_data_2024) + + # 计算变化率 + growth_changes = {} + for key in indicators_2023.keys(): + current_value = indicators_2024.get(key) + previous_value = indicators_2023.get(key) + growth_changes[key] = self._calculate_growth_change(current_value, previous_value) + + return { + 'success': True, + 'stock_code': stock_code, + 'data': { + 'indicators_2023': indicators_2023, + 'indicators_2024': indicators_2024, + 'growth_changes': growth_changes + } + } + + except Exception as e: + logger.error(f"获取增长指标变化率失败: {str(e)}") + return { + 'success': False, + 'message': f'获取增长指标变化率失败: {str(e)}' + } + + def get_wacc_data(self, stock_code: str) -> Optional[float]: + """ + 获取股票的WACC数据 + + Args: + stock_code: 股票代码 + + Returns: + WACC值,如果未找到则返回None + """ + try: + # 查询MongoDB + record = self.wacc_collection.find_one({ + 'code': stock_code, + 'endTime': '20241231' + }) + + if not record: + logger.warning(f"未找到股票 {stock_code} 的WACC数据") + return None + + return record['wacc'] + + except Exception as e: + logger.error(f"获取WACC数据失败: {str(e)}") + return None + + def _calculate_profit_years(self, wind_data_list: List[Dict]) -> int: + """ + 计算近五年的盈利年数 + + Args: + wind_data_list: 按时间排序的wind_data列表 + + Returns: + 盈利年数 + """ + try: + profit_years = 0 + # 获取最近5年的数据 + recent_years = sorted(wind_data_list, key=lambda x: x['time'], reverse=True)[:5] + + for year_data in recent_years: + # 获取销售净利率 + net_profit_ratio = self._find_indicator(year_data, 'profitability', '净利润/营业总收入') + if net_profit_ratio is not None and net_profit_ratio > 0: + profit_years += 1 + + return profit_years + + except Exception as e: + logger.error(f"计算盈利年数失败: {str(e)}") + return 0 + + def extract_financial_indicators(self, stock_code: str) -> Dict: + """ + 从MongoDB中提取指定的财务指标 + + Args: + stock_code: 股票代码 + + Returns: + 包含财务指标的字典 + """ + try: + # 查询MongoDB + record = self.mongo_collection.find_one({'code': stock_code}) + if not record or 'wind_data' not in record: + return { + 'success': False, + 'message': f'未找到股票 {stock_code} 的财务数据' + } + + # 获取最新的财务数据(按时间排序) + wind_data = sorted(record['wind_data'], key=lambda x: x['time'], reverse=True)[0] + + # 获取WACC数据 + wacc = self.get_wacc_data(stock_code) + + # 计算近五年盈利年数 + profit_years = self._calculate_profit_years(record['wind_data']) + + + # 定义指标映射 + indicators = { + # 偿债能力指标 + 'debt_equity_ratio': self._find_indicator(wind_data, 'solvency', '产权比率'), + 'debt_ebitda_ratio': self._find_indicator(wind_data, 'solvency', '全部债务/EBITDA'), + 'interest_coverage_ratio': self._find_indicator(wind_data, 'solvency', '已获利息倍数(EBIT/利息费用)'), + 'current_ratio': self._find_indicator(wind_data, 'solvency', '流动比率'), + 'quick_ratio': self._find_indicator(wind_data, 'solvency', '速动比率'), + 'cash_ratio': self._find_indicator(wind_data, 'solvency', '现金比率'), + 'cash_to_debt_ratio': self._find_indicator(wind_data, 'solvency', '经营活动产生的现金流量净额/负债合计'), + + # 资本结构指标 + 'equity_ratio': self._find_indicator(wind_data, 'capitalStructure', '股东权益比'), + + # 盈利能力指标 + 'profit_years': profit_years, # 添加近五年盈利年数 + + # 成长能力指标 + 'diluted_eps_growth': self._find_indicator(wind_data, 'grows', '稀释每股收益(同比增长率)'), + 'operating_cash_flow_per_share_growth': self._find_indicator(wind_data, 'grows', '每股经营活动产生的现金流量净额(同比增长率)'), + 'revenue_growth': self._find_indicator(wind_data, 'grows', '营业收入(同比增长率)'), + 'operating_profit_growth': self._find_indicator(wind_data, 'grows', '营业利润(同比增长率)'), + 'net_profit_growth_excl_nonrecurring': self._find_indicator(wind_data, 'grows', '归属母公司股东的净利润-扣除非经常性损益(同比增长率)'), + 'operating_cash_flow_growth': self._find_indicator(wind_data, 'grows', '经营活动产生的现金流量净额(同比增长率)'), + 'rd_expense_growth': self._find_indicator(wind_data, 'grows', '研发费用同比增长'), + 'roe_growth': self._find_indicator(wind_data, 'grows', '净资产收益率(摊薄)(同比增长)'), + + # Z值相关指标 + 'working_capital_to_assets': self._find_indicator(wind_data, 'ZValue', '营运资本/总资产'), + 'retained_earnings_to_assets': self._find_indicator(wind_data, 'ZValue', '留存收益/总资产'), + 'ebit_to_assets': self._find_indicator(wind_data, 'ZValue', '息税前利润(TTM)/总资产'), + 'market_value_to_liabilities': self._find_indicator(wind_data, 'ZValue', '当日总市值/负债总计'), + 'equity_to_liabilities': self._find_indicator(wind_data, 'ZValue', '股东权益合计(含少数)/负债总计'), + 'revenue_to_assets': self._find_indicator(wind_data, 'ZValue', '营业收入/总资产'), + 'z_score': self._find_indicator(wind_data, 'ZValue', 'Z值'), + + # 营运能力指标 + 'inventory_turnover_days': self._find_indicator(wind_data, 'operatingCapacity', '存货周转天数'), + 'receivables_turnover_days': self._find_indicator(wind_data, 'operatingCapacity', '应收账款周转天数'), + 'payables_turnover_days': self._find_indicator(wind_data, 'operatingCapacity', '应付账款周转天数'), + + # 盈利能力指标 + 'gross_profit_margin': self._find_indicator(wind_data, 'profitability', '销售毛利率'), + 'operating_profit_margin': self._find_indicator(wind_data, 'profitability', '营业利润/营业总收入'), + 'net_profit_margin': self._find_indicator(wind_data, 'profitability', '销售净利率'), + 'roe': self._find_indicator(wind_data, 'profitability', '净资产收益率ROE(平均)'), + 'roa': self._find_indicator(wind_data, 'profitability', '总资产净利率ROA'), + 'roic': self._find_indicator(wind_data, 'profitability', '投入资本回报率ROIC'), + + # WACC数据 + 'wacc': wacc + } + + # 对数值进行四舍五入处理 + for key, value in indicators.items(): + if isinstance(value, (int, float)) and value is not None: + indicators[key] = round(value, 3) + + # 添加数据时间 + indicators['data_time'] = wind_data['time'] + + + return { + 'success': True, + 'stock_code': stock_code, + 'indicators': indicators + } + + except Exception as e: + logger.error(f"提取财务指标失败: {str(e)}") + return { + 'success': False, + 'message': f'提取财务指标失败: {str(e)}' + } + + def _find_indicator(self, data: Dict, category: str, meaning: str) -> Optional[float]: + """ + 在指定类别中查找指标值 + + Args: + data: 财务数据字典 + category: 指标类别 + meaning: 指标含义 + + Returns: + 指标值,如果未找到则返回None + """ + try: + if category not in data: + return None + + for item in data[category]['list']: + if item['meaning'] == meaning: + return item['data'] + return None + + except Exception as e: + logger.error(f"查找指标 {category}.{meaning} 失败: {str(e)}") + return None + + def test_mongo_structure(self, stock_code: str = None) -> Dict: + """ + 测试方法:查看MongoDB集合的字段结构 + + Args: + stock_code: 股票代码,如果为None则返回第一条记录 + + Returns: + 包含字段结构的字典 + """ + try: + # 构建查询条件 + query = {} + if stock_code: + query['code'] = stock_code + + # 获取一条记录 + record = self.mongo_collection.find_one(query) + + if not record: + return { + 'success': False, + 'message': f'未找到股票 {stock_code} 的记录' if stock_code else '集合为空' + } + + # 移除MongoDB的_id字段 + if '_id' in record: + record.pop('_id') + + # 获取所有字段名 + fields = list(record.keys()) + + # 格式化输出 + result = { + 'success': True, + 'fields': fields, + 'sample_data': record + } + + # 打印字段信息 + logger.info(f"集合字段列表: {json.dumps(fields, ensure_ascii=False, indent=2)}") + logger.info(f"示例数据: {json.dumps(record, ensure_ascii=False, indent=2)}") + + return result + + except Exception as e: + logger.error(f"查询MongoDB结构失败: {str(e)}") + return { + 'success': False, + 'message': f'查询失败: {str(e)}' + } + + def analyze_financial_data(self, stock_code: str) -> Dict: + """ + 分析财务数据 + + Args: + stock_code: 股票代码 + + Returns: + 分析结果字典,包含所有财务指标及其排名得分 + """ + try: + # 获取股票价格数据 + price_collector = StockPriceCollector() + price_data = price_collector.get_stock_price_data(stock_code) + + # 获取概念板块数据 + industry_analyzer = IndustryAnalyzer() + concepts = industry_analyzer.get_stock_concepts(stock_code) + + # 获取基础财务指标 + base_result = self.extract_financial_indicators(stock_code) + if not base_result.get('success'): + return base_result + + # 获取增长指标变化 + growth_result = self.get_growth_change_indicators(stock_code) + if not growth_result.get('success'): + return growth_result + + # 获取行业排名 + rank_result = self.calculate_industry_rankings(stock_code) + if not rank_result.get('success'): + return rank_result + + # 定义指标说明映射 + indicator_descriptions = { + # 财务实力指标 + 'debt_equity_ratio': '债务股本比率', + 'debt_ebitda_ratio': '债务/税息折旧及摊销前利润', + 'interest_coverage_ratio': '利息保障倍数', + 'cash_to_debt_ratio': '现金负债率', + 'equity_ratio': '股东权益比', + 'wacc': '加权平均资本成本', + 'roic': '资本回报率', + + # 盈利能力指标 + 'profit_years': '过去五年盈利年数', + 'gross_profit_margin': '毛利率', + 'operating_profit_margin': '营业利润率', + 'net_profit_margin': '净利率', + 'roe': '股本回报率ROE', + 'roa': '资产收益率ROA', + + # 成长能力指标 + 'diluted_eps_growth': '稀释每股收益增长率', + 'operating_cash_flow_per_share_growth': '每股经营活动产生的现金流量净额增长率', + 'revenue_growth': '营业收入增长率', + 'operating_profit_growth': '营业利润增长率', + 'net_profit_growth_excl_nonrecurring': '扣非净利润增长率', + 'operating_cash_flow_growth': '经营活动产生的现金流量净额增长率', + 'rd_expense_growth': '研发费用增长率', + 'roe_growth': '净资产收益率增长率', + + # 价值评级指标 + 'z_score': 'Z值', + 'working_capital_to_assets': '营运资本/总资产', + 'retained_earnings_to_assets': '留存收益/总资产', + 'ebit_to_assets': '息税前利润/总资产', + 'market_value_to_liabilities': '总市值/负债总计', + 'equity_to_liabilities': '股东权益/负债总计', + 'revenue_to_assets': '营业收入/总资产', + + # 流动性指标 + 'inventory_turnover_days': '存货周转天数', + 'receivables_turnover_days': '应收账款周转天数', + 'payables_turnover_days': '应付账款周转天数', + 'current_ratio': '流动比率', + 'quick_ratio': '速动比率', + 'cash_ratio': '现金比率' + } + + # 构建基础指标数据 + base_indicators = base_result.get('indicators', {}) + rankings = rank_result.get('rankings', {}) + + # 定义各板块的指标列表 + financial_strength_indicators = [ + 'debt_equity_ratio', 'debt_ebitda_ratio', 'interest_coverage_ratio', + 'cash_to_debt_ratio', 'equity_ratio', 'wacc', 'roic' + ] + + profitability_indicators = [ + 'gross_profit_margin', 'operating_profit_margin', + 'net_profit_margin', 'roe', 'roa', 'roic', 'profit_years' + ] + + growth_indicators = [ + 'diluted_eps_growth', 'operating_cash_flow_per_share_growth', + 'revenue_growth', 'operating_profit_growth', + 'net_profit_growth_excl_nonrecurring', 'operating_cash_flow_growth', + 'rd_expense_growth', 'roe_growth' + ] + + value_rating_indicators = [ + 'z_score', 'working_capital_to_assets', 'retained_earnings_to_assets', + 'ebit_to_assets', 'market_value_to_liabilities', + 'equity_to_liabilities', 'revenue_to_assets' + ] + + liquidity_indicators = [ + 'inventory_turnover_days', 'receivables_turnover_days', + 'payables_turnover_days', 'current_ratio', 'quick_ratio', 'cash_ratio' + ] + + # 处理各板块指标 + def process_indicators(indicator_list): + result = [] + total_score = 0 + valid_scores = 0 + + for key in indicator_list: + if key in base_indicators: + rank_score = rankings.get(key, 0) + if rank_score is not None: + total_score += rank_score + valid_scores += 1 + + result.append({ + 'key': key, + 'desc': indicator_descriptions.get(key, key), + 'value': base_indicators[key], + 'rank_score': rank_score + }) + + # 计算平均得分 + avg_score = round(total_score / valid_scores, 2) if valid_scores > 0 else 0 + + return { + 'indicators': result, + 'avg_score': avg_score + } + + # 处理增长指标变化 + growth_changes = growth_result.get('data', {}).get('growth_changes', {}) + growth_changes_list = [] + total_growth_score = 0 + valid_growth_scores = 0 + + for key in growth_indicators: + if key in growth_changes: + value = growth_changes[key] + if isinstance(value, (int, float)) and value is not None: + value = round(value, 3) + + rank_score = rankings.get(f'{key}_change', 0) + if rank_score is not None: + total_growth_score += rank_score + valid_growth_scores += 1 + + growth_changes_list.append({ + 'key': f'{key}_change', + 'desc': f'{indicator_descriptions.get(key, key)}增长变化率', + 'value': value, + 'rank_score': rank_score + }) + + # 计算增长指标的平均得分 + growth_avg_score = round(total_growth_score / valid_growth_scores, 2) if valid_growth_scores > 0 else 0 + + # 构建增长指标数据 + growth_data = { + 'indicators': growth_changes_list, + 'avg_score': growth_avg_score + } + + return { + 'success': True, + 'stock_code': stock_code, + 'data_time': base_indicators.get('data_time'), + 'financial_strength': process_indicators(financial_strength_indicators), + 'profitability': process_indicators(profitability_indicators), + 'growth': growth_data, + 'value_rating': process_indicators(value_rating_indicators), + 'liquidity': process_indicators(liquidity_indicators), + 'concepts': concepts, # 添加概念板块数据 + 'price_data': price_data # 添加实时股价数据 + } + + except Exception as e: + logger.error(f"分析财务数据失败: {str(e)}") + return { + 'success': False, + 'message': f'分析财务数据失败: {str(e)}' + } + + def __del__(self): + """析构函数,关闭数据库连接""" + if hasattr(self, 'mongo_client'): + self.mongo_client.close() + logger.info("数据库连接已关闭") + + def _convert_stock_code_format(self, stock_code: str) -> str: + """ + 转换股票代码格式 + + Args: + stock_code: 原始股票代码,格式如 "603507.SH" + + Returns: + 转换后的股票代码,格式如 "SH603507" + """ + try: + code, market = stock_code.split('.') + return f"{market}{code}" + except Exception as e: + logger.error(f"转换股票代码格式失败: {str(e)}") + return stock_code + + def get_industry_stocks(self, stock_code: str) -> List: + """ + 获取同行业股票列表 + + Args: + stock_code: 股票代码,格式如 "603507.SH" + + Returns: + 包含同行业股票列表的字典 + """ + try: + # 转换股票代码格式 + formatted_code = self._convert_stock_code_format(stock_code) + + # 查询行业板块 + query = """ + SELECT bk_name + FROM gp_hybk + WHERE gp_code = %s + """ + result = pd.read_sql(query, self.mysql_engine, params=(formatted_code,)) + + if result.empty: + return { + 'success': False, + 'message': f'未找到股票 {stock_code} 的行业信息' + } + + bk_name = result.iloc[0]['bk_name'] + + # 查询同行业股票 + query = """ + SELECT gp_code + FROM gp_hybk + WHERE bk_name = %s + """ + stocks = pd.read_sql(query, self.mysql_engine, params=(bk_name,)) + + # 转换回原始格式 + stock_list = [] + for code in stocks['gp_code']: + if code.startswith('SH'): + stock_list.append(f"{code[2:]}.SH") + elif code.startswith('SZ'): + stock_list.append(f"{code[2:]}.SZ") + + return stock_list + + except Exception as e: + logger.error(f"获取同行业股票列表失败: {str(e)}") + return [] + + def _calculate_industry_rank_score(self, value: float, values: List[float], is_higher_better: bool = True) -> float: + """ + 计算行业排名得分 + + Args: + value: 当前值 + values: 行业所有值列表 + is_higher_better: 是否越高越好,默认为True + + Returns: + 0-10的得分,10表示第一名,0表示最后一名 + """ + try: + if value is None or not values: + return 0 + + # 过滤掉None值 + valid_values = [v for v in values if v is not None] + if not valid_values: + return 0 + + # 计算排名 + if is_higher_better: + rank = sum(1 for x in valid_values if x > value) + 1 + else: + rank = sum(1 for x in valid_values if x < value) + 1 + + # 计算得分 (10 * (1 - (rank - 1) / (n - 1))) + n = len(valid_values) + if n == 1: + return 10 + score = 10 * (1 - (rank - 1) / (n - 1)) + return round(score, 2) + + except Exception as e: + logger.error(f"计算行业排名得分失败: {str(e)}") + return 0 + + def _get_industry_indicators(self, stock_list: List[str]) -> Dict[str, List[float]]: + """ + 获取行业所有公司的指标值 + + Args: + stock_list: 股票代码列表 + + Returns: + 包含所有指标值的字典,key为指标名,value为该指标的所有公司值列表 + """ + try: + industry_indicators = {} + + # 遍历所有股票获取指标 + for stock_code in stock_list: + result = self.extract_financial_indicators(stock_code) + if not result.get('success'): + continue + + indicators = result.get('indicators', {}) + for key, value in indicators.items(): + if key != 'data_time': + if key not in industry_indicators: + industry_indicators[key] = [] + industry_indicators[key].append(value) + + return industry_indicators + + except Exception as e: + logger.error(f"获取行业指标失败: {str(e)}") + return {} + + def calculate_industry_rankings(self, stock_code: str) -> Dict: + """ + 计算公司在行业中的排名得分 + + Args: + stock_code: 股票代码 + + Returns: + 包含所有指标排名得分的字典 + """ + try: + # 获取同行业股票列表 + stock_list = self.get_industry_stocks(stock_code) + if not stock_list: + return { + 'success': False, + 'message': f'未找到股票 {stock_code} 的同行业公司' + } + + # 获取当前公司的指标 + current_result = self.extract_financial_indicators(stock_code) + if not current_result.get('success'): + return current_result + + # 获取当前公司的增长指标变化 + current_growth_result = self.get_growth_change_indicators(stock_code) + if not current_growth_result.get('success'): + return current_growth_result + + # 获取行业所有公司的指标 + industry_indicators = self._get_industry_indicators(stock_list) + + # 获取行业所有公司的增长指标变化 + industry_growth_indicators = {} + for stock in stock_list: + growth_result = self.get_growth_change_indicators(stock) + if growth_result.get('success'): + growth_changes = growth_result.get('data', {}).get('growth_changes', {}) + for key, value in growth_changes.items(): + if key not in industry_growth_indicators: + industry_growth_indicators[key] = [] + industry_growth_indicators[key].append(value) + + # 定义指标是否越高越好 + higher_better_indicators = { + # 偿债能力指标 + 'current_ratio': True, + 'quick_ratio': True, + 'cash_ratio': True, + 'interest_coverage_ratio': True, + + # 资本结构指标 + 'equity_ratio': True, + + # 盈利能力指标 + 'profit_years': True, # 盈利年数越高越好 + + # 成长能力指标 + 'diluted_eps_growth': True, + 'operating_cash_flow_per_share_growth': True, + 'revenue_growth': True, + 'operating_profit_growth': True, + 'net_profit_growth_excl_nonrecurring': True, + 'operating_cash_flow_growth': True, + 'rd_expense_growth': True, + 'roe_growth': True, + + # Z值相关指标 + 'working_capital_to_assets': True, + 'retained_earnings_to_assets': True, + 'ebit_to_assets': True, + 'market_value_to_liabilities': True, + 'equity_to_liabilities': True, + 'revenue_to_assets': True, + 'z_score': True, + + # 营运能力指标 + 'inventory_turnover_days': False, + 'receivables_turnover_days': False, + 'payables_turnover_days': False, + + # 盈利能力指标 + 'gross_profit_margin': True, + 'operating_profit_margin': True, + 'net_profit_margin': True, + 'roe': True, + 'roa': True, + 'roic': True, + + # WACC数据 + 'wacc': False + } + + # 计算每个指标的排名得分 + current_indicators = current_result.get('indicators', {}) + rankings = {} + + # 计算基础指标的排名 + for key, value in current_indicators.items(): + if key != 'data_time' and key in industry_indicators: + is_higher_better = higher_better_indicators.get(key, True) + score = self._calculate_industry_rank_score( + value, + industry_indicators[key], + is_higher_better + ) + rankings[key] = score + + # 计算增长指标的排名 + current_growth_changes = current_growth_result.get('data', {}).get('growth_changes', {}) + for key, value in current_growth_changes.items(): + if key in industry_growth_indicators: + is_higher_better = higher_better_indicators.get(key, True) + score = self._calculate_industry_rank_score( + value, + industry_growth_indicators[key], + is_higher_better + ) + rankings[f'{key}_change'] = score + + return { + 'success': True, + 'stock_code': stock_code, + 'data_time': current_indicators.get('data_time'), + 'rankings': rankings + } + + except Exception as e: + logger.error(f"计算行业排名失败: {str(e)}") + return { + 'success': False, + 'message': f'计算行业排名失败: {str(e)}' + } \ No newline at end of file diff --git a/src/valuation_analysis/industry_analysis.py b/src/valuation_analysis/industry_analysis.py index 944cf69..bc100a7 100644 --- a/src/valuation_analysis/industry_analysis.py +++ b/src/valuation_analysis/industry_analysis.py @@ -274,7 +274,7 @@ class IndustryAnalyzer: logger.info(f"计算行业 {metric} 分位数完成: 当前{metric}={result['current']:.2f}, 百分位={result['percentile']:.2f}%") return result - def get_industry_crowding_index(self, industry_name: str, start_date: str = None, end_date: str = None) -> pd.DataFrame: + def get_industry_crowding_index(self, industry_name: str, start_date: str = None, end_date: str = None, use_cache: bool = True) -> pd.DataFrame: """ 计算行业交易拥挤度指标,并使用Redis缓存结果 @@ -285,6 +285,7 @@ class IndustryAnalyzer: industry_name: 行业名称 start_date: 不再使用此参数,保留是为了兼容性 end_date: 结束日期(默认为当前日期) + use_cache: 是否使用缓存,默认为True Returns: 包含行业拥挤度指标的DataFrame @@ -297,25 +298,26 @@ class IndustryAnalyzer: end_date = datetime.datetime.now().strftime('%Y-%m-%d') # 检查缓存 - cache_key = f"industry_crowding:{industry_name}" - cached_data = redis_client.get(cache_key) - - if cached_data: - try: - # 尝试解析缓存的JSON数据 - cached_df_dict = json.loads(cached_data) - logger.info(f"从缓存获取行业 {industry_name} 的拥挤度数据") - - # 将缓存的字典转换回DataFrame - df = pd.DataFrame(cached_df_dict) - - # 确保trade_date列是日期类型 - df['trade_date'] = pd.to_datetime(df['trade_date']) - - return df - except Exception as cache_error: - logger.warning(f"解析缓存的拥挤度数据失败,将重新查询: {cache_error}") + if use_cache: + cache_key = f"industry_crowding:{industry_name}" + cached_data = redis_client.get(cache_key) + if cached_data: + try: + # 尝试解析缓存的JSON数据 + cached_df_dict = json.loads(cached_data) + logger.info(f"从缓存获取行业 {industry_name} 的拥挤度数据") + + # 将缓存的字典转换回DataFrame + df = pd.DataFrame(cached_df_dict) + + # 确保trade_date列是日期类型 + df['trade_date'] = pd.to_datetime(df['trade_date']) + + return df + except Exception as cache_error: + logger.warning(f"解析缓存的拥挤度数据失败,将重新查询: {cache_error}") + # 获取行业所有股票 stock_codes = self.get_industry_stocks(industry_name) if not stock_codes: @@ -397,15 +399,16 @@ class IndustryAnalyzer: df_dict = df.to_dict(orient='records') # 缓存结果,有效期1天(86400秒) - try: - redis_client.set( - cache_key, - json.dumps(df_dict, default=str), # 使用default=str处理日期等特殊类型 - ex=86400 # 1天的秒数 - ) - logger.info(f"已缓存行业 {industry_name} 的拥挤度数据,有效期为1天") - except Exception as cache_error: - logger.warning(f"缓存行业拥挤度数据失败: {cache_error}") + if use_cache: + try: + redis_client.set( + cache_key, + json.dumps(df_dict, default=str), # 使用default=str处理日期等特殊类型 + ex=86400 # 1天的秒数 + ) + logger.info(f"已缓存行业 {industry_name} 的拥挤度数据,有效期为1天") + except Exception as cache_error: + logger.warning(f"缓存行业拥挤度数据失败: {cache_error}") logger.info(f"成功计算行业 {industry_name} 的拥挤度指标,共 {len(df)} 条记录") return df @@ -509,4 +512,54 @@ class IndustryAnalyzer: except Exception as e: logger.error(f"获取行业综合分析失败: {e}") - return {"success": False, "message": f"获取行业综合分析失败: {e}"} \ No newline at end of file + return {"success": False, "message": f"获取行业综合分析失败: {e}"} + + def get_stock_concepts(self, stock_code: str) -> List[str]: + """ + 获取指定股票所属的概念板块列表 + + Args: + stock_code: 股票代码 + + Returns: + 概念板块名称列表 + """ + try: + # 转换股票代码格式 + formatted_code = self._convert_stock_code_format(stock_code) + + query = text(""" + SELECT DISTINCT bk_name + FROM gp_gnbk + WHERE gp_code = :stock_code + """) + + with self.engine.connect() as conn: + result = conn.execute(query, {"stock_code": formatted_code}).fetchall() + + if result: + return [row[0] for row in result] + else: + logger.warning(f"未找到股票 {stock_code} 的概念板块数据") + return [] + + except Exception as e: + logger.error(f"获取股票概念板块失败: {e}") + return [] + + def _convert_stock_code_format(self, stock_code: str) -> str: + """ + 转换股票代码格式 + + Args: + stock_code: 原始股票代码,格式如 "600519.SH" + + Returns: + 转换后的股票代码,格式如 "SH600519" + """ + try: + code, market = stock_code.split('.') + return f"{market}{code}" + except Exception as e: + logger.error(f"转换股票代码格式失败: {str(e)}") + return stock_code \ No newline at end of file diff --git a/src/valuation_analysis/stock_price_collector.py b/src/valuation_analysis/stock_price_collector.py new file mode 100644 index 0000000..88a4808 --- /dev/null +++ b/src/valuation_analysis/stock_price_collector.py @@ -0,0 +1,463 @@ +""" +东方财富实时股价数据采集模块 +提供从东方财富网站采集实时股价数据并存储到数据库的功能 +功能包括: +1. 采集实时股价数据 +2. 存储数据到数据库 +3. 定时自动更新数据 +""" + +import requests +import pandas as pd +import datetime +import logging +import time +import os +import sys +from pathlib import Path +from sqlalchemy import create_engine, text +from typing import Dict + +# 添加项目根目录到Python路径 +current_file = Path(__file__) +project_root = current_file.parent.parent.parent +sys.path.append(str(project_root)) + +from src.valuation_analysis.config import DB_URL, LOG_FILE + +# 获取项目根目录 +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# 确保日志目录存在 +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(LOG_FILE), + logging.StreamHandler() + ] +) +logger = logging.getLogger("stock_price_collector") + + +def get_create_table_sql() -> str: + """ + 获取创建实时股价数据表的SQL语句 + + Returns: + 创建表的SQL语句 + """ + return """ + CREATE TABLE IF NOT EXISTS stock_price_data ( + stock_code VARCHAR(10) PRIMARY KEY COMMENT '股票代码', + stock_name VARCHAR(50) COMMENT '股票名称', + latest_price DECIMAL(10,2) COMMENT '最新价', + change_percent DECIMAL(10,2) COMMENT '涨跌幅', + change_amount DECIMAL(10,2) COMMENT '涨跌额', + volume BIGINT COMMENT '成交量(手)', + amount DECIMAL(20,2) COMMENT '成交额', + amplitude DECIMAL(10,2) COMMENT '振幅', + turnover_rate DECIMAL(10,2) COMMENT '换手率', + pe_ratio DECIMAL(10,2) COMMENT '市盈率', + high_price DECIMAL(10,2) COMMENT '最高价', + low_price DECIMAL(10,2) COMMENT '最低价', + open_price DECIMAL(10,2) COMMENT '开盘价', + pre_close DECIMAL(10,2) COMMENT '昨收价', + total_market_value DECIMAL(20,2) COMMENT '总市值', + float_market_value DECIMAL(20,2) COMMENT '流通市值', + pb_ratio DECIMAL(10,2) COMMENT '市净率', + list_date DATE COMMENT '上市日期', + update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间' + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='实时股价数据表'; + """ + + +class StockPriceCollector: + """东方财富实时股价数据采集器类""" + + def __init__(self, db_url: str = DB_URL): + """ + 初始化东方财富实时股价数据采集器 + + Args: + db_url: 数据库连接URL + """ + self.engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600 + ) + self.base_url = "https://push2.eastmoney.com/api/qt/clist/get" + self.headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Accept": "application/json, text/plain, */*", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Origin": "https://quote.eastmoney.com", + "Referer": "https://quote.eastmoney.com/", + } + logger.info("东方财富实时股价数据采集器初始化完成") + + def _ensure_table_exists(self) -> bool: + """ + 确保数据表存在,如果不存在则创建 + + Returns: + 是否成功确保表存在 + """ + try: + create_table_query = text(get_create_table_sql()) + + with self.engine.connect() as conn: + conn.execute(create_table_query) + conn.commit() + + logger.info("实时股价数据表创建成功") + return True + + except Exception as e: + logger.error(f"确保数据表存在失败: {e}") + return False + + def _convert_stock_code(self, code: str) -> str: + """ + 转换股票代码格式 + + Args: + code: 原始股票代码 + + Returns: + 转换后的股票代码 + """ + if code.startswith(('0', '3')): + return f"{code}.SZ" + else: + return f"{code}.SH" + + def _parse_list_date(self, date_str: str) -> datetime.date: + """ + 解析上市日期 + + Args: + date_str: 日期字符串 + + Returns: + 日期对象 + """ + if not date_str or date_str == '-': + return None + try: + # 如果输入是整数,先转换为字符串 + if isinstance(date_str, int): + date_str = str(date_str) + return datetime.datetime.strptime(date_str, "%Y%m%d").date() + except ValueError: + logger.warning(f"无法解析日期: {date_str}") + return None + + def fetch_data(self, page: int = 1) -> pd.DataFrame: + """ + 获取指定页码的实时股价数据 + + Args: + page: 页码 + + Returns: + 包含实时股价数据的DataFrame + """ + try: + params = { + "np": 1, + "fltt": 2, + "invt": 2, + "fs": "m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048", + "fid": "f12", + "pn": page, + "pz": 100, + "po": 0, + "dect": 1 + } + + logger.info(f"开始获取第 {page} 页数据") + + response = requests.get(self.base_url, params=params, headers=self.headers) + if response.status_code != 200: + logger.error(f"获取第 {page} 页数据失败: HTTP {response.status_code}") + return pd.DataFrame() + + data = response.json() + if not data.get("rc") == 0: + logger.error(f"获取数据失败: {data.get('message', '未知错误')}") + return pd.DataFrame() + + # 提取数据列表 + items = data.get("data", {}).get("diff", []) + if not items: + logger.warning(f"第 {page} 页未找到有效数据") + return pd.DataFrame() + + # 转换为DataFrame + df = pd.DataFrame(items) + + # 重命名列 + column_mapping = { + "f12": "stock_code", + "f14": "stock_name", + "f2": "latest_price", + "f3": "change_percent", + "f4": "change_amount", + "f5": "volume", + "f6": "amount", + "f7": "amplitude", + "f8": "turnover_rate", + "f9": "pe_ratio", + "f15": "high_price", + "f16": "low_price", + "f17": "open_price", + "f18": "pre_close", + "f20": "total_market_value", + "f21": "float_market_value", + "f23": "pb_ratio", + "f26": "list_date" + } + + df = df.rename(columns=column_mapping) + + # 转换股票代码格式 + df['stock_code'] = df['stock_code'].apply(self._convert_stock_code) + + # 转换上市日期 + df['list_date'] = df['list_date'].apply(self._parse_list_date) + + logger.info(f"第 {page} 页数据获取成功,包含 {len(df)} 条记录") + return df + + except Exception as e: + logger.error(f"获取第 {page} 页数据失败: {e}") + return pd.DataFrame() + + def fetch_all_data(self) -> pd.DataFrame: + """ + 获取所有页的实时股价数据 + + Returns: + 包含所有实时股价数据的DataFrame + """ + all_data = [] + page = 1 + + while True: + page_data = self.fetch_data(page) + if page_data.empty: + logger.info(f"第 {page} 页数据为空,停止采集") + break + + all_data.append(page_data) + + # 如果返回的数据少于100条,说明是最后一页 + if len(page_data) < 100: + break + + page += 1 + # 添加延迟,避免请求过于频繁 + time.sleep(1) + + if all_data: + combined_df = pd.concat(all_data, ignore_index=True) + logger.info(f"数据采集完成,共采集 {len(combined_df)} 条记录") + return combined_df + else: + logger.warning("未获取到任何有效数据") + return pd.DataFrame() + + def save_to_database(self, data: pd.DataFrame) -> bool: + """ + 将数据保存到数据库 + + Args: + data: 要保存的数据DataFrame + + Returns: + 是否成功保存数据 + """ + if data.empty: + logger.warning("没有数据需要保存") + return False + + try: + # 确保数据表存在 + if not self._ensure_table_exists(): + return False + data = data.replace('-', None) + # 将nan值转换为None(在SQL中会变成NULL) + data = data.replace({pd.NA: None, pd.NaT: None}) + data = data.where(pd.notnull(data), None) + + # 添加数据或更新已有数据 + inserted_count = 0 + updated_count = 0 + + with self.engine.connect() as conn: + for _, row in data.iterrows(): + # 将Series转换为dict,并处理nan值 + row_dict = {k: (None if pd.isna(v) else v) for k, v in row.items()} + + # 检查该股票的数据是否已存在 + check_query = text(""" + SELECT COUNT(*) FROM stock_price_data WHERE stock_code = :stock_code + """) + result = conn.execute(check_query, {"stock_code": row_dict['stock_code']}).scalar() + + if result > 0: # 数据已存在,执行更新 + update_query = text(""" + UPDATE stock_price_data SET + stock_name = :stock_name, + latest_price = :latest_price, + change_percent = :change_percent, + change_amount = :change_amount, + volume = :volume, + amount = :amount, + amplitude = :amplitude, + turnover_rate = :turnover_rate, + pe_ratio = :pe_ratio, + high_price = :high_price, + low_price = :low_price, + open_price = :open_price, + pre_close = :pre_close, + total_market_value = :total_market_value, + float_market_value = :float_market_value, + pb_ratio = :pb_ratio, + list_date = :list_date + WHERE stock_code = :stock_code + """) + conn.execute(update_query, row_dict) + updated_count += 1 + else: # 数据不存在,执行插入 + insert_query = text(""" + INSERT INTO stock_price_data ( + stock_code, stock_name, latest_price, change_percent, + change_amount, volume, amount, amplitude, turnover_rate, + pe_ratio, high_price, low_price, open_price, pre_close, + total_market_value, float_market_value, pb_ratio, list_date + ) VALUES ( + :stock_code, :stock_name, :latest_price, :change_percent, + :change_amount, :volume, :amount, :amplitude, :turnover_rate, + :pe_ratio, :high_price, :low_price, :open_price, :pre_close, + :total_market_value, :float_market_value, :pb_ratio, :list_date + ) + """) + conn.execute(insert_query, row_dict) + inserted_count += 1 + + conn.commit() + + logger.info(f"数据保存成功:新增 {inserted_count} 条记录,更新 {updated_count} 条记录") + return True + + except Exception as e: + logger.error(f"保存数据到数据库失败: {e}") + return False + + def update_latest_data(self) -> bool: + """ + 更新最新实时股价数据 + + Returns: + 是否成功更新最新数据 + """ + try: + logger.info("开始更新最新实时股价数据") + + # 获取所有数据 + df = self.fetch_all_data() + if df.empty: + logger.warning("未获取到最新数据") + return False + + # 保存数据到数据库 + result = self.save_to_database(df) + + if result: + logger.info(f"最新数据更新成功,共更新 {len(df)} 条记录") + else: + logger.warning("最新数据更新失败") + + return result + + except Exception as e: + logger.error(f"更新最新数据失败: {e}") + return False + + def get_stock_price_data(self, stock_code: str, convert_code: bool = False) -> Dict: + """ + 获取指定股票的最新价格数据 + + Args: + stock_code: 股票代码 + convert_code: 是否需要转换股票代码格式,默认为False + + Returns: + 包含股票价格数据的字典 + """ + try: + # 转换股票代码格式 + formatted_code = self._convert_stock_code(stock_code) if convert_code else stock_code + + query = text(""" + SELECT + stock_code, + stock_name, + latest_price, + change_percent, + change_amount, + volume, + amount, + amplitude, + turnover_rate, + pe_ratio, + high_price, + low_price, + total_market_value, + float_market_value, + pb_ratio, + list_date, + update_time + FROM + stock_price_data + WHERE + stock_code = :stock_code + """) + + with self.engine.connect() as conn: + result = conn.execute(query, {"stock_code": formatted_code}).fetchone() + + if result: + # 将结果转换为字典 + data = dict(result._mapping) + # 处理日期类型 + if data['list_date']: + data['list_date'] = data['list_date'].strftime('%Y-%m-%d') + if data['update_time']: + data['update_time'] = data['update_time'].strftime('%Y-%m-%d %H:%M:%S') + return data + else: + logger.warning(f"未找到股票 {stock_code} 的价格数据") + return None + + except Exception as e: + logger.error(f"获取股票价格数据失败: {e}") + return None + + +# 示例使用方式 +if __name__ == "__main__": + # 创建实时股价数据采集器 + collector = StockPriceCollector() + + # 更新最新数据 + logger.info("开始更新最新实时股价数据...") + collector.update_latest_data() \ No newline at end of file