This commit is contained in:
liao 2025-05-21 14:01:54 +08:00
parent 5b9ae03000
commit 8f73099e18
7 changed files with 1632 additions and 77 deletions

View File

@ -16,4 +16,5 @@ markdown2>=2.5.3
google-genai google-genai
redis==5.2.1 redis==5.2.1
pandas==2.2.3 pandas==2.2.3
apscheduler==3.11.0 apscheduler==3.11.0
pymongo==4.13.0

View File

@ -1,6 +1,6 @@
import sys import sys
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta, time
import pandas as pd import pandas as pd
import uuid import uuid
import json import json
@ -45,6 +45,9 @@ from src.scripts.stock_daily_data_collector import collect_stock_daily_data
from utils.distributed_lock import DistributedLock from utils.distributed_lock import DistributedLock
from valuation_analysis.industry_analysis import redis_client from valuation_analysis.industry_analysis import redis_client
from valuation_analysis.financial_analysis import FinancialAnalyzer
from src.valuation_analysis.stock_price_collector import StockPriceCollector
# 设置日志 # 设置日志
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -183,6 +186,77 @@ def run_backtest_task(task_id, stocks_buy_dates, end_date):
backtest_tasks[task_id]['error'] = str(e) backtest_tasks[task_id]['error'] = str(e)
logger.error(f"回测任务 {task_id} 失败:{str(e)}") logger.error(f"回测任务 {task_id} 失败:{str(e)}")
def initialize_stock_price_schedule():
"""
初始化实时股价数据采集定时任务
"""
# 创建分布式锁
price_lock = DistributedLock(redis_client, "stock_price_collector", expire_time=3600) # 1小时过期
# 尝试获取锁
if not price_lock.acquire():
logger.info("其他服务器正在运行实时股价数据采集任务,本服务器跳过")
return None
try:
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
# 创建定时任务调度器
scheduler = BackgroundScheduler()
def is_trading_time():
"""判断当前是否为交易时间"""
now = datetime.now()
current_time = now.time()
# 定义交易时间段
morning_start = time(9, 25) # 上午开盘前5分钟
morning_end = time(11, 30) # 上午收盘
afternoon_start = time(13, 0) # 下午开盘
afternoon_end = time(15, 0) # 下午收盘
# 判断是否为工作日
if now.weekday() >= 5: # 5是周六6是周日
return False
# 判断是否在交易时间段内
is_morning = morning_start <= current_time <= morning_end
is_afternoon = afternoon_start <= current_time <= afternoon_end
return is_morning or is_afternoon
def update_stock_price():
"""更新实时股价数据"""
if not is_trading_time():
return
try:
collector = StockPriceCollector()
collector.update_latest_data()
except Exception as e:
logger.error(f"更新实时股价数据失败: {e}")
# 添加定时任务
scheduler.add_job(
func=update_stock_price,
trigger='interval',
minutes=5,
id='stock_price_update',
name='实时股价数据采集',
replace_existing=True
)
# 启动调度器
scheduler.start()
logger.info("实时股价数据采集定时任务已初始化将在交易时间内每5分钟执行一次")
return scheduler
except Exception as e:
logger.error(f"初始化实时股价数据采集定时任务失败: {str(e)}")
price_lock.release()
return None
def initialize_rzrq_collector_schedule(): def initialize_rzrq_collector_schedule():
"""初始化融资融券数据采集定时任务""" """初始化融资融券数据采集定时任务"""
# 创建分布式锁 # 创建分布式锁
@ -203,7 +277,7 @@ def initialize_rzrq_collector_schedule():
# 添加每天下午5点执行的任务 # 添加每天下午5点执行的任务
scheduler.add_job( scheduler.add_job(
func=run_rzrq_initial_collection, func=run_rzrq_initial_collection,
trigger=CronTrigger(hour=18, minute=0), trigger=CronTrigger(hour=18, minute=40),
id='rzrq_daily_update', id='rzrq_daily_update',
name='每日更新融资融券数据', name='每日更新融资融券数据',
replace_existing=True replace_existing=True
@ -284,9 +358,9 @@ def run_stock_daily_collection():
return False return False
def run_rzrq_initial_collection(): def run_rzrq_initial_collection():
"""执行融资融券数据初始全量采集""" """执行融资融券数据更新采集"""
try: try:
logger.info("开始执行融资融券数据初始全量采集") logger.info("开始执行融资融券数据更新采集")
# 生成任务ID # 生成任务ID
task_id = f"rzrq-{uuid.uuid4().hex[:16]}" task_id = f"rzrq-{uuid.uuid4().hex[:16]}"
@ -296,7 +370,7 @@ def run_rzrq_initial_collection():
'status': 'running', 'status': 'running',
'created_at': datetime.now().isoformat(), 'created_at': datetime.now().isoformat(),
'type': 'initial_collection', 'type': 'initial_collection',
'message': '开始执行融资融券数据初始全量采集' 'message': '开始执行融资融券数据更新采集'
} }
# 在新线程中执行采集任务 # 在新线程中执行采集任务
@ -307,16 +381,16 @@ def run_rzrq_initial_collection():
if result: if result:
rzrq_tasks[task_id]['status'] = 'completed' rzrq_tasks[task_id]['status'] = 'completed'
rzrq_tasks[task_id]['message'] = '融资融券数据初始全量采集完成' rzrq_tasks[task_id]['message'] = '融资融券数据更新完成'
logger.info(f"融资融券数据初始全量采集任务 {task_id} 完成") logger.info(f"融资融券数据更新任务 {task_id} 完成")
else: else:
rzrq_tasks[task_id]['status'] = 'failed' rzrq_tasks[task_id]['status'] = 'failed'
rzrq_tasks[task_id]['message'] = '融资融券数据初始全量采集失败' rzrq_tasks[task_id]['message'] = '融资融券数据更新失败'
logger.error(f"融资融券数据初始全量采集任务 {task_id} 失败") logger.error(f"融资融券数据更新任务 {task_id} 失败")
except Exception as e: except Exception as e:
rzrq_tasks[task_id]['status'] = 'failed' rzrq_tasks[task_id]['status'] = 'failed'
rzrq_tasks[task_id]['message'] = f'执行失败: {str(e)}' rzrq_tasks[task_id]['message'] = f'执行失败: {str(e)}'
logger.error(f"执行融资融券数据初始全量采集线程中出错: {str(e)}") logger.error(f"执行融资融券数据更新线程中出错: {str(e)}")
# 创建并启动线程 # 创建并启动线程
thread = Thread(target=collection_task) thread = Thread(target=collection_task)
@ -325,7 +399,7 @@ def run_rzrq_initial_collection():
return task_id return task_id
except Exception as e: except Exception as e:
logger.error(f"启动融资融券数据初始全量采集任务失败: {str(e)}") logger.error(f"启动融资融券数据更新任务失败: {str(e)}")
if 'task_id' in locals(): if 'task_id' in locals():
rzrq_tasks[task_id]['status'] = 'failed' rzrq_tasks[task_id]['status'] = 'failed'
rzrq_tasks[task_id]['message'] = f'启动失败: {str(e)}' rzrq_tasks[task_id]['message'] = f'启动失败: {str(e)}'
@ -2567,7 +2641,7 @@ def initialize_industry_crowding_schedule():
# 添加每天晚上10点执行的任务 # 添加每天晚上10点执行的任务
scheduler.add_job( scheduler.add_job(
func=precalculate_industry_crowding, func=precalculate_industry_crowding,
trigger=CronTrigger(hour=22, minute=0), trigger=CronTrigger(hour=20, minute=30),
id='industry_crowding_precalc', id='industry_crowding_precalc',
name='预计算行业拥挤度指标', name='预计算行业拥挤度指标',
replace_existing=True replace_existing=True
@ -2575,7 +2649,7 @@ def initialize_industry_crowding_schedule():
# 启动调度器 # 启动调度器
scheduler.start() scheduler.start()
logger.info("行业拥挤度指标预计算定时任务已初始化将在每天22:00执行") logger.info("行业拥挤度指标预计算定时任务已初始化将在每天20:30执行")
return scheduler return scheduler
except Exception as e: except Exception as e:
logger.error(f"初始化行业拥挤度指标预计算定时任务失败: {str(e)}") logger.error(f"初始化行业拥挤度指标预计算定时任务失败: {str(e)}")
@ -2585,43 +2659,126 @@ def initialize_industry_crowding_schedule():
def precalculate_industry_crowding(): def precalculate_industry_crowding():
"""预计算所有行业的拥挤度指标""" """预计算所有行业的拥挤度指标"""
try: try:
logger.info("开始预计算所有行业的拥挤度指标") from .valuation_analysis.industry_analysis import IndustryAnalyzer
# 获取所有行业列表 analyzer = IndustryAnalyzer()
industries = industry_analyzer.get_industry_list() industries = analyzer.get_all_industries()
if not industries:
logger.error("获取行业列表失败")
return
# 记录成功和失败的数量
success_count = 0
fail_count = 0
# 遍历所有行业
for industry in industries: for industry in industries:
try: try:
industry_name = industry['name'] # 调用时设置 use_cache=False强制重新计算
logger.info(f"正在计算行业 {industry_name} 的拥挤度指标") df = analyzer.get_industry_crowding_index(industry, use_cache=False)
# 调用拥挤度计算方法
df = industry_analyzer.get_industry_crowding_index(industry_name)
if not df.empty: if not df.empty:
success_count += 1 logger.info(f"成功预计算行业 {industry} 的拥挤度指标")
logger.info(f"成功计算行业 {industry_name} 的拥挤度指标")
else: else:
fail_count += 1 logger.warning(f"行业 {industry} 的拥挤度指标计算失败")
logger.warning(f"计算行业 {industry_name} 的拥挤度指标失败")
except Exception as e: except Exception as e:
fail_count += 1 logger.error(f"预计算行业 {industry} 的拥挤度指标时出错: {str(e)}")
logger.error(f"计算行业 {industry_name} 的拥挤度指标时出错: {str(e)}")
continue continue
logger.info(f"行业拥挤度指标预计算完成,成功: {success_count},失败: {fail_count}") logger.info("所有行业的拥挤度指标预计算完成")
except Exception as e: except Exception as e:
logger.error(f"预计算行业拥挤度指标失败: {str(e)}") logger.error(f"预计算行业拥挤度指标失败: {str(e)}")
finally:
# 释放分布式锁
industry_crowding_lock = DistributedLock(redis_client, "industry_crowding_calculator")
industry_crowding_lock.release()
@app.route('/api/financial/analysis', methods=['GET'])
def financial_analysis():
"""
财务分析接口
请求参数:
stock_code: 股票代码
返回:
分析结果JSON
"""
try:
stock_code = request.args.get('stock_code')
if not stock_code:
return jsonify({
'success': False,
'message': '缺少必要参数stock_code'
}), 400
analyzer = FinancialAnalyzer()
result = analyzer.analyze_financial_data(stock_code)
return jsonify(result)
except Exception as e:
logger.error(f"财务分析失败: {str(e)}")
return jsonify({
'success': False,
'message': f'财务分析失败: {str(e)}'
}), 500
@app.route('/api/financial/indicators', methods=['GET'])
def get_financial_indicators():
"""
获取财务指标接口
请求参数:
stock_code: 股票代码
返回:
JSON格式的财务指标数据
"""
try:
# 获取股票代码
stock_code = request.args.get('stock_code')
if not stock_code:
return jsonify({
'success': False,
'message': '缺少必要参数: stock_code'
}), 400
# 创建分析器实例
analyzer = FinancialAnalyzer()
# 获取财务指标
result = analyzer.extract_financial_indicators(stock_code)
return jsonify(result)
except Exception as e:
logger.error(f"获取财务指标失败: {str(e)}")
return jsonify({
'success': False,
'message': f'获取财务指标失败: {str(e)}'
}), 500
@app.route('/api/financial/test_structure', methods=['GET'])
def test_mongo_structure():
"""
测试MongoDB集合结构接口
请求参数:
stock_code: 股票代码可选
返回:
JSON格式的集合结构信息
"""
try:
# 获取股票代码(可选)
stock_code = request.args.get('stock_code')
# 创建分析器实例
analyzer = FinancialAnalyzer()
# 获取集合结构
result = analyzer.test_mongo_structure(stock_code)
return jsonify(result)
except Exception as e:
logger.error(f"测试MongoDB结构失败: {str(e)}")
return jsonify({
'success': False,
'message': f'测试MongoDB结构失败: {str(e)}'
}), 500
if __name__ == '__main__': if __name__ == '__main__':
""" """
@ -2644,9 +2801,9 @@ if __name__ == '__main__':
else: else:
print("股票日线采集器锁释放失败或不存在") print("股票日线采集器锁释放失败或不存在")
if industry_crowding_lock.release(): if industry_crowding_lock.release():
print("成功释放股票日线采集器") print("成功释放行业拥挤度")
else: else:
print("股票日线采集器锁释放失败或不存在") print("行业拥挤度锁释放失败或不存在")
print("锁释放操作完成") print("锁释放操作完成")
""" """
@ -2660,5 +2817,8 @@ if __name__ == '__main__':
# 初始化行业拥挤度指标预计算定时任务 # 初始化行业拥挤度指标预计算定时任务
industry_crowding_scheduler = initialize_industry_crowding_schedule() industry_crowding_scheduler = initialize_industry_crowding_schedule()
# 初始化实时股价数据采集定时任务
initialize_stock_price_schedule()
# 启动Web服务器 # 启动Web服务器
app.run(host='0.0.0.0', port=5000, debug=True) app.run(host='0.0.0.0', port=5000, debug=True)

View File

@ -17,6 +17,16 @@ DB_CONFIG = {
# 创建数据库连接URL # 创建数据库连接URL
DB_URL = f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}" DB_URL = f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
# MongoDB配置
MONGO_CONFIG = {
'host': '192.168.18.75',
'port': 27017,
'db': 'judge',
'username': 'root',
'password': 'wlkj2018',
'collection': 'wind_financial_analysis'
}
# 项目根目录 # 项目根目录
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

View File

@ -284,7 +284,6 @@ class EastmoneyRzrqCollector:
# 确保数据表存在 # 确保数据表存在
if not self._ensure_table_exists(): if not self._ensure_table_exists():
return False return False
# 将nan值转换为None在SQL中会变成NULL # 将nan值转换为None在SQL中会变成NULL
data = data.replace({pd.NA: None, pd.NaT: None}) data = data.replace({pd.NA: None, pd.NaT: None})
data = data.where(pd.notnull(data), None) data = data.where(pd.notnull(data), None)
@ -422,18 +421,18 @@ class EastmoneyRzrqCollector:
def initial_data_collection(self) -> bool: def initial_data_collection(self) -> bool:
""" """
首次全量采集融资融券数据 每日更新任务采集融资融券数据
Returns: Returns:
是否成功采集所有数据 是否成功采集所有数据
""" """
try: try:
logger.info("开始获取最新融资融券数据...") logger.info("开始获取最新融资融券数据...")
df = collector.fetch_data(page=1) df = self.fetch_data(page=1)
if not df.empty: if not df.empty:
# 保存数据到数据库 # 保存数据到数据库
if collector.save_to_database(df): if self.save_to_database(df):
logger.info(f"成功更新最新数据,日期:{df.iloc[0]['trade_date']}") logger.info(f"成功更新最新数据,日期:{df.iloc[0]['trade_date']}")
else: else:
logger.error("更新最新数据失败") logger.error("更新最新数据失败")
@ -443,7 +442,7 @@ class EastmoneyRzrqCollector:
return True return True
except Exception as e: except Exception as e:
logger.error(f"首次全量采集失败: {e}") logger.error(f"每日更新任务采集失败: {e}")
return False return False
def get_chart_data(self, limit_days: int = 30) -> dict: def get_chart_data(self, limit_days: int = 30) -> dict:

View File

@ -0,0 +1,869 @@
"""
财务分析模块
提供从MySQL和MongoDB获取财务数据并进行分析的功能
"""
import pandas as pd
import numpy as np
from sqlalchemy import create_engine
from pymongo import MongoClient
import logging
from typing import Dict, List, Optional, Union, Tuple
import json
from .config import DB_URL, MONGO_CONFIG, LOG_FILE
from .stock_price_collector import StockPriceCollector
from .industry_analysis import IndustryAnalyzer
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(LOG_FILE),
logging.StreamHandler()
]
)
logger = logging.getLogger("financial_analysis")
class FinancialAnalyzer:
"""财务分析器类"""
def __init__(self):
"""初始化财务分析器"""
# 初始化MySQL连接
self.mysql_engine = create_engine(
DB_URL,
pool_size=5,
max_overflow=10,
pool_recycle=3600
)
# 初始化MongoDB连接
self.mongo_client = MongoClient(
host=MONGO_CONFIG['host'],
port=MONGO_CONFIG['port'],
username=MONGO_CONFIG['username'],
password=MONGO_CONFIG['password']
)
self.mongo_db = self.mongo_client[MONGO_CONFIG['db']]
self.mongo_collection = self.mongo_db[MONGO_CONFIG['collection']]
self.wacc_collection = self.mongo_db['wind_stock_wacc_roic']
logger.info("财务分析器初始化完成")
def _get_growth_indicators(self, wind_data: Dict) -> Dict[str, float]:
"""
从wind_data中提取增长指标
Args:
wind_data: wind_data字典
Returns:
包含增长指标的字典
"""
indicators = {
'rd_expense_growth': self._find_indicator(wind_data, 'grows', '研发费用同比增长'),
'roe_growth': self._find_indicator(wind_data, 'grows', '净资产收益率(摊薄)(同比增长)'),
'diluted_eps_growth': self._find_indicator(wind_data, 'grows', '稀释每股收益(同比增长率)'),
'operating_cash_flow_per_share_growth': self._find_indicator(wind_data, 'grows', '每股经营活动产生的现金流量净额(同比增长率)'),
'revenue_growth': self._find_indicator(wind_data, 'grows', '营业收入(同比增长率)'),
'operating_profit_growth': self._find_indicator(wind_data, 'grows', '营业利润(同比增长率)'),
'net_profit_growth_excl_nonrecurring': self._find_indicator(wind_data, 'grows', '归属母公司股东的净利润-扣除非经常性损益(同比增长率)'),
'operating_cash_flow_growth': self._find_indicator(wind_data, 'grows', '经营活动产生的现金流量净额(同比增长率)')
}
return indicators
def _calculate_growth_change(self, current_value: float, previous_value: float) -> Optional[float]:
"""
计算增长率的增长变化
Args:
current_value: 当前值
previous_value: 上一期值
Returns:
增长变化率如果无法计算则返回None
"""
try:
if current_value is None or previous_value is None:
return None
if previous_value == 0:
return None
return (current_value - previous_value) / abs(previous_value)
except Exception as e:
logger.error(f"计算增长变化率失败: {str(e)}")
return None
def get_growth_change_indicators(self, stock_code: str) -> Dict:
"""
获取增长指标的变化率
Args:
stock_code: 股票代码
Returns:
包含增长指标变化率的字典
"""
try:
# 查询MongoDB
record = self.mongo_collection.find_one({'code': stock_code})
if not record or 'wind_data' not in record:
return {
'success': False,
'message': f'未找到股票 {stock_code} 的财务数据'
}
# 获取2023-12-31和2024-12-31的数据
wind_data_2023 = None
wind_data_2024 = None
for data in record['wind_data']:
if data['time'] == '2023-12-31':
wind_data_2023 = data
elif data['time'] == '2024-12-31':
wind_data_2024 = data
if not wind_data_2023 or not wind_data_2024:
return {
'success': False,
'message': f'未找到股票 {stock_code} 的2023或2024年财务数据'
}
# 获取两个时间点的指标
indicators_2023 = self._get_growth_indicators(wind_data_2023)
indicators_2024 = self._get_growth_indicators(wind_data_2024)
# 计算变化率
growth_changes = {}
for key in indicators_2023.keys():
current_value = indicators_2024.get(key)
previous_value = indicators_2023.get(key)
growth_changes[key] = self._calculate_growth_change(current_value, previous_value)
return {
'success': True,
'stock_code': stock_code,
'data': {
'indicators_2023': indicators_2023,
'indicators_2024': indicators_2024,
'growth_changes': growth_changes
}
}
except Exception as e:
logger.error(f"获取增长指标变化率失败: {str(e)}")
return {
'success': False,
'message': f'获取增长指标变化率失败: {str(e)}'
}
def get_wacc_data(self, stock_code: str) -> Optional[float]:
"""
获取股票的WACC数据
Args:
stock_code: 股票代码
Returns:
WACC值如果未找到则返回None
"""
try:
# 查询MongoDB
record = self.wacc_collection.find_one({
'code': stock_code,
'endTime': '20241231'
})
if not record:
logger.warning(f"未找到股票 {stock_code} 的WACC数据")
return None
return record['wacc']
except Exception as e:
logger.error(f"获取WACC数据失败: {str(e)}")
return None
def _calculate_profit_years(self, wind_data_list: List[Dict]) -> int:
"""
计算近五年的盈利年数
Args:
wind_data_list: 按时间排序的wind_data列表
Returns:
盈利年数
"""
try:
profit_years = 0
# 获取最近5年的数据
recent_years = sorted(wind_data_list, key=lambda x: x['time'], reverse=True)[:5]
for year_data in recent_years:
# 获取销售净利率
net_profit_ratio = self._find_indicator(year_data, 'profitability', '净利润/营业总收入')
if net_profit_ratio is not None and net_profit_ratio > 0:
profit_years += 1
return profit_years
except Exception as e:
logger.error(f"计算盈利年数失败: {str(e)}")
return 0
def extract_financial_indicators(self, stock_code: str) -> Dict:
"""
从MongoDB中提取指定的财务指标
Args:
stock_code: 股票代码
Returns:
包含财务指标的字典
"""
try:
# 查询MongoDB
record = self.mongo_collection.find_one({'code': stock_code})
if not record or 'wind_data' not in record:
return {
'success': False,
'message': f'未找到股票 {stock_code} 的财务数据'
}
# 获取最新的财务数据(按时间排序)
wind_data = sorted(record['wind_data'], key=lambda x: x['time'], reverse=True)[0]
# 获取WACC数据
wacc = self.get_wacc_data(stock_code)
# 计算近五年盈利年数
profit_years = self._calculate_profit_years(record['wind_data'])
# 定义指标映射
indicators = {
# 偿债能力指标
'debt_equity_ratio': self._find_indicator(wind_data, 'solvency', '产权比率'),
'debt_ebitda_ratio': self._find_indicator(wind_data, 'solvency', '全部债务/EBITDA'),
'interest_coverage_ratio': self._find_indicator(wind_data, 'solvency', '已获利息倍数(EBIT/利息费用)'),
'current_ratio': self._find_indicator(wind_data, 'solvency', '流动比率'),
'quick_ratio': self._find_indicator(wind_data, 'solvency', '速动比率'),
'cash_ratio': self._find_indicator(wind_data, 'solvency', '现金比率'),
'cash_to_debt_ratio': self._find_indicator(wind_data, 'solvency', '经营活动产生的现金流量净额/负债合计'),
# 资本结构指标
'equity_ratio': self._find_indicator(wind_data, 'capitalStructure', '股东权益比'),
# 盈利能力指标
'profit_years': profit_years, # 添加近五年盈利年数
# 成长能力指标
'diluted_eps_growth': self._find_indicator(wind_data, 'grows', '稀释每股收益(同比增长率)'),
'operating_cash_flow_per_share_growth': self._find_indicator(wind_data, 'grows', '每股经营活动产生的现金流量净额(同比增长率)'),
'revenue_growth': self._find_indicator(wind_data, 'grows', '营业收入(同比增长率)'),
'operating_profit_growth': self._find_indicator(wind_data, 'grows', '营业利润(同比增长率)'),
'net_profit_growth_excl_nonrecurring': self._find_indicator(wind_data, 'grows', '归属母公司股东的净利润-扣除非经常性损益(同比增长率)'),
'operating_cash_flow_growth': self._find_indicator(wind_data, 'grows', '经营活动产生的现金流量净额(同比增长率)'),
'rd_expense_growth': self._find_indicator(wind_data, 'grows', '研发费用同比增长'),
'roe_growth': self._find_indicator(wind_data, 'grows', '净资产收益率(摊薄)(同比增长)'),
# Z值相关指标
'working_capital_to_assets': self._find_indicator(wind_data, 'ZValue', '营运资本/总资产'),
'retained_earnings_to_assets': self._find_indicator(wind_data, 'ZValue', '留存收益/总资产'),
'ebit_to_assets': self._find_indicator(wind_data, 'ZValue', '息税前利润(TTM)/总资产'),
'market_value_to_liabilities': self._find_indicator(wind_data, 'ZValue', '当日总市值/负债总计'),
'equity_to_liabilities': self._find_indicator(wind_data, 'ZValue', '股东权益合计(含少数)/负债总计'),
'revenue_to_assets': self._find_indicator(wind_data, 'ZValue', '营业收入/总资产'),
'z_score': self._find_indicator(wind_data, 'ZValue', 'Z值'),
# 营运能力指标
'inventory_turnover_days': self._find_indicator(wind_data, 'operatingCapacity', '存货周转天数'),
'receivables_turnover_days': self._find_indicator(wind_data, 'operatingCapacity', '应收账款周转天数'),
'payables_turnover_days': self._find_indicator(wind_data, 'operatingCapacity', '应付账款周转天数'),
# 盈利能力指标
'gross_profit_margin': self._find_indicator(wind_data, 'profitability', '销售毛利率'),
'operating_profit_margin': self._find_indicator(wind_data, 'profitability', '营业利润/营业总收入'),
'net_profit_margin': self._find_indicator(wind_data, 'profitability', '销售净利率'),
'roe': self._find_indicator(wind_data, 'profitability', '净资产收益率ROE(平均)'),
'roa': self._find_indicator(wind_data, 'profitability', '总资产净利率ROA'),
'roic': self._find_indicator(wind_data, 'profitability', '投入资本回报率ROIC'),
# WACC数据
'wacc': wacc
}
# 对数值进行四舍五入处理
for key, value in indicators.items():
if isinstance(value, (int, float)) and value is not None:
indicators[key] = round(value, 3)
# 添加数据时间
indicators['data_time'] = wind_data['time']
return {
'success': True,
'stock_code': stock_code,
'indicators': indicators
}
except Exception as e:
logger.error(f"提取财务指标失败: {str(e)}")
return {
'success': False,
'message': f'提取财务指标失败: {str(e)}'
}
def _find_indicator(self, data: Dict, category: str, meaning: str) -> Optional[float]:
"""
在指定类别中查找指标值
Args:
data: 财务数据字典
category: 指标类别
meaning: 指标含义
Returns:
指标值如果未找到则返回None
"""
try:
if category not in data:
return None
for item in data[category]['list']:
if item['meaning'] == meaning:
return item['data']
return None
except Exception as e:
logger.error(f"查找指标 {category}.{meaning} 失败: {str(e)}")
return None
def test_mongo_structure(self, stock_code: str = None) -> Dict:
"""
测试方法查看MongoDB集合的字段结构
Args:
stock_code: 股票代码如果为None则返回第一条记录
Returns:
包含字段结构的字典
"""
try:
# 构建查询条件
query = {}
if stock_code:
query['code'] = stock_code
# 获取一条记录
record = self.mongo_collection.find_one(query)
if not record:
return {
'success': False,
'message': f'未找到股票 {stock_code} 的记录' if stock_code else '集合为空'
}
# 移除MongoDB的_id字段
if '_id' in record:
record.pop('_id')
# 获取所有字段名
fields = list(record.keys())
# 格式化输出
result = {
'success': True,
'fields': fields,
'sample_data': record
}
# 打印字段信息
logger.info(f"集合字段列表: {json.dumps(fields, ensure_ascii=False, indent=2)}")
logger.info(f"示例数据: {json.dumps(record, ensure_ascii=False, indent=2)}")
return result
except Exception as e:
logger.error(f"查询MongoDB结构失败: {str(e)}")
return {
'success': False,
'message': f'查询失败: {str(e)}'
}
def analyze_financial_data(self, stock_code: str) -> Dict:
"""
分析财务数据
Args:
stock_code: 股票代码
Returns:
分析结果字典包含所有财务指标及其排名得分
"""
try:
# 获取股票价格数据
price_collector = StockPriceCollector()
price_data = price_collector.get_stock_price_data(stock_code)
# 获取概念板块数据
industry_analyzer = IndustryAnalyzer()
concepts = industry_analyzer.get_stock_concepts(stock_code)
# 获取基础财务指标
base_result = self.extract_financial_indicators(stock_code)
if not base_result.get('success'):
return base_result
# 获取增长指标变化
growth_result = self.get_growth_change_indicators(stock_code)
if not growth_result.get('success'):
return growth_result
# 获取行业排名
rank_result = self.calculate_industry_rankings(stock_code)
if not rank_result.get('success'):
return rank_result
# 定义指标说明映射
indicator_descriptions = {
# 财务实力指标
'debt_equity_ratio': '债务股本比率',
'debt_ebitda_ratio': '债务/税息折旧及摊销前利润',
'interest_coverage_ratio': '利息保障倍数',
'cash_to_debt_ratio': '现金负债率',
'equity_ratio': '股东权益比',
'wacc': '加权平均资本成本',
'roic': '资本回报率',
# 盈利能力指标
'profit_years': '过去五年盈利年数',
'gross_profit_margin': '毛利率',
'operating_profit_margin': '营业利润率',
'net_profit_margin': '净利率',
'roe': '股本回报率ROE',
'roa': '资产收益率ROA',
# 成长能力指标
'diluted_eps_growth': '稀释每股收益增长率',
'operating_cash_flow_per_share_growth': '每股经营活动产生的现金流量净额增长率',
'revenue_growth': '营业收入增长率',
'operating_profit_growth': '营业利润增长率',
'net_profit_growth_excl_nonrecurring': '扣非净利润增长率',
'operating_cash_flow_growth': '经营活动产生的现金流量净额增长率',
'rd_expense_growth': '研发费用增长率',
'roe_growth': '净资产收益率增长率',
# 价值评级指标
'z_score': 'Z值',
'working_capital_to_assets': '营运资本/总资产',
'retained_earnings_to_assets': '留存收益/总资产',
'ebit_to_assets': '息税前利润/总资产',
'market_value_to_liabilities': '总市值/负债总计',
'equity_to_liabilities': '股东权益/负债总计',
'revenue_to_assets': '营业收入/总资产',
# 流动性指标
'inventory_turnover_days': '存货周转天数',
'receivables_turnover_days': '应收账款周转天数',
'payables_turnover_days': '应付账款周转天数',
'current_ratio': '流动比率',
'quick_ratio': '速动比率',
'cash_ratio': '现金比率'
}
# 构建基础指标数据
base_indicators = base_result.get('indicators', {})
rankings = rank_result.get('rankings', {})
# 定义各板块的指标列表
financial_strength_indicators = [
'debt_equity_ratio', 'debt_ebitda_ratio', 'interest_coverage_ratio',
'cash_to_debt_ratio', 'equity_ratio', 'wacc', 'roic'
]
profitability_indicators = [
'gross_profit_margin', 'operating_profit_margin',
'net_profit_margin', 'roe', 'roa', 'roic', 'profit_years'
]
growth_indicators = [
'diluted_eps_growth', 'operating_cash_flow_per_share_growth',
'revenue_growth', 'operating_profit_growth',
'net_profit_growth_excl_nonrecurring', 'operating_cash_flow_growth',
'rd_expense_growth', 'roe_growth'
]
value_rating_indicators = [
'z_score', 'working_capital_to_assets', 'retained_earnings_to_assets',
'ebit_to_assets', 'market_value_to_liabilities',
'equity_to_liabilities', 'revenue_to_assets'
]
liquidity_indicators = [
'inventory_turnover_days', 'receivables_turnover_days',
'payables_turnover_days', 'current_ratio', 'quick_ratio', 'cash_ratio'
]
# 处理各板块指标
def process_indicators(indicator_list):
result = []
total_score = 0
valid_scores = 0
for key in indicator_list:
if key in base_indicators:
rank_score = rankings.get(key, 0)
if rank_score is not None:
total_score += rank_score
valid_scores += 1
result.append({
'key': key,
'desc': indicator_descriptions.get(key, key),
'value': base_indicators[key],
'rank_score': rank_score
})
# 计算平均得分
avg_score = round(total_score / valid_scores, 2) if valid_scores > 0 else 0
return {
'indicators': result,
'avg_score': avg_score
}
# 处理增长指标变化
growth_changes = growth_result.get('data', {}).get('growth_changes', {})
growth_changes_list = []
total_growth_score = 0
valid_growth_scores = 0
for key in growth_indicators:
if key in growth_changes:
value = growth_changes[key]
if isinstance(value, (int, float)) and value is not None:
value = round(value, 3)
rank_score = rankings.get(f'{key}_change', 0)
if rank_score is not None:
total_growth_score += rank_score
valid_growth_scores += 1
growth_changes_list.append({
'key': f'{key}_change',
'desc': f'{indicator_descriptions.get(key, key)}增长变化率',
'value': value,
'rank_score': rank_score
})
# 计算增长指标的平均得分
growth_avg_score = round(total_growth_score / valid_growth_scores, 2) if valid_growth_scores > 0 else 0
# 构建增长指标数据
growth_data = {
'indicators': growth_changes_list,
'avg_score': growth_avg_score
}
return {
'success': True,
'stock_code': stock_code,
'data_time': base_indicators.get('data_time'),
'financial_strength': process_indicators(financial_strength_indicators),
'profitability': process_indicators(profitability_indicators),
'growth': growth_data,
'value_rating': process_indicators(value_rating_indicators),
'liquidity': process_indicators(liquidity_indicators),
'concepts': concepts, # 添加概念板块数据
'price_data': price_data # 添加实时股价数据
}
except Exception as e:
logger.error(f"分析财务数据失败: {str(e)}")
return {
'success': False,
'message': f'分析财务数据失败: {str(e)}'
}
def __del__(self):
"""析构函数,关闭数据库连接"""
if hasattr(self, 'mongo_client'):
self.mongo_client.close()
logger.info("数据库连接已关闭")
def _convert_stock_code_format(self, stock_code: str) -> str:
"""
转换股票代码格式
Args:
stock_code: 原始股票代码格式如 "603507.SH"
Returns:
转换后的股票代码格式如 "SH603507"
"""
try:
code, market = stock_code.split('.')
return f"{market}{code}"
except Exception as e:
logger.error(f"转换股票代码格式失败: {str(e)}")
return stock_code
def get_industry_stocks(self, stock_code: str) -> List:
"""
获取同行业股票列表
Args:
stock_code: 股票代码格式如 "603507.SH"
Returns:
包含同行业股票列表的字典
"""
try:
# 转换股票代码格式
formatted_code = self._convert_stock_code_format(stock_code)
# 查询行业板块
query = """
SELECT bk_name
FROM gp_hybk
WHERE gp_code = %s
"""
result = pd.read_sql(query, self.mysql_engine, params=(formatted_code,))
if result.empty:
return {
'success': False,
'message': f'未找到股票 {stock_code} 的行业信息'
}
bk_name = result.iloc[0]['bk_name']
# 查询同行业股票
query = """
SELECT gp_code
FROM gp_hybk
WHERE bk_name = %s
"""
stocks = pd.read_sql(query, self.mysql_engine, params=(bk_name,))
# 转换回原始格式
stock_list = []
for code in stocks['gp_code']:
if code.startswith('SH'):
stock_list.append(f"{code[2:]}.SH")
elif code.startswith('SZ'):
stock_list.append(f"{code[2:]}.SZ")
return stock_list
except Exception as e:
logger.error(f"获取同行业股票列表失败: {str(e)}")
return []
def _calculate_industry_rank_score(self, value: float, values: List[float], is_higher_better: bool = True) -> float:
"""
计算行业排名得分
Args:
value: 当前值
values: 行业所有值列表
is_higher_better: 是否越高越好默认为True
Returns:
0-10的得分10表示第一名0表示最后一名
"""
try:
if value is None or not values:
return 0
# 过滤掉None值
valid_values = [v for v in values if v is not None]
if not valid_values:
return 0
# 计算排名
if is_higher_better:
rank = sum(1 for x in valid_values if x > value) + 1
else:
rank = sum(1 for x in valid_values if x < value) + 1
# 计算得分 (10 * (1 - (rank - 1) / (n - 1)))
n = len(valid_values)
if n == 1:
return 10
score = 10 * (1 - (rank - 1) / (n - 1))
return round(score, 2)
except Exception as e:
logger.error(f"计算行业排名得分失败: {str(e)}")
return 0
def _get_industry_indicators(self, stock_list: List[str]) -> Dict[str, List[float]]:
"""
获取行业所有公司的指标值
Args:
stock_list: 股票代码列表
Returns:
包含所有指标值的字典key为指标名value为该指标的所有公司值列表
"""
try:
industry_indicators = {}
# 遍历所有股票获取指标
for stock_code in stock_list:
result = self.extract_financial_indicators(stock_code)
if not result.get('success'):
continue
indicators = result.get('indicators', {})
for key, value in indicators.items():
if key != 'data_time':
if key not in industry_indicators:
industry_indicators[key] = []
industry_indicators[key].append(value)
return industry_indicators
except Exception as e:
logger.error(f"获取行业指标失败: {str(e)}")
return {}
def calculate_industry_rankings(self, stock_code: str) -> Dict:
"""
计算公司在行业中的排名得分
Args:
stock_code: 股票代码
Returns:
包含所有指标排名得分的字典
"""
try:
# 获取同行业股票列表
stock_list = self.get_industry_stocks(stock_code)
if not stock_list:
return {
'success': False,
'message': f'未找到股票 {stock_code} 的同行业公司'
}
# 获取当前公司的指标
current_result = self.extract_financial_indicators(stock_code)
if not current_result.get('success'):
return current_result
# 获取当前公司的增长指标变化
current_growth_result = self.get_growth_change_indicators(stock_code)
if not current_growth_result.get('success'):
return current_growth_result
# 获取行业所有公司的指标
industry_indicators = self._get_industry_indicators(stock_list)
# 获取行业所有公司的增长指标变化
industry_growth_indicators = {}
for stock in stock_list:
growth_result = self.get_growth_change_indicators(stock)
if growth_result.get('success'):
growth_changes = growth_result.get('data', {}).get('growth_changes', {})
for key, value in growth_changes.items():
if key not in industry_growth_indicators:
industry_growth_indicators[key] = []
industry_growth_indicators[key].append(value)
# 定义指标是否越高越好
higher_better_indicators = {
# 偿债能力指标
'current_ratio': True,
'quick_ratio': True,
'cash_ratio': True,
'interest_coverage_ratio': True,
# 资本结构指标
'equity_ratio': True,
# 盈利能力指标
'profit_years': True, # 盈利年数越高越好
# 成长能力指标
'diluted_eps_growth': True,
'operating_cash_flow_per_share_growth': True,
'revenue_growth': True,
'operating_profit_growth': True,
'net_profit_growth_excl_nonrecurring': True,
'operating_cash_flow_growth': True,
'rd_expense_growth': True,
'roe_growth': True,
# Z值相关指标
'working_capital_to_assets': True,
'retained_earnings_to_assets': True,
'ebit_to_assets': True,
'market_value_to_liabilities': True,
'equity_to_liabilities': True,
'revenue_to_assets': True,
'z_score': True,
# 营运能力指标
'inventory_turnover_days': False,
'receivables_turnover_days': False,
'payables_turnover_days': False,
# 盈利能力指标
'gross_profit_margin': True,
'operating_profit_margin': True,
'net_profit_margin': True,
'roe': True,
'roa': True,
'roic': True,
# WACC数据
'wacc': False
}
# 计算每个指标的排名得分
current_indicators = current_result.get('indicators', {})
rankings = {}
# 计算基础指标的排名
for key, value in current_indicators.items():
if key != 'data_time' and key in industry_indicators:
is_higher_better = higher_better_indicators.get(key, True)
score = self._calculate_industry_rank_score(
value,
industry_indicators[key],
is_higher_better
)
rankings[key] = score
# 计算增长指标的排名
current_growth_changes = current_growth_result.get('data', {}).get('growth_changes', {})
for key, value in current_growth_changes.items():
if key in industry_growth_indicators:
is_higher_better = higher_better_indicators.get(key, True)
score = self._calculate_industry_rank_score(
value,
industry_growth_indicators[key],
is_higher_better
)
rankings[f'{key}_change'] = score
return {
'success': True,
'stock_code': stock_code,
'data_time': current_indicators.get('data_time'),
'rankings': rankings
}
except Exception as e:
logger.error(f"计算行业排名失败: {str(e)}")
return {
'success': False,
'message': f'计算行业排名失败: {str(e)}'
}

View File

@ -274,7 +274,7 @@ class IndustryAnalyzer:
logger.info(f"计算行业 {metric} 分位数完成: 当前{metric}={result['current']:.2f}, 百分位={result['percentile']:.2f}%") logger.info(f"计算行业 {metric} 分位数完成: 当前{metric}={result['current']:.2f}, 百分位={result['percentile']:.2f}%")
return result return result
def get_industry_crowding_index(self, industry_name: str, start_date: str = None, end_date: str = None) -> pd.DataFrame: def get_industry_crowding_index(self, industry_name: str, start_date: str = None, end_date: str = None, use_cache: bool = True) -> pd.DataFrame:
""" """
计算行业交易拥挤度指标并使用Redis缓存结果 计算行业交易拥挤度指标并使用Redis缓存结果
@ -285,6 +285,7 @@ class IndustryAnalyzer:
industry_name: 行业名称 industry_name: 行业名称
start_date: 不再使用此参数保留是为了兼容性 start_date: 不再使用此参数保留是为了兼容性
end_date: 结束日期默认为当前日期 end_date: 结束日期默认为当前日期
use_cache: 是否使用缓存默认为True
Returns: Returns:
包含行业拥挤度指标的DataFrame 包含行业拥挤度指标的DataFrame
@ -297,25 +298,26 @@ class IndustryAnalyzer:
end_date = datetime.datetime.now().strftime('%Y-%m-%d') end_date = datetime.datetime.now().strftime('%Y-%m-%d')
# 检查缓存 # 检查缓存
cache_key = f"industry_crowding:{industry_name}" if use_cache:
cached_data = redis_client.get(cache_key) cache_key = f"industry_crowding:{industry_name}"
cached_data = redis_client.get(cache_key)
if cached_data:
try:
# 尝试解析缓存的JSON数据
cached_df_dict = json.loads(cached_data)
logger.info(f"从缓存获取行业 {industry_name} 的拥挤度数据")
# 将缓存的字典转换回DataFrame
df = pd.DataFrame(cached_df_dict)
# 确保trade_date列是日期类型
df['trade_date'] = pd.to_datetime(df['trade_date'])
return df
except Exception as cache_error:
logger.warning(f"解析缓存的拥挤度数据失败,将重新查询: {cache_error}")
if cached_data:
try:
# 尝试解析缓存的JSON数据
cached_df_dict = json.loads(cached_data)
logger.info(f"从缓存获取行业 {industry_name} 的拥挤度数据")
# 将缓存的字典转换回DataFrame
df = pd.DataFrame(cached_df_dict)
# 确保trade_date列是日期类型
df['trade_date'] = pd.to_datetime(df['trade_date'])
return df
except Exception as cache_error:
logger.warning(f"解析缓存的拥挤度数据失败,将重新查询: {cache_error}")
# 获取行业所有股票 # 获取行业所有股票
stock_codes = self.get_industry_stocks(industry_name) stock_codes = self.get_industry_stocks(industry_name)
if not stock_codes: if not stock_codes:
@ -397,15 +399,16 @@ class IndustryAnalyzer:
df_dict = df.to_dict(orient='records') df_dict = df.to_dict(orient='records')
# 缓存结果有效期1天86400秒 # 缓存结果有效期1天86400秒
try: if use_cache:
redis_client.set( try:
cache_key, redis_client.set(
json.dumps(df_dict, default=str), # 使用default=str处理日期等特殊类型 cache_key,
ex=86400 # 1天的秒数 json.dumps(df_dict, default=str), # 使用default=str处理日期等特殊类型
) ex=86400 # 1天的秒数
logger.info(f"已缓存行业 {industry_name} 的拥挤度数据有效期为1天") )
except Exception as cache_error: logger.info(f"已缓存行业 {industry_name} 的拥挤度数据有效期为1天")
logger.warning(f"缓存行业拥挤度数据失败: {cache_error}") except Exception as cache_error:
logger.warning(f"缓存行业拥挤度数据失败: {cache_error}")
logger.info(f"成功计算行业 {industry_name} 的拥挤度指标,共 {len(df)} 条记录") logger.info(f"成功计算行业 {industry_name} 的拥挤度指标,共 {len(df)} 条记录")
return df return df
@ -509,4 +512,54 @@ class IndustryAnalyzer:
except Exception as e: except Exception as e:
logger.error(f"获取行业综合分析失败: {e}") logger.error(f"获取行业综合分析失败: {e}")
return {"success": False, "message": f"获取行业综合分析失败: {e}"} return {"success": False, "message": f"获取行业综合分析失败: {e}"}
def get_stock_concepts(self, stock_code: str) -> List[str]:
"""
获取指定股票所属的概念板块列表
Args:
stock_code: 股票代码
Returns:
概念板块名称列表
"""
try:
# 转换股票代码格式
formatted_code = self._convert_stock_code_format(stock_code)
query = text("""
SELECT DISTINCT bk_name
FROM gp_gnbk
WHERE gp_code = :stock_code
""")
with self.engine.connect() as conn:
result = conn.execute(query, {"stock_code": formatted_code}).fetchall()
if result:
return [row[0] for row in result]
else:
logger.warning(f"未找到股票 {stock_code} 的概念板块数据")
return []
except Exception as e:
logger.error(f"获取股票概念板块失败: {e}")
return []
def _convert_stock_code_format(self, stock_code: str) -> str:
"""
转换股票代码格式
Args:
stock_code: 原始股票代码格式如 "600519.SH"
Returns:
转换后的股票代码格式如 "SH600519"
"""
try:
code, market = stock_code.split('.')
return f"{market}{code}"
except Exception as e:
logger.error(f"转换股票代码格式失败: {str(e)}")
return stock_code

View File

@ -0,0 +1,463 @@
"""
东方财富实时股价数据采集模块
提供从东方财富网站采集实时股价数据并存储到数据库的功能
功能包括
1. 采集实时股价数据
2. 存储数据到数据库
3. 定时自动更新数据
"""
import requests
import pandas as pd
import datetime
import logging
import time
import os
import sys
from pathlib import Path
from sqlalchemy import create_engine, text
from typing import Dict
# 添加项目根目录到Python路径
current_file = Path(__file__)
project_root = current_file.parent.parent.parent
sys.path.append(str(project_root))
from src.valuation_analysis.config import DB_URL, LOG_FILE
# 获取项目根目录
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 确保日志目录存在
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(LOG_FILE),
logging.StreamHandler()
]
)
logger = logging.getLogger("stock_price_collector")
def get_create_table_sql() -> str:
"""
获取创建实时股价数据表的SQL语句
Returns:
创建表的SQL语句
"""
return """
CREATE TABLE IF NOT EXISTS stock_price_data (
stock_code VARCHAR(10) PRIMARY KEY COMMENT '股票代码',
stock_name VARCHAR(50) COMMENT '股票名称',
latest_price DECIMAL(10,2) COMMENT '最新价',
change_percent DECIMAL(10,2) COMMENT '涨跌幅',
change_amount DECIMAL(10,2) COMMENT '涨跌额',
volume BIGINT COMMENT '成交量(手)',
amount DECIMAL(20,2) COMMENT '成交额',
amplitude DECIMAL(10,2) COMMENT '振幅',
turnover_rate DECIMAL(10,2) COMMENT '换手率',
pe_ratio DECIMAL(10,2) COMMENT '市盈率',
high_price DECIMAL(10,2) COMMENT '最高价',
low_price DECIMAL(10,2) COMMENT '最低价',
open_price DECIMAL(10,2) COMMENT '开盘价',
pre_close DECIMAL(10,2) COMMENT '昨收价',
total_market_value DECIMAL(20,2) COMMENT '总市值',
float_market_value DECIMAL(20,2) COMMENT '流通市值',
pb_ratio DECIMAL(10,2) COMMENT '市净率',
list_date DATE COMMENT '上市日期',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='实时股价数据表';
"""
class StockPriceCollector:
"""东方财富实时股价数据采集器类"""
def __init__(self, db_url: str = DB_URL):
"""
初始化东方财富实时股价数据采集器
Args:
db_url: 数据库连接URL
"""
self.engine = create_engine(
db_url,
pool_size=5,
max_overflow=10,
pool_recycle=3600
)
self.base_url = "https://push2.eastmoney.com/api/qt/clist/get"
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
"Origin": "https://quote.eastmoney.com",
"Referer": "https://quote.eastmoney.com/",
}
logger.info("东方财富实时股价数据采集器初始化完成")
def _ensure_table_exists(self) -> bool:
"""
确保数据表存在如果不存在则创建
Returns:
是否成功确保表存在
"""
try:
create_table_query = text(get_create_table_sql())
with self.engine.connect() as conn:
conn.execute(create_table_query)
conn.commit()
logger.info("实时股价数据表创建成功")
return True
except Exception as e:
logger.error(f"确保数据表存在失败: {e}")
return False
def _convert_stock_code(self, code: str) -> str:
"""
转换股票代码格式
Args:
code: 原始股票代码
Returns:
转换后的股票代码
"""
if code.startswith(('0', '3')):
return f"{code}.SZ"
else:
return f"{code}.SH"
def _parse_list_date(self, date_str: str) -> datetime.date:
"""
解析上市日期
Args:
date_str: 日期字符串
Returns:
日期对象
"""
if not date_str or date_str == '-':
return None
try:
# 如果输入是整数,先转换为字符串
if isinstance(date_str, int):
date_str = str(date_str)
return datetime.datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
logger.warning(f"无法解析日期: {date_str}")
return None
def fetch_data(self, page: int = 1) -> pd.DataFrame:
"""
获取指定页码的实时股价数据
Args:
page: 页码
Returns:
包含实时股价数据的DataFrame
"""
try:
params = {
"np": 1,
"fltt": 2,
"invt": 2,
"fs": "m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048",
"fid": "f12",
"pn": page,
"pz": 100,
"po": 0,
"dect": 1
}
logger.info(f"开始获取第 {page} 页数据")
response = requests.get(self.base_url, params=params, headers=self.headers)
if response.status_code != 200:
logger.error(f"获取第 {page} 页数据失败: HTTP {response.status_code}")
return pd.DataFrame()
data = response.json()
if not data.get("rc") == 0:
logger.error(f"获取数据失败: {data.get('message', '未知错误')}")
return pd.DataFrame()
# 提取数据列表
items = data.get("data", {}).get("diff", [])
if not items:
logger.warning(f"{page} 页未找到有效数据")
return pd.DataFrame()
# 转换为DataFrame
df = pd.DataFrame(items)
# 重命名列
column_mapping = {
"f12": "stock_code",
"f14": "stock_name",
"f2": "latest_price",
"f3": "change_percent",
"f4": "change_amount",
"f5": "volume",
"f6": "amount",
"f7": "amplitude",
"f8": "turnover_rate",
"f9": "pe_ratio",
"f15": "high_price",
"f16": "low_price",
"f17": "open_price",
"f18": "pre_close",
"f20": "total_market_value",
"f21": "float_market_value",
"f23": "pb_ratio",
"f26": "list_date"
}
df = df.rename(columns=column_mapping)
# 转换股票代码格式
df['stock_code'] = df['stock_code'].apply(self._convert_stock_code)
# 转换上市日期
df['list_date'] = df['list_date'].apply(self._parse_list_date)
logger.info(f"{page} 页数据获取成功,包含 {len(df)} 条记录")
return df
except Exception as e:
logger.error(f"获取第 {page} 页数据失败: {e}")
return pd.DataFrame()
def fetch_all_data(self) -> pd.DataFrame:
"""
获取所有页的实时股价数据
Returns:
包含所有实时股价数据的DataFrame
"""
all_data = []
page = 1
while True:
page_data = self.fetch_data(page)
if page_data.empty:
logger.info(f"{page} 页数据为空,停止采集")
break
all_data.append(page_data)
# 如果返回的数据少于100条说明是最后一页
if len(page_data) < 100:
break
page += 1
# 添加延迟,避免请求过于频繁
time.sleep(1)
if all_data:
combined_df = pd.concat(all_data, ignore_index=True)
logger.info(f"数据采集完成,共采集 {len(combined_df)} 条记录")
return combined_df
else:
logger.warning("未获取到任何有效数据")
return pd.DataFrame()
def save_to_database(self, data: pd.DataFrame) -> bool:
"""
将数据保存到数据库
Args:
data: 要保存的数据DataFrame
Returns:
是否成功保存数据
"""
if data.empty:
logger.warning("没有数据需要保存")
return False
try:
# 确保数据表存在
if not self._ensure_table_exists():
return False
data = data.replace('-', None)
# 将nan值转换为None在SQL中会变成NULL
data = data.replace({pd.NA: None, pd.NaT: None})
data = data.where(pd.notnull(data), None)
# 添加数据或更新已有数据
inserted_count = 0
updated_count = 0
with self.engine.connect() as conn:
for _, row in data.iterrows():
# 将Series转换为dict并处理nan值
row_dict = {k: (None if pd.isna(v) else v) for k, v in row.items()}
# 检查该股票的数据是否已存在
check_query = text("""
SELECT COUNT(*) FROM stock_price_data WHERE stock_code = :stock_code
""")
result = conn.execute(check_query, {"stock_code": row_dict['stock_code']}).scalar()
if result > 0: # 数据已存在,执行更新
update_query = text("""
UPDATE stock_price_data SET
stock_name = :stock_name,
latest_price = :latest_price,
change_percent = :change_percent,
change_amount = :change_amount,
volume = :volume,
amount = :amount,
amplitude = :amplitude,
turnover_rate = :turnover_rate,
pe_ratio = :pe_ratio,
high_price = :high_price,
low_price = :low_price,
open_price = :open_price,
pre_close = :pre_close,
total_market_value = :total_market_value,
float_market_value = :float_market_value,
pb_ratio = :pb_ratio,
list_date = :list_date
WHERE stock_code = :stock_code
""")
conn.execute(update_query, row_dict)
updated_count += 1
else: # 数据不存在,执行插入
insert_query = text("""
INSERT INTO stock_price_data (
stock_code, stock_name, latest_price, change_percent,
change_amount, volume, amount, amplitude, turnover_rate,
pe_ratio, high_price, low_price, open_price, pre_close,
total_market_value, float_market_value, pb_ratio, list_date
) VALUES (
:stock_code, :stock_name, :latest_price, :change_percent,
:change_amount, :volume, :amount, :amplitude, :turnover_rate,
:pe_ratio, :high_price, :low_price, :open_price, :pre_close,
:total_market_value, :float_market_value, :pb_ratio, :list_date
)
""")
conn.execute(insert_query, row_dict)
inserted_count += 1
conn.commit()
logger.info(f"数据保存成功:新增 {inserted_count} 条记录,更新 {updated_count} 条记录")
return True
except Exception as e:
logger.error(f"保存数据到数据库失败: {e}")
return False
def update_latest_data(self) -> bool:
"""
更新最新实时股价数据
Returns:
是否成功更新最新数据
"""
try:
logger.info("开始更新最新实时股价数据")
# 获取所有数据
df = self.fetch_all_data()
if df.empty:
logger.warning("未获取到最新数据")
return False
# 保存数据到数据库
result = self.save_to_database(df)
if result:
logger.info(f"最新数据更新成功,共更新 {len(df)} 条记录")
else:
logger.warning("最新数据更新失败")
return result
except Exception as e:
logger.error(f"更新最新数据失败: {e}")
return False
def get_stock_price_data(self, stock_code: str, convert_code: bool = False) -> Dict:
"""
获取指定股票的最新价格数据
Args:
stock_code: 股票代码
convert_code: 是否需要转换股票代码格式默认为False
Returns:
包含股票价格数据的字典
"""
try:
# 转换股票代码格式
formatted_code = self._convert_stock_code(stock_code) if convert_code else stock_code
query = text("""
SELECT
stock_code,
stock_name,
latest_price,
change_percent,
change_amount,
volume,
amount,
amplitude,
turnover_rate,
pe_ratio,
high_price,
low_price,
total_market_value,
float_market_value,
pb_ratio,
list_date,
update_time
FROM
stock_price_data
WHERE
stock_code = :stock_code
""")
with self.engine.connect() as conn:
result = conn.execute(query, {"stock_code": formatted_code}).fetchone()
if result:
# 将结果转换为字典
data = dict(result._mapping)
# 处理日期类型
if data['list_date']:
data['list_date'] = data['list_date'].strftime('%Y-%m-%d')
if data['update_time']:
data['update_time'] = data['update_time'].strftime('%Y-%m-%d %H:%M:%S')
return data
else:
logger.warning(f"未找到股票 {stock_code} 的价格数据")
return None
except Exception as e:
logger.error(f"获取股票价格数据失败: {e}")
return None
# 示例使用方式
if __name__ == "__main__":
# 创建实时股价数据采集器
collector = StockPriceCollector()
# 更新最新数据
logger.info("开始更新最新实时股价数据...")
collector.update_latest_data()