diff --git a/README.md b/README.md index a3cd868..c3bf6f5 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,127 @@ pip install -r requirements.txt 筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。 +## Docker 部署说明 + +本系统支持通过 Docker 进行部署,可以单实例部署或多实例部署。 + +### 1. 单实例部署 + +```bash +# 构建镜像 +docker-compose build + +# 启动服务 +docker-compose up -d +``` + +### 2. 多实例部署 + +提供了三个脚本用于管理多实例部署,每个实例将在不同端口上运行(从5088开始递增): + +#### 使用 docker run 部署多实例(推荐) + +```bash +# 赋予脚本执行权限 +chmod +x deploy-multiple.sh + +# 部署3个实例,端口将为5088、5089、5090 +./deploy-multiple.sh 3 +``` + +每个实例将有自己独立的数据目录:`./instances/instance-N/`,包含独立的日志和报告文件。 + +#### 使用 docker-compose 部署多实例 + +```bash +# 赋予脚本执行权限 +chmod +x deploy-compose.sh + +# 部署3个实例,端口将为5088、5089、5090 +./deploy-compose.sh 3 +``` + +每个实例将使用独立的项目名称 `stock-app-N`。 + +#### 管理多实例 + +使用 `manage-instances.sh` 脚本可以方便地管理已部署的实例: + +```bash +# 赋予脚本执行权限 +chmod +x manage-instances.sh + +# 查看所有实例状态 +./manage-instances.sh status + +# 列出正在运行的实例 +./manage-instances.sh list + +# 启动指定实例 +./manage-instances.sh start 2 + +# 启动所有实例 +./manage-instances.sh start all + +# 停止指定实例 +./manage-instances.sh stop 1 + +# 停止所有实例 +./manage-instances.sh stop all + +# 重启指定实例 +./manage-instances.sh restart 3 + +# 查看指定实例的日志 +./manage-instances.sh logs 1 + +# 删除指定实例 +./manage-instances.sh remove 2 + +# 删除所有实例 +./manage-instances.sh remove all +``` + +### 3. 中文字体安装 + +系统需要中文字体以正常生成PDF报告。如果部署时遇到 `找不到中文字体文件` 错误,可以使用以下方法解决: + +#### 自动安装字体(推荐) + +```bash +# 进入容器内部 +docker exec -it [容器ID或名称] bash + +# 在容器内执行字体安装脚本 +cd /app +python src/fundamentals_llm/setup_fonts.py +``` + +#### 手动安装字体 + +1. 在宿主机上下载中文字体文件(如simhei.ttf、wqy-microhei.ttc等) +2. 创建字体目录并复制字体文件: +```bash +mkdir -p src/fundamentals_llm/fonts +cp 你下载的字体文件.ttf src/fundamentals_llm/fonts/simhei.ttf +``` +3. 重新构建镜像: +```bash +docker-compose build +``` + +### 4. 访问服务 + +部署完成后,可以通过以下URL访问服务: + +- 单实例:`http://服务器IP:5000` +- 多实例:`http://服务器IP:508X`(X为实例序号,从8开始) + +例如: +- 实例1:`http://服务器IP:5088` +- 实例2:`http://服务器IP:5089` +- 实例3:`http://服务器IP:5090` + ## 数据库查询示例 以下SQL查询示例展示了如何从数据库中筛选特定类型的企业: @@ -250,4 +371,12 @@ AND stock_code IN ( -- ... 其他条件 ``` -注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。 \ No newline at end of file +注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。 + + + +# 重新构建镜像 +docker-compose build + +# 重启所有实例 +./manage-instances.sh restart all \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8e855f4..3470caa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ volcengine-python-sdk[ark] openai>=1.0 reportlab>=4.3.1 markdown2>=2.5.3 -# 1. 按下 Win+R ,输入 regedit 打开注册表编辑器。 -# 2. 设置 \HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem 路径下的变量 LongPathsEnabled 为 1 即可。 \ No newline at end of file +google-genai +redis==5.2.1 \ No newline at end of file diff --git a/src/app.py b/src/app.py index 840d279..c746420 100644 --- a/src/app.py +++ b/src/app.py @@ -1,27 +1,47 @@ import sys import os +from datetime import datetime, timedelta +import pandas as pd +import uuid +import json +from threading import Thread from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db # 添加项目根目录到 Python 路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from flask import Flask, jsonify, request +from flask import Flask, jsonify, request, send_from_directory from flask_cors import CORS import logging # 导入企业筛选器 from src.fundamentals_llm.enterprise_screener import EnterpriseScreener +# 导入股票回测器 +from src.stock_analysis_v2 import run_backtest, StockBacktester + # 设置日志 logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), # 输出到控制台 + logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8') # 输出到文件 + ] ) + +# 确保logs和results目录存在 +os.makedirs('logs', exist_ok=True) +os.makedirs('results', exist_ok=True) +os.makedirs('results/tasks', exist_ok=True) +os.makedirs('static/results', exist_ok=True) + logger = logging.getLogger(__name__) +logger.info("Flask应用启动") # 创建 Flask 应用 -app = Flask(__name__) +app = Flask(__name__, static_folder='static') CORS(app) # 启用跨域请求支持 # 创建企业筛选器实例 @@ -30,6 +50,329 @@ screener = EnterpriseScreener() # 获取数据库连接 db = next(get_db()) +# 获取项目根目录 +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports') + +# 确保reports目录存在 +os.makedirs(REPORTS_DIR, exist_ok=True) +logger.info(f"报告目录路径: {REPORTS_DIR}") + +# 存储回测任务状态的字典 +backtest_tasks = {} + +def run_backtest_task(task_id, stocks_buy_dates, end_date): + """ + 在后台运行回测任务 + """ + try: + logger.info(f"开始执行回测任务 {task_id}") + # 更新任务状态为进行中 + backtest_tasks[task_id]['status'] = 'running' + + # 运行回测 + results, stats_list = run_backtest(stocks_buy_dates, end_date) + + # 如果回测成功 + if results and stats_list: + # 计算总体统计 + stats_df = pd.DataFrame(stats_list) + total_profit = stats_df['final_profit'].sum() + avg_win_rate = stats_df['win_rate'].mean() + avg_holding_days = stats_df['avg_holding_days'].mean() + + # 找出最佳止盈比例(假设为0.15,实际应从回测结果中分析得出) + best_take_profit_pct = 0.15 + + # 获取图表URL + chart_urls = { + "all_stocks": f"/static/results/{task_id}/all_stocks_analysis.png", + "profit_matrix": f"/static/results/{task_id}/profit_matrix_analysis.png" + } + + # 保存股票详细统计 + stock_stats = [] + for stat in stats_list: + stock_stats.append({ + "symbol": stat['symbol'], + "total_trades": int(stat['total_trades']), + "profitable_trades": int(stat['profitable_trades']), + "loss_trades": int(stat['loss_trades']), + "win_rate": float(stat['win_rate']), + "avg_holding_days": float(stat['avg_holding_days']), + "final_profit": float(stat['final_profit']), + "entry_count": int(stat['entry_count']) + }) + + # 构建结果数据 + result_data = { + "task_id": task_id, + "status": "completed", + "results": { + "total_profit": float(total_profit), + "win_rate": float(avg_win_rate), + "avg_holding_days": float(avg_holding_days), + "best_take_profit_pct": best_take_profit_pct, + "stock_stats": stock_stats, + "chart_urls": chart_urls + } + } + + # 保存结果到文件 + task_result_path = os.path.join('results', 'tasks', f"{task_id}.json") + with open(task_result_path, 'w', encoding='utf-8') as f: + json.dump(result_data, f, ensure_ascii=False, indent=2) + + # 更新任务状态为已完成 + backtest_tasks[task_id].update({ + 'status': 'completed', + 'results': result_data['results'] + }) + + logger.info(f"回测任务 {task_id} 已完成") + else: + # 更新任务状态为失败 + backtest_tasks[task_id]['status'] = 'failed' + backtest_tasks[task_id]['error'] = "回测未产生有效结果" + logger.error(f"回测任务 {task_id} 失败:回测未产生有效结果") + except Exception as e: + # 更新任务状态为失败 + backtest_tasks[task_id]['status'] = 'failed' + backtest_tasks[task_id]['error'] = str(e) + logger.error(f"回测任务 {task_id} 失败:{str(e)}") + +@app.route('/api/backtest/run', methods=['POST']) +def start_backtest(): + """启动回测任务 + + 请求体格式: + { + "stocks_buy_dates": { + "SH600522": ["2022-05-10", "2022-06-10"], // 股票代码: [买入日期列表] + "SZ002340": ["2022-06-15"], + "SH601615": ["2022-07-20", "2022-08-01"] + }, + "end_date": "2022-10-20" // 所有股票共同的结束日期 + } + + 返回内容: + { + "task_id": "backtask-khkerhy4u237y489237489truiy8432" + } + """ + try: + # 从请求体获取参数 + data = request.get_json() + + if not data: + return jsonify({ + "status": "error", + "message": "请求格式错误: 需要提供JSON数据" + }), 400 + + stocks_buy_dates = data.get('stocks_buy_dates') + end_date = data.get('end_date') + + if not stocks_buy_dates or not isinstance(stocks_buy_dates, dict): + return jsonify({ + "status": "error", + "message": "请求格式错误: 需要提供stocks_buy_dates字典" + }), 400 + + if not end_date or not isinstance(end_date, str): + return jsonify({ + "status": "error", + "message": "请求格式错误: 需要提供有效的end_date" + }), 400 + + # 验证日期格式 + try: + datetime.strptime(end_date, '%Y-%m-%d') + for stock_code, buy_dates in stocks_buy_dates.items(): + if not isinstance(buy_dates, list): + return jsonify({ + "status": "error", + "message": f"请求格式错误: 股票 {stock_code} 的买入日期必须是列表" + }), 400 + for buy_date in buy_dates: + datetime.strptime(buy_date, '%Y-%m-%d') + except ValueError as e: + return jsonify({ + "status": "error", + "message": f"日期格式错误: {str(e)}" + }), 400 + + # 生成任务ID + task_id = f"backtask-{uuid.uuid4().hex[:16]}" + + # 创建任务目录 + task_dir = os.path.join('static', 'results', task_id) + os.makedirs(task_dir, exist_ok=True) + + # 记录任务信息 + backtest_tasks[task_id] = { + 'status': 'pending', + 'created_at': datetime.now().isoformat(), + 'stocks_buy_dates': stocks_buy_dates, + 'end_date': end_date + } + + # 创建线程运行回测 + thread = Thread(target=run_backtest_task, args=(task_id, stocks_buy_dates, end_date)) + thread.daemon = True + thread.start() + + logger.info(f"已创建回测任务: {task_id}") + + return jsonify({ + "task_id": task_id + }) + + except Exception as e: + logger.error(f"创建回测任务失败: {str(e)}") + return jsonify({ + "status": "error", + "message": f"创建回测任务失败: {str(e)}" + }), 500 + +@app.route('/api/backtest/status', methods=['GET']) +def check_backtest_status(): + """查询回测任务状态 + + 参数: + - task_id: 回测任务ID + + 返回内容: + { + "task_id": "backtask-khkerhy4u237y489237489truiy8432", + "status": "running" | "completed" | "failed", + "created_at": "2023-12-01T10:30:45", + "error": "错误信息(如有)" + } + """ + try: + task_id = request.args.get('task_id') + + if not task_id: + return jsonify({ + "status": "error", + "message": "请求格式错误: 需要提供task_id参数" + }), 400 + + # 检查任务是否存在 + if task_id not in backtest_tasks: + return jsonify({ + "status": "error", + "message": f"任务不存在: {task_id}" + }), 404 + + # 获取任务信息 + task_info = backtest_tasks[task_id] + + # 构建响应 + response = { + "task_id": task_id, + "status": task_info['status'], + "created_at": task_info['created_at'] + } + + # 如果任务失败,添加错误信息 + if task_info['status'] == 'failed' and 'error' in task_info: + response['error'] = task_info['error'] + + return jsonify(response) + + except Exception as e: + logger.error(f"查询任务状态失败: {str(e)}") + return jsonify({ + "status": "error", + "message": f"查询任务状态失败: {str(e)}" + }), 500 + +@app.route('/api/backtest/result', methods=['GET']) +def get_backtest_result(): + """获取回测任务结果 + + 参数: + - task_id: 回测任务ID + + 返回内容: + { + "task_id": "backtask-2023121-001", + "status": "completed", + "results": { + "total_profit": 123456.78, + "win_rate": 75.5, + "avg_holding_days": 12.3, + "best_take_profit_pct": 0.15, + "stock_stats": [...], + "chart_urls": { + "all_stocks": "/static/results/backtask-2023121-001/all_stocks_analysis.png", + "profit_matrix": "/static/results/backtask-2023121-001/profit_matrix_analysis.png" + } + } + } + """ + try: + task_id = request.args.get('task_id') + + if not task_id: + return jsonify({ + "status": "error", + "message": "请求格式错误: 需要提供task_id参数" + }), 400 + + # 检查任务是否存在 + if task_id not in backtest_tasks: + # 尝试从文件加载 + task_result_path = os.path.join('results', 'tasks', f"{task_id}.json") + if os.path.exists(task_result_path): + with open(task_result_path, 'r', encoding='utf-8') as f: + result_data = json.load(f) + return jsonify(result_data) + else: + return jsonify({ + "status": "error", + "message": f"任务不存在: {task_id}" + }), 404 + + # 获取任务信息 + task_info = backtest_tasks[task_id] + + # 检查任务是否完成 + if task_info['status'] != 'completed': + return jsonify({ + "status": "error", + "message": f"任务尚未完成或已失败: {task_info['status']}" + }), 400 + + # 构建响应 + response = { + "task_id": task_id, + "status": "completed", + "results": task_info['results'] + } + + return jsonify(response) + + except Exception as e: + logger.error(f"获取任务结果失败: {str(e)}") + return jsonify({ + "status": "error", + "message": f"获取任务结果失败: {str(e)}" + }), 500 + +@app.route('/api/reports/') +def serve_report(filename): + """提供PDF报告的访问""" + try: + logger.info(f"请求文件: {filename}") + logger.info(f"从目录: {REPORTS_DIR} 提供文件") + return send_from_directory(REPORTS_DIR, filename, as_attachment=True) + except Exception as e: + logger.error(f"提供报告文件失败: {str(e)}") + return jsonify({"error": "文件不存在"}), 404 + @app.route('/api/health', methods=['GET']) def health_check(): """健康检查接口""" @@ -185,10 +528,10 @@ def generate_reports(): # 导入 PDF 生成器模块 try: - from src.fundamentals_llm.pdf_generator import generate_investment_report + from src.fundamentals_llm.pdf_generator import PDFGenerator except ImportError: try: - from fundamentals_llm.pdf_generator import generate_investment_report + from fundamentals_llm.pdf_generator import PDFGenerator except ImportError as e: logger.error(f"无法导入 PDF 生成器模块: {str(e)}") return jsonify({ @@ -200,8 +543,14 @@ def generate_reports(): generated_reports = [] for stock_code, stock_name in stocks: try: + # 创建 PDF 生成器实例 + generator = PDFGenerator() # 调用 PDF 生成器 - report_path = generate_investment_report(stock_code, stock_name) + report_path = generator.generate_pdf( + title=f"{stock_name}({stock_code}) 基本面分析报告", + content_dict={}, # 这里需要传入实际的内容字典 + filename=f"{stock_name}_{stock_code}_analysis.pdf" + ) generated_reports.append({ "code": stock_code, "name": stock_name, @@ -505,9 +854,10 @@ def analyze_and_recommend(): "message": f"分析和推荐股票失败: {str(e)}" }), 500 + @app.route('/api/comprehensive_analysis', methods=['POST']) def comprehensive_analysis(): - """综合分析接口 - 组合多种功能和参数 + """综合分析接口 - 使用队列方式处理被锁定的股票 请求体格式: { @@ -558,36 +908,9 @@ def comprehensive_analysis(): if not isinstance(limit, int) or limit <= 0: limit = 10 - # 导入必要的聊天机器人模块 + # 导入必要的模块 try: - # 首先尝试导入聊天机器人模块 - try: - from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot - logger.info("成功从 src.fundamentals_llm.chat_bot 导入 ChatBot") - except ImportError as e1: - try: - from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot - logger.info("成功从 fundamentals_llm.chat_bot 导入 ChatBot") - except ImportError as e2: - logger.error(f"无法导入在线聊天机器人模块: {str(e1)}, {str(e2)}") - return jsonify({ - "status": "error", - "message": f"服务器配置错误: 聊天机器人模块不可用,错误详情: {str(e2)}" - }), 500 - - # 然后尝试导入离线聊天机器人模块 - try: - from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot - logger.info("成功从 src.fundamentals_llm.chat_bot_with_offline 导入 ChatBot") - except ImportError as e1: - try: - from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot - logger.info("成功从 fundamentals_llm.chat_bot_with_offline 导入 ChatBot") - except ImportError as e2: - logger.warning(f"无法导入离线聊天机器人模块: {str(e1)}, {str(e2)}") - # 这里可以继续执行,因为某些功能可能不需要离线模型 - - # 最后导入基本面分析器 + # 先导入基本面分析器 try: from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer") @@ -601,6 +924,19 @@ def comprehensive_analysis(): "status": "error", "message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}" }), 500 + + # 再导入其他可能需要的模块 + try: + from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot + from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot + except ImportError: + try: + from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot + from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot + except ImportError: + # 这些模块不是必须的,所以继续执行 + logger.warning("无法导入聊天机器人模块,但这不会影响基本功能") + except Exception as e: logger.error(f"导入必要模块时出错: {str(e)}") return jsonify({ @@ -611,79 +947,164 @@ def comprehensive_analysis(): # 创建基本面分析器实例 analyzer = FundamentalAnalyzer() - # 为每个股票生成投资建议 - investment_advices = [] - for stock_code, stock_name in stocks: - try: - # 生成投资建议 - success, advice, reasoning, references = analyzer.query_analysis( - stock_code, stock_name, "investment_advice" - ) - - if success: - investment_advices.append({ - "code": stock_code, - "name": stock_name, - "advice": advice, - "reasoning": reasoning, - "references": references, - "status": "success" - }) - logger.info(f"成功生成 {stock_name}({stock_code}) 的投资建议") - else: - investment_advices.append({ + # 准备结果容器 + investment_advices = {} # 使用字典,股票代码作为键 + processing_queue = list(stocks) # 初始处理队列 + max_attempts = 5 # 最大重试次数 + total_attempts = 0 + + # 导入数据库模块 + from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db + + # 开始处理队列 + while processing_queue and total_attempts < max_attempts: + total_attempts += 1 + logger.info(f"开始第 {total_attempts} 轮处理,队列中有 {len(processing_queue)} 只股票") + + # 暂存下一轮需要处理的股票 + next_round_queue = [] + + # 处理当前队列中的所有股票 + for stock_code, stock_name in processing_queue: + try: + # 检查是否已有分析结果 + db = next(get_db()) + existing_result = get_analysis_result(db, stock_code, "investment_advice") + + # 如果已有近期结果,直接使用 + if existing_result and existing_result.update_time > datetime.now() - timedelta(hours=12): + investment_advices[stock_code] = { + "code": stock_code, + "name": stock_name, + "advice": existing_result.ai_response, + "reasoning": existing_result.reasoning_process, + "references": existing_result.references, + "status": "success", + "from_cache": True + } + logger.info(f"使用缓存的 {stock_name}({stock_code}) 分析结果") + continue + + # 检查是否被锁定 + if analyzer.is_stock_locked(stock_code, "investment_advice"): + # 已被锁定,放到下一轮队列 + next_round_queue.append([stock_code, stock_name]) + # 记录状态 + if stock_code not in investment_advices: + investment_advices[stock_code] = { + "code": stock_code, + "name": stock_name, + "status": "pending", + "message": f"股票 {stock_code} 正在被其他请求分析中,已加入等待队列" + } + logger.info(f"股票 {stock_name}({stock_code}) 已被锁定,放入下一轮队列") + continue + + # 尝试锁定并分析 + analyzer.lock_stock(stock_code, "investment_advice") + try: + # 执行分析 + success, advice, reasoning, references = analyzer.query_analysis( + stock_code, stock_name, "investment_advice" + ) + + # 记录结果 + if success: + investment_advices[stock_code] = { + "code": stock_code, + "name": stock_name, + "advice": advice, + "reasoning": reasoning, + "references": references, + "status": "success" + } + logger.info(f"成功分析 {stock_name}({stock_code})") + else: + investment_advices[stock_code] = { + "code": stock_code, + "name": stock_name, + "status": "error", + "error": advice or "分析失败,无详细信息" + } + logger.error(f"分析 {stock_name}({stock_code}) 失败: {advice}") + finally: + # 确保释放锁 + analyzer.unlock_stock(stock_code, "investment_advice") + + except Exception as e: + # 处理异常 + investment_advices[stock_code] = { "code": stock_code, "name": stock_name, "status": "error", - "error": advice - }) - logger.error(f"生成 {stock_name}({stock_code}) 的投资建议失败: {advice}") - except Exception as e: - logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}") - investment_advices.append({ - "code": stock_code, - "name": stock_name, - "status": "error", - "error": str(e) + "error": str(e) + } + logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}") + # 确保释放锁 + try: + analyzer.unlock_stock(stock_code, "investment_advice") + except: + pass + + # 如果还有下一轮要处理的股票,等待一段时间后继续 + if next_round_queue: + logger.info(f"本轮结束,还有 {len(next_round_queue)} 只股票等待下一轮处理") + # 等待30秒再处理下一轮 + import time + time.sleep(30) + processing_queue = next_round_queue + else: + # 所有股票都已处理完,退出循环 + logger.info("所有股票处理完毕") + processing_queue = [] + + # 处理仍在队列中的股票(达到最大重试次数但仍未处理的) + for stock_code, stock_name in processing_queue: + if stock_code in investment_advices and investment_advices[stock_code]["status"] == "pending": + investment_advices[stock_code].update({ + "status": "timeout", + "message": f"等待超时,股票 {stock_code} 可能正在被长时间分析" }) + logger.warning(f"股票 {stock_name}({stock_code}) 分析超时") + + # 将字典转换为列表 + investment_advice_list = list(investment_advices.values()) # 生成PDF报告(如果需要) pdf_results = [] if generate_pdf: try: - # 导入PDF生成器模块 - try: - from src.fundamentals_llm.pdf_generator import generate_investment_report - except ImportError: - try: - from fundamentals_llm.pdf_generator import generate_investment_report - except ImportError as e: - logger.error(f"无法导入 PDF 生成器模块: {str(e)}") - return jsonify({ - "status": "error", - "message": f"服务器配置错误: PDF 生成器模块不可用, 错误详情: {str(e)}" - }), 500 - - # 生成报告 - for stock_code, stock_name in stocks: - try: - # 调用 PDF 生成器 - report_path = generate_investment_report(stock_code, stock_name) - pdf_results.append({ - "code": stock_code, - "name": stock_name, - "report_path": report_path, - "status": "success" - }) - logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}") - except Exception as e: - logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}") - pdf_results.append({ - "code": stock_code, - "name": stock_name, - "status": "error", - "error": str(e) - }) + # 针对已成功分析的股票生成PDF + for stock_info in investment_advice_list: + if stock_info["status"] == "success": + stock_code = stock_info["code"] + stock_name = stock_info["name"] + try: + report_path = analyzer.generate_pdf_report(stock_code, stock_name) + if report_path: + pdf_results.append({ + "code": stock_code, + "name": stock_name, + "report_path": report_path, + "status": "success" + }) + logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}") + else: + pdf_results.append({ + "code": stock_code, + "name": stock_name, + "status": "error", + "error": "生成报告失败" + }) + logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败") + except Exception as e: + logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}") + pdf_results.append({ + "code": stock_code, + "name": stock_name, + "status": "error", + "error": str(e) + }) except Exception as e: logger.error(f"处理PDF生成请求失败: {str(e)}") pdf_results = [] @@ -784,11 +1205,24 @@ def comprehensive_analysis(): logger.error(f"应用企业画像筛选失败: {str(e)}") filtered_stocks = [] + # 统计各种状态的股票数量 + success_count = sum(1 for item in investment_advice_list if item["status"] == "success") + pending_count = sum(1 for item in investment_advice_list if item["status"] == "pending") + timeout_count = sum(1 for item in investment_advice_list if item["status"] == "timeout") + error_count = sum(1 for item in investment_advice_list if item["status"] == "error") + # 返回结果 response = { - "status": "success", + "status": "success" if success_count > 0 else "partial_success" if success_count + pending_count > 0 else "failed", "total_input_stocks": len(stocks), - "investment_advices": investment_advices + "stats": { + "success": success_count, + "pending": pending_count, + "timeout": timeout_count, + "error": error_count + }, + "rounds_attempted": total_attempts, + "investment_advices": investment_advice_list } if profile_filter: diff --git a/src/fundamentals_llm/chat_bot.py b/src/fundamentals_llm/chat_bot.py index 5ea60ab..2ced445 100644 --- a/src/fundamentals_llm/chat_bot.py +++ b/src/fundamentals_llm/chat_bot.py @@ -83,37 +83,37 @@ class ChatBot: "role": "system", "content": """你是一个专业的股票分析助手,擅长进行深入的基本面分析。你的分析应该: -1. 专业严谨 -- 使用准确的专业术语 -- 引用可靠的数据来源 -- 分析逻辑清晰 -- 结论有理有据 - -2. 全面细致 -- 深入分析问题的各个方面 -- 关注细节和关键信息 -- 考虑多个影响因素 -- 提供详实的论据支持 - -3. 客观中立 -- 保持独立判断 -- 不夸大或贬低 -- 平衡利弊分析 -- 指出潜在风险 - -4. 实用性强 -- 分析结论具体明确 -- 建议具有可操作性 -- 关注实际投资价值 -- 提供清晰的决策参考 - -5. 及时更新 -- 关注最新信息 -- 反映市场变化 -- 动态调整分析 -- 保持信息时效性 - -请根据用户的具体需求,提供专业、深入的分析。如果遇到不确定的信息,请明确说明。""" + 1. 专业严谨 + - 使用准确的专业术语 + - 引用可靠的数据来源 + - 分析逻辑清晰 + - 结论有理有据 + + 2. 全面细致 + - 深入分析问题的各个方面 + - 关注细节和关键信息 + - 考虑多个影响因素 + - 提供详实的论据支持 + + 3. 客观中立 + - 保持独立判断 + - 不夸大或贬低 + - 平衡利弊分析 + - 指出潜在风险 + + 4. 实用性强 + - 分析结论具体明确 + - 建议具有可操作性 + - 关注实际投资价值 + - 提供清晰的决策参考 + + 5. 及时更新 + - 关注最新信息 + - 反映市场变化 + - 动态调整分析 + - 保持信息时效性 + + 请根据用户的具体需求,提供专业、深入的分析。如果遇到不确定的信息,请明确说明。""" } # 对话历史 @@ -150,11 +150,15 @@ class ChatBot: logger.error(f"格式化参考资料时出错: {str(e)}") return str(ref) - def chat(self, user_input: str) -> Dict[str, Any]: + def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]: """与AI进行对话 Args: user_input: 用户输入的问题 + temperature: 控制输出的随机性,范围0-2,默认1.0 + top_p: 控制输出的多样性,范围0-1,默认0.7 + max_tokens: 控制输出的最大长度,默认4096 + frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 Returns: Dict[str, Any]: 包含以下字段的字典: @@ -173,6 +177,10 @@ class ChatBot: stream = self.client.chat.completions.create( model=self.model, messages=self.conversation_history, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + frequency_penalty=frequency_penalty, stream=True ) diff --git a/src/fundamentals_llm/chat_bot_with_gemini.py b/src/fundamentals_llm/chat_bot_with_gemini.py new file mode 100644 index 0000000..039daf4 --- /dev/null +++ b/src/fundamentals_llm/chat_bot_with_gemini.py @@ -0,0 +1,420 @@ +import logging +import os +import sys +import json +from typing import Dict, List, Optional, Any, Union +from pydantic import BaseModel, Field +from google import genai +from google.genai.types import Tool, GenerateContentConfig, GoogleSearch + +# 设置日志记录 +logger = logging.getLogger(__name__) + +# 定义基础数据结构 +class TextAnalysisResult(BaseModel): + """文本分析结果,包含分析正文、推理过程和引用URL""" + analysis_text: str = Field(description="详细的分析文本") + reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None) + references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None) + +class NumericalAnalysisResult(BaseModel): + """数值分析结果,包含数值和分析描述""" + value: str = Field(description="评估值") + description: str = Field(description="评估描述") + +# 分析类型到模型类的映射 +ANALYSIS_TYPE_MODEL_MAP = { + # 文本分析类型 + "text_analysis_result": TextAnalysisResult, + + # 数值分析类型 + "numerical_analysis_result": NumericalAnalysisResult +} + +class JsonOutputParser: + """解析JSON输出的工具类""" + def __init__(self, pydantic_object): + """初始化解析器 + + Args: + pydantic_object: 需要解析成的Pydantic模型类 + """ + self.pydantic_object = pydantic_object + + def get_format_instructions(self) -> str: + """获取格式化指令""" + schema = self.pydantic_object.schema() + schema_str = json.dumps(schema, ensure_ascii=False, indent=2) + + return f"""请将你的回答格式化为符合以下JSON结构的格式: +{schema_str} + +你的回答应该只包含一个有效的JSON对象,而不包含其他内容。请不要包含任何解释或前缀,仅输出JSON本身。 +""" + + def parse(self, text: str) -> Any: + """解析文本为对象""" + try: + # 尝试解析JSON + text = text.strip() + # 如果文本以```json开头且以```结尾,则提取中间部分 + if text.startswith("```json") and text.endswith("```"): + text = text[7:-3].strip() + # 如果文本以```开头且以```结尾,则提取中间部分 + elif text.startswith("```") and text.endswith("```"): + text = text[3:-3].strip() + + json_obj = json.loads(text) + return self.pydantic_object.parse_obj(json_obj) + except Exception as e: + raise ValueError(f"无法解析为JSON或无法匹配模式: {e}") + +# 获取项目根目录的绝对路径 +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# 尝试导入配置 +try: + # 添加项目根目录到 Python 路径 + sys.path.append(os.path.dirname(ROOT_DIR)) + sys.path.append(ROOT_DIR) + + # 导入配置 + try: + from scripts.config import get_model_config + + logger.info("成功从scripts.config导入配置") + except ImportError: + try: + from src.scripts.config import get_model_config + + logger.info("成功从src.scripts.config导入配置") + except ImportError: + logger.warning("无法导入配置模块,使用默认配置") + + + # 使用默认配置的实现 + def get_model_config(platform: str, model_type: str): + return { + "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", + "model": "gemini-2.0-flash" + } +except Exception as e: + logger.error(f"导入配置时出错: {str(e)},使用默认配置") + + + # 使用默认配置的实现 + def get_model_config(platform: str, model_type: str): + return { + "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", + "model": "gemini-2.0-flash" + } + +class ChatBot: + def __init__(self, platform: str = "Gemini", model_type: str = "gemini-2.0-flash"): + """初始化聊天机器人 + + Args: + platform: 平台名称,默认为Gemini + model_type: 要使用的模型类型,默认为gemini-2.0-flash + """ + try: + # 从配置获取API配置 + config = get_model_config(platform, model_type) + self.api_key = config["api_key"] + self.model = config["model"] # 直接使用配置中的模型名称 + + logger.info(f"初始化ChatBot,使用平台: {platform}, 模型: {self.model}") + + google_search_tool = Tool( + google_search=GoogleSearch() + ) + + self.client = genai.Client(api_key=self.api_key) + + # 系统提示语 + self.system_instruction = """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下: + + 1. 分析风格: + - 专业、客观、理性 + - 注重数据支撑 + - 关注风险控制 + - 重视投资性价比 + + 2. 分析框架: + - 公司基本面分析 + - 行业竞争格局 + - 估值水平评估 + - 风险因素识别 + - 投资机会判断 + + 3. 输出要求: + - 简明扼要,重点突出 + - 逻辑清晰,层次分明 + - 数据准确,论据充分 + - 结论明确,建议具体 + + 请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" + + # 对话历史 + self.conversation_history = [] + except Exception as e: + logger.error(f"初始化ChatBot时出错: {str(e)}") + raise + + def chat(self, user_input: str, analysis_type: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096) -> dict: + """处理用户输入并返回AI回复 + + Args: + user_input: 用户输入的问题 + analysis_type: 分析类型,如company_profile或business_scale等 + temperature: 控制输出的随机性,范围0-2,默认0.7 + top_p: 控制输出的多样性,范围0-1,默认0.7 + max_tokens: 控制输出的最大长度,默认4096 + + Returns: + dict: 包含AI回复的字典 + """ + try: + from google.genai import types + google_search_tool = Tool( + google_search=GoogleSearch() + ) + + # 确定分析类型和模型 + model_class = None + if analysis_type and analysis_type in ANALYSIS_TYPE_MODEL_MAP: + model_class = ANALYSIS_TYPE_MODEL_MAP[analysis_type] + parser = JsonOutputParser(pydantic_object=model_class) + format_instructions = parser.get_format_instructions() + + # 在原有基础上添加格式指令 + combined_instruction = f"{self.system_instruction}\n\n{format_instructions}" + else: + combined_instruction = self.system_instruction + parser = None + + response = self.client.models.generate_content_stream( + model=self.model, + config=GenerateContentConfig( + system_instruction=combined_instruction, + tools=[google_search_tool], + response_modalities=["TEXT"], + temperature=temperature, + top_p=top_p, + max_output_tokens=max_tokens + ), + contents=user_input + ) + + full_response = "" + for chunk in response: + full_response += chunk.text + print(chunk.text, end="") + + # 如果指定了分析类型,尝试解析输出 + if parser: + try: + # 检查分析类型是文本分析还是数值分析 + is_text_analysis = model_class == TextAnalysisResult + is_numerical_analysis = model_class == NumericalAnalysisResult + + # 使用解析方法尝试处理JSON响应 + json_result = self.parse_json_response(full_response, model_class) + if json_result: + return json_result + + # 如果使用解析方法失败,尝试使用parser解析 + parsed_result = parser.parse(full_response) + structured_result = parsed_result.dict() + + if isinstance(parsed_result, TextAnalysisResult): + # 文本分析结果 + return { + "response": structured_result["analysis_text"], + "reasoning_process": structured_result.get("reasoning_process"), + "references": structured_result.get("references") + } + elif isinstance(parsed_result, NumericalAnalysisResult): + # 数值分析结果 + return { + "response": structured_result["value"], + "description": structured_result.get("description", "") + } + except Exception as e: + logger.warning(f"解析响应失败: {str(e)},返回原始响应") + + # 最后尝试根据分析类型直接构建结果 + if is_text_analysis: + return { + "response": full_response, + "reasoning_process": None, + "references": None + } + elif is_numerical_analysis: + try: + # 尝试从文本中提取数字 + import re + number_match = re.search(r'\b(\d+)\b', full_response) + value = int(number_match.group(1)) if number_match else 0 + return { + "response": value, + "description": full_response + } + except: + return { + "response": 0, + "description": full_response + } + else: + return { + "response": full_response, + "reasoning_process": None, + "references": None + } + else: + # 普通响应 + return { + "response": full_response, + "reasoning_process": None, + "references": None + } + + except Exception as e: + logger.error(f"聊天失败: {str(e)}") + error_msg = f"抱歉,发生错误: {str(e)}" + print(f"\n{error_msg}") + return {"response": error_msg, "reasoning_process": None, "references": None} + + def clear_history(self): + """清除对话历史""" + self.conversation_history = [] + # 重置会话 + if hasattr(self, 'model_instance'): + self.chat_session = self.model_instance.start_chat() + print("对话历史已清除") + + def run(self): + """运行聊天机器人""" + print("欢迎使用Gemini AI助手!输入 'quit' 退出,输入 'clear' 清除对话历史。") + print("-" * 50) + + while True: + try: + # 获取用户输入 + user_input = input("\n你: ").strip() + + # 检查退出命令 + if user_input.lower() == 'quit': + print("感谢使用,再见!") + break + + # 检查清除历史命令 + if user_input.lower() == 'clear': + self.clear_history() + continue + + # 如果输入为空,继续下一轮 + if not user_input: + continue + + # 获取AI回复 + self.chat(user_input) + print("-" * 50) + + except KeyboardInterrupt: + print("\n感谢使用,再见!") + break + except Exception as e: + logger.error(f"运行错误: {str(e)}") + print(f"发生错误: {str(e)}") + + def parse_json_response(self, full_response: str, model_class) -> dict: + """解析JSON响应,处理不同的JSON结构 + + Args: + full_response: 完整的响应文本 + model_class: 模型类(TextAnalysisResult或NumericalAnalysisResult) + + Returns: + dict: 包含解析结果的字典 + """ + # 检查分析类型 + is_text_analysis = model_class == TextAnalysisResult + is_numerical_analysis = model_class == NumericalAnalysisResult + + try: + # 尝试提取JSON内容 + json_text = full_response + # 如果文本以```json开头且以```结尾,则提取中间部分 + if "```json" in json_text and "```" in json_text.split("```json", 1)[1]: + json_text = json_text.split("```json", 1)[1].split("```", 1)[0].strip() + # 如果文本以```开头且以```结尾,则提取中间部分 + elif "```" in json_text and "```" in json_text.split("```", 1)[1]: + json_text = json_text.split("```", 1)[1].split("```", 1)[0].strip() + + # 尝试解析JSON + json_obj = json.loads(json_text) + + # 情况1: 直接包含字段的格式 {"analysis_text": "...", "reasoning_process": null, "references": [...]} + if is_text_analysis and "analysis_text" in json_obj: + return { + "response": json_obj["analysis_text"], + "reasoning_process": json_obj.get("reasoning_process"), + "references": json_obj.get("references", []) + } + elif is_numerical_analysis and "value" in json_obj: + value = json_obj["value"] + if isinstance(value, str) and value.isdigit(): + value = int(value) + elif isinstance(value, (int, float)): + value = int(value) + else: + value = 0 + + return { + "response": value, + "description": json_obj.get("description", "") + } + + # 情况2: properties中包含字段的格式 + if "properties" in json_obj and isinstance(json_obj["properties"], dict): + properties = json_obj["properties"] + + if is_text_analysis and "analysis_text" in properties: + return { + "response": properties["analysis_text"], + "reasoning_process": properties.get("reasoning_process"), + "references": properties.get("references", []) + } + elif is_numerical_analysis and "value" in properties: + value = properties["value"] + if isinstance(value, str) and value.isdigit(): + value = int(value) + elif isinstance(value, (int, float)): + value = int(value) + else: + value = 0 + + return { + "response": value, + "description": properties.get("description", "") + } + + # 如果无法识别结构,使用parser解析 + raise ValueError("无法识别的JSON结构,尝试使用parser解析") + + except Exception as e: + # 记录错误但不抛出,让调用方法继续尝试其他解析方式 + logger.warning(f"JSON解析失败: {str(e)}") + return None + +if __name__ == "__main__": + # 设置日志级别 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # 创建并运行聊天机器人 + bot = ChatBot() + bot.run() \ No newline at end of file diff --git a/src/fundamentals_llm/chat_bot_with_offline.py b/src/fundamentals_llm/chat_bot_with_offline.py index 83a435a..d25a7db 100644 --- a/src/fundamentals_llm/chat_bot_with_offline.py +++ b/src/fundamentals_llm/chat_bot_with_offline.py @@ -3,7 +3,7 @@ from openai import OpenAI import os import sys import time -import random + # 设置日志记录 logger = logging.getLogger(__name__) @@ -73,27 +73,27 @@ class ChatBot: self.system_message = { "role": "system", "content": """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下: - -1. 分析风格: -- 专业、客观、理性 -- 注重数据支撑 -- 关注风险控制 -- 重视投资性价比 - -2. 分析框架: -- 公司基本面分析 -- 行业竞争格局 -- 估值水平评估 -- 风险因素识别 -- 投资机会判断 - -3. 输出要求: -- 简明扼要,重点突出 -- 逻辑清晰,层次分明 -- 数据准确,论据充分 -- 结论明确,建议具体 - -请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" + + 1. 分析风格: + - 专业、客观、理性 + - 注重数据支撑 + - 关注风险控制 + - 重视投资性价比 + + 2. 分析框架: + - 公司基本面分析 + - 行业竞争格局 + - 估值水平评估 + - 风险因素识别 + - 投资机会判断 + + 3. 输出要求: + - 简明扼要,重点突出 + - 逻辑清晰,层次分明 + - 数据准确,论据充分 + - 结论明确,建议具体 + + 请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" } # 对话历史 @@ -102,8 +102,19 @@ class ChatBot: logger.error(f"初始化ChatBot时出错: {str(e)}") raise - def chat(self, user_input: str) -> str: - """处理用户输入并返回AI回复""" + def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> str: + """处理用户输入并返回AI回复 + + Args: + user_input: 用户输入的问题 + temperature: 控制输出的随机性,范围0-2,默认1.0 + top_p: 控制输出的多样性,范围0-1,默认0.7 + max_tokens: 控制输出的最大长度,默认4096 + frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 + + Returns: + str: AI的回答内容 + """ try: # # 添加用户消息到对话历史-多轮 self.conversation_history.append({ @@ -116,7 +127,10 @@ class ChatBot: stream = self.client.chat.completions.create( model=self.model, messages=self.conversation_history, - temperature=0, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + frequency_penalty=frequency_penalty, stream=True, timeout=600 ) diff --git a/src/fundamentals_llm/enterprise_screener.py b/src/fundamentals_llm/enterprise_screener.py index 5474440..062153d 100644 --- a/src/fundamentals_llm/enterprise_screener.py +++ b/src/fundamentals_llm/enterprise_screener.py @@ -130,9 +130,15 @@ class EnterpriseScreener: { 'dimension': 'investment_advice', 'field': 'investment_advice_type', - 'operator': '!=', - 'value': '不建议' - } + 'operator': 'in', + 'value': ["'短期'","'中期'","'长期'"] + }, + { + 'dimension': 'financial_report', + 'field': 'financial_report_level', + 'operator': '>=', + 'value': -1 + }, ] return self._screen_stocks_by_conditions(conditions, limit) diff --git a/src/fundamentals_llm/fundamental_analysis.py b/src/fundamentals_llm/fundamental_analysis.py index 845917b..ce109f7 100644 --- a/src/fundamentals_llm/fundamental_analysis.py +++ b/src/fundamentals_llm/fundamental_analysis.py @@ -1,7 +1,11 @@ import logging import os -from datetime import datetime -from typing import Dict, List, Optional, Tuple, Callable +import sys +from datetime import datetime, timedelta +import time +import redis +from typing import Dict, List, Optional, Tuple, Callable, Any + # 修改导入路径,使用相对导入 try: # 尝试相对导入 @@ -9,6 +13,7 @@ try: from .chat_bot_with_offline import ChatBot as OfflineChatBot from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result from .pdf_generator import PDFGenerator + from .text_processor import TextProcessor except ImportError: # 如果相对导入失败,尝试尝试绝对导入 try: @@ -16,18 +21,17 @@ except ImportError: from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result from src.fundamentals_llm.pdf_generator import PDFGenerator + from src.fundamentals_llm.text_processor import TextProcessor except ImportError: # 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况) from chat_bot import ChatBot from chat_bot_with_offline import ChatBot as OfflineChatBot from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result from pdf_generator import PDFGenerator + from text_processor import TextProcessor import json import re -# 设置日志记录 -logger = logging.getLogger(__name__) - # 获取项目根目录的绝对路径 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -58,6 +62,34 @@ logging.basicConfig( datefmt=date_format ) +logger = logging.getLogger(__name__) +logger.info("测试日志输出 - 程序启动") + +from typing import Dict, List, Optional, Any, Union +from pydantic import BaseModel, Field + +# 定义基础数据结构 +class TextAnalysisResult(BaseModel): + """文本分析结果,包含分析正文、推理过程和引用URL""" + analysis_text: str = Field(description="详细的分析文本") + reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None) + references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None) + +class NumericalAnalysisResult(BaseModel): + """数值分析结果,包含数值和分析描述""" + value: str = Field(description="评估值") + description: str = Field(description="评估描述") + +# 添加Redis客户端 +redis_client = redis.Redis( + host='192.168.18.208', # Redis服务器地址,根据实际情况调整 + port=6379, + password='wlkj2018', + db=14, + socket_timeout=5, + decode_responses=True +) + class FundamentalAnalyzer: """基本面分析器""" @@ -66,9 +98,10 @@ class FundamentalAnalyzer: # 使用联网模型进行基本面分析 self.chat_bot = ChatBot(model_type="online_bot") # 使用离线模型进行其他分析 - self.offline_bot = OfflineChatBot(platform="tl_private", model_type="ds-v1") + self.offline_bot = OfflineChatBot(platform="volc", model_type="offline_model") # 千问打杂 - self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") + # self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") + self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="GLM") self.db = next(get_db()) # 定义维度映射 @@ -161,6 +194,8 @@ class FundamentalAnalyzer: - 关键战略决策和转型 请提供专业、客观的分析,突出关键信息,避免冗长描述。""" + #开头清上下文缓存 + self.chat_bot.clear_history() # 获取AI分析结果 result = self.chat_bot.chat(prompt) @@ -554,8 +589,8 @@ class FundamentalAnalyzer: 请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。""" self.offline_bot_tl_qw.clear_history() # 使用离线模型进行分析 - space_value_str = self.offline_bot_tl_qw.chat(prompt) - space_value_str = self._clean_model_output(space_value_str) + space_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) + space_value_str = TextProcessor.clean_thought_process(space_value_str) # 提取数值 space_value = 0 # 默认值 @@ -658,9 +693,9 @@ class FundamentalAnalyzer: 请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" self.offline_bot_tl_qw.clear_history() # 使用离线模型进行分析 - events_value_str = self.offline_bot_tl_qw.chat(prompt) + events_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) # 数据清洗 - events_value_str = self._clean_model_output(events_value_str) + events_value_str = TextProcessor.clean_thought_process(events_value_str) # 提取数值 events_value = 0 # 默认值 @@ -700,14 +735,14 @@ class FundamentalAnalyzer: def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool: """分析股吧讨论内容""" try: - prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出: + prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在400字以内(主要讨论话题200字,重要信息汇总200字),请严格按照以下格式输出: - 1. 主要讨论话题(150字左右): + 1. 主要讨论话题: - 近期热点事件 - 投资者关注焦点 - 市场情绪倾向 - 2. 重要信息汇总(150字左右): + 2. 重要信息汇总: - 公司相关动态 - 行业政策变化 - 市场预期变化 @@ -762,10 +797,10 @@ class FundamentalAnalyzer: 请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" self.offline_bot_tl_qw.clear_history() # 使用离线模型进行分析 - emotion_value_str = self.offline_bot_tl_qw.chat(prompt) + emotion_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) # 数据清洗 - emotion_value_str = self._clean_model_output(emotion_value_str) + emotion_value_str = TextProcessor.clean_thought_process(emotion_value_str) # 提取数值 emotion_value = 0 # 默认值 @@ -1015,10 +1050,10 @@ class FundamentalAnalyzer: 只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。""" - response = self.chat_bot.chat(prompt) + response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) # 提取数值 - rating_str = self._extract_numeric_value_from_response(response) + rating_str = TextProcessor.extract_numeric_value_from_response(response) # 尝试将响应转换为整数 try: @@ -1047,20 +1082,20 @@ class FundamentalAnalyzer: """ try: prompt = f"""请仔细分析以下目标股价文本,判断上涨空间和下跌空间的关系,并返回对应的数值: -- 如果上涨空间大于下跌空间,返回数值"1" -- 如果上涨空间和下跌空间差不多,返回数值"0" -- 如果下跌空间大于上涨空间,返回数值"-1" -- 如果文本中没有相关信息,返回数值"0" - -目标股价文本: -{price_text} - -只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" + - 如果上涨空间大于下跌空间,返回数值"1" + - 如果上涨空间和下跌空间差不多,返回数值"0" + - 如果下跌空间大于上涨空间,返回数值"-1" + - 如果文本中没有相关信息,返回数值"0" + + 目标股价文本: + {price_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" - response = self.chat_bot.chat(prompt) + response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) # 提取数值 - odds_str = self._extract_numeric_value_from_response(response) + odds_str = TextProcessor.extract_numeric_value_from_response(response) # 尝试将响应转换为整数 try: @@ -1200,8 +1235,8 @@ class FundamentalAnalyzer: 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" self.offline_bot_tl_qw.clear_history() - response = self.offline_bot_tl_qw.chat(prompt) - pe_hist_str = self._clean_model_output(response) + response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) + pe_hist_str = TextProcessor.clean_thought_process(response) try: pe_hist = int(pe_hist_str) @@ -1240,8 +1275,8 @@ class FundamentalAnalyzer: 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" self.offline_bot_tl_qw.clear_history() - response = self.offline_bot_tl_qw.chat(prompt) - pb_hist_str = self._clean_model_output(response) + response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) + pb_hist_str = TextProcessor.clean_thought_process(response) try: pb_hist = int(pb_hist_str) @@ -1280,8 +1315,8 @@ class FundamentalAnalyzer: 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" self.offline_bot_tl_qw.clear_history() - response = self.offline_bot_tl_qw.chat(prompt) - pe_ind_str = self._clean_model_output(response) + response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) + pe_ind_str = TextProcessor.clean_thought_process(response) try: pe_ind = int(pe_ind_str) @@ -1320,8 +1355,8 @@ class FundamentalAnalyzer: 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" self.offline_bot_tl_qw.clear_history() - response = self.offline_bot_tl_qw.chat(prompt) - pb_ind_str = self._clean_model_output(response) + response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) + pb_ind_str = TextProcessor.clean_thought_process(response) try: pb_ind = int(pb_ind_str) @@ -1373,9 +1408,9 @@ class FundamentalAnalyzer: {json.dumps(all_results, ensure_ascii=False, indent=2)}""" self.offline_bot.clear_history() # 使用离线模型生成建议 - result = self.offline_bot.chat(prompt) + result = self.offline_bot.chat(prompt,max_tokens=20000) # 清理模型输出 - result = self._clean_model_output(result) + result = TextProcessor.clean_thought_process(result) # 保存到数据库 success = save_analysis_result( self.db, @@ -1440,94 +1475,7 @@ class FundamentalAnalyzer: except Exception as e: logger.error(f"提取投资建议类型失败: {str(e)}") return None - - def _clean_model_output(self, output: str) -> str: - """清理模型输出,移除推理过程,只保留最终结果 - Args: - output: 模型原始输出文本 - - Returns: - str: 清理后的输出文本 - """ - try: - # 找到标签的位置 - think_end = output.find('') - if think_end != -1: - # 移除标签及其之前的所有内容 - output = output[think_end + len(''):] - - # 处理可能存在的空行 - lines = output.split('\n') - cleaned_lines = [] - for line in lines: - line = line.strip() - if line: # 只保留非空行 - cleaned_lines.append(line) - - # 重新组合文本 - output = '\n'.join(cleaned_lines) - - return output.strip() - - except Exception as e: - logger.error(f"清理模型输出失败: {str(e)}") - return output.strip() - - def _extract_numeric_value_from_response(self, response: str) -> str: - """从模型响应中提取数值,移除参考资料和推理过程 - - Args: - response: 模型原始响应文本或响应对象 - - Returns: - str: 提取的数值字符串 - """ - try: - # 处理响应对象(包含response字段的字典) - if isinstance(response, dict) and "response" in response: - response = response["response"] - - # 确保响应是字符串 - if not isinstance(response, str): - logger.warning(f"响应不是字符串类型: {type(response)}") - return "0" - - # 移除推理过程部分 - reasoning_start = response.find("推理过程:") - if reasoning_start != -1: - response = response[:reasoning_start].strip() - - # 移除参考资料部分(通常以 [数字] 开头的行) - lines = response.split("\n") - cleaned_lines = [] - - for line in lines: - # 跳过参考资料行(通常以 [数字] 开头) - if re.match(r'\[\d+\]', line.strip()): - continue - cleaned_lines.append(line) - - response = "\n".join(cleaned_lines).strip() - - # 提取数值 - # 先尝试直接将整个响应转换为数值 - if response.strip() in ["-2", "-1", "0", "1", "2"]: - return response.strip() - - # 如果整个响应不是数值,尝试匹配第一个数值 - match = re.search(r'([-]?[0-9])', response) - if match: - return match.group(1) - - # 如果没有找到数值,返回默认值 - logger.warning(f"未能从响应中提取数值: {response}") - return "0" - - except Exception as e: - logger.error(f"从响应中提取数值失败: {str(e)}") - return "0" - def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]: """尝试多次从投资建议中提取建议类型 @@ -1549,7 +1497,7 @@ class FundamentalAnalyzer: # 使用千问离线模型提取建议类型 - result = self.offline_bot_tl_qw.chat(prompt) + result = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) # 检查是否是错误响应 if isinstance(result, str) and "抱歉,发生错误" in result: @@ -1557,7 +1505,7 @@ class FundamentalAnalyzer: continue # 清理模型输出 - cleaned_result = self._clean_model_output(result) + cleaned_result = TextProcessor.clean_thought_process(result) # 检查结果是否为有效类型 if cleaned_result in valid_types: @@ -1644,6 +1592,24 @@ class FundamentalAnalyzer: Optional[str]: 生成的PDF文件路径,如果失败则返回None """ try: + # 检查是否已存在PDF文件 + reports_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'reports') + os.makedirs(reports_dir, exist_ok=True) + + # 构建可能的文件名格式 + possible_filenames = [ + f"{stock_name}_{stock_code}_analysis.pdf", + f"{stock_name}_{stock_code}.SZ_analysis.pdf", + f"{stock_name}_{stock_code}.SH_analysis.pdf" + ] + + # 检查是否存在已生成的PDF文件 + for filename in possible_filenames: + filepath = os.path.join(reports_dir, filename) + if os.path.exists(filepath): + logger.info(f"找到已存在的PDF报告: {filepath}") + return filepath + # 维度名称映射 dimension_names = { "company_profile": "公司简介", @@ -1669,27 +1635,156 @@ class FundamentalAnalyzer: logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果") return None + # 确保字体目录存在 + fonts_dir = os.path.join(os.path.dirname(__file__), "fonts") + os.makedirs(fonts_dir, exist_ok=True) + + # 检查是否存在字体文件,如果不存在则创建简单的默认字体标记文件 + font_path = os.path.join(fonts_dir, "simhei.ttf") + if not os.path.exists(font_path): + # 尝试从系统字体目录复制 + try: + import shutil + # 尝试常见的系统字体位置 + system_fonts = [ + "C:/Windows/Fonts/simhei.ttf", # Windows + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # Linux + "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", # 其他Linux + "/System/Library/Fonts/PingFang.ttc" # macOS + ] + + for system_font in system_fonts: + if os.path.exists(system_font): + shutil.copy2(system_font, font_path) + logger.info(f"已复制字体文件: {system_font} -> {font_path}") + break + except Exception as font_error: + logger.warning(f"复制字体文件失败: {str(font_error)}") + # 创建PDF生成器实例 generator = PDFGenerator() # 生成PDF报告 - filepath = generator.generate_pdf( - title=f"{stock_name}({stock_code}) 基本面分析报告", - content_dict=content_dict, - filename=f"{stock_name}_{stock_code}_analysis.pdf" - ) + try: + # 第一次尝试生成 + filepath = generator.generate_pdf( + title=f"{stock_name}({stock_code}) 基本面分析报告", + content_dict=content_dict, + output_dir=reports_dir, + filename=f"{stock_name}_{stock_code}_analysis.pdf" + ) + + if filepath: + logger.info(f"PDF报告已生成: {filepath}") + return filepath + except Exception as pdf_error: + logger.error(f"生成PDF报告第一次尝试失败: {str(pdf_error)}") + + # 如果是字体问题,可能需要使用备选方案 + if "找不到中文字体文件" in str(pdf_error): + # 导出为文本文件作为备选 + try: + txt_filename = f"{stock_name}_{stock_code}_analysis.txt" + txt_filepath = os.path.join(reports_dir, txt_filename) + + with open(txt_filepath, 'w', encoding='utf-8') as f: + f.write(f"{stock_name}({stock_code}) 基本面分析报告\n") + f.write(f"生成时间:{datetime.now().strftime('%Y年%m月%d日 %H:%M:%S')}\n\n") + + for section_title, content in content_dict.items(): + if content: + f.write(f"## {section_title}\n\n") + f.write(f"{content}\n\n") + + logger.info(f"由于PDF生成失败,已生成文本报告: {txt_filepath}") + return txt_filepath + except Exception as txt_error: + logger.error(f"生成文本报告失败: {str(txt_error)}") - if filepath: - logger.info(f"PDF报告已生成: {filepath}") - else: - logger.error("PDF报告生成失败") - - return filepath + logger.error("PDF报告生成失败") + return None except Exception as e: logger.error(f"生成PDF报告失败: {str(e)}") return None + def is_stock_locked(self, stock_code: str, dimension: str) -> bool: + """检查股票是否已被锁定 + + Args: + stock_code: 股票代码 + dimension: 分析维度 + + Returns: + bool: 是否被锁定 + """ + try: + lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" + + # 检查是否存在锁 + existing_lock = redis_client.get(lock_key) + if existing_lock: + lock_time = int(existing_lock) + current_time = int(time.time()) + + # 锁超过30分钟(1800秒)视为过期 + if current_time - lock_time > 1800: + # 锁已过期,可以释放 + redis_client.delete(lock_key) + logger.info(f"股票 {stock_code} 维度 {dimension} 的过期锁已释放") + return False + else: + # 锁未过期,被锁定 + return True + + # 不存在锁 + return False + except Exception as e: + logger.error(f"检查股票 {stock_code} 锁状态时出错: {str(e)}") + # 出错时保守处理,返回未锁定 + return False + + def lock_stock(self, stock_code: str, dimension: str) -> bool: + """锁定股票 + + Args: + stock_code: 股票代码 + dimension: 分析维度 + + Returns: + bool: 是否成功锁定 + """ + try: + lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" + current_time = int(time.time()) + + # 设置锁,过期时间1小时 + redis_client.set(lock_key, current_time, ex=3600) + logger.info(f"股票 {stock_code} 维度 {dimension} 已锁定") + return True + except Exception as e: + logger.error(f"锁定股票 {stock_code} 时出错: {str(e)}") + return False + + def unlock_stock(self, stock_code: str, dimension: str) -> bool: + """解锁股票 + + Args: + stock_code: 股票代码 + dimension: 分析维度 + + Returns: + bool: 是否成功解锁 + """ + try: + lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" + redis_client.delete(lock_key) + logger.info(f"股票 {stock_code} 维度 {dimension} 已解锁") + return True + except Exception as e: + logger.error(f"解锁股票 {stock_code} 时出错: {str(e)}") + return False + def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool: """测试单个分析方法""" try: @@ -1713,12 +1808,7 @@ def test_single_stock(analyzer: FundamentalAnalyzer, stock_code: str, stock_name def main(): """主函数""" - # 设置日志级别 - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) - + # 测试股票列表 test_stocks = [ ("603690", "至纯科技"), diff --git a/src/fundamentals_llm/fundamental_analysis_gemini.py b/src/fundamentals_llm/fundamental_analysis_gemini.py new file mode 100644 index 0000000..a5fb766 --- /dev/null +++ b/src/fundamentals_llm/fundamental_analysis_gemini.py @@ -0,0 +1,1704 @@ +import logging +import os +import sys +from datetime import datetime +from typing import Dict, List, Optional, Tuple, Callable, Any + +from google.genai._common import BaseModel + +# 修改导入路径,使用相对导入 +try: + # 尝试相对导入 + from .chat_bot import ChatBot + from .chat_bot_with_offline import ChatBot as OfflineChatBot + from .chat_bot_with_gemini import ChatBot as GeminiChatBot + from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result + from .pdf_generator import PDFGenerator + from .text_processor import TextProcessor +except ImportError: + # 如果相对导入失败,尝试尝试绝对导入 + try: + from src.fundamentals_llm.chat_bot import ChatBot + from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot + from src.fundamentals_llm.chat_bot_with_gemini import ChatBot as GeminiChatBot + from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result + from src.fundamentals_llm.pdf_generator import PDFGenerator + from src.fundamentals_llm.text_processor import TextProcessor + except ImportError: + # 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况) + from chat_bot import ChatBot + from chat_bot_with_offline import ChatBot as OfflineChatBot + from chat_bot_with_gemini import ChatBot as GeminiChatBot + from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result + from pdf_generator import PDFGenerator + from text_processor import TextProcessor +import json +import re + +# 获取项目根目录的绝对路径 +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# 创建logs目录(如果不存在) +LOGS_DIR = os.path.join(ROOT_DIR, "logs") +os.makedirs(LOGS_DIR, exist_ok=True) + +# 配置日志格式 +log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +date_format = '%Y-%m-%d %H:%M:%S' + +# 创建文件处理器 +log_file = os.path.join(LOGS_DIR, f"fundamental_analysis_{datetime.now().strftime('%Y%m%d')}.log") +file_handler = logging.FileHandler(log_file, encoding='utf-8') +file_handler.setLevel(logging.INFO) +file_handler.setFormatter(logging.Formatter(log_format, date_format)) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +console_handler.setFormatter(logging.Formatter(log_format, date_format)) + +# 配置根日志记录器 +logging.basicConfig( + level=logging.INFO, + handlers=[file_handler, console_handler], + format=log_format, + datefmt=date_format +) + +logger = logging.getLogger(__name__) +logger.info("测试日志输出 - 程序启动") + +class FundamentalAnalyzer: + """基本面分析器""" + + def __init__(self): + """初始化分析器""" + # 使用联网模型进行基本面分析 + self.chat_bot = ChatBot(model_type="online_bot") + # 使用离线模型进行其他分析 + self.offline_bot = OfflineChatBot(platform="tl_private", model_type="ds_v1") + # 谷歌模型测试 + self.gemini_bot = GeminiChatBot(platform="Gemini", model_type="offline_model") + # 千问打杂 + # self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") + self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") + self.db = next(get_db()) + + # 定义维度映射 + self.dimension_methods = { + "company_profile": self.analyze_company_profile, + "management_ownership": self.analyze_management_ownership, + "financial_report": self.analyze_financial_report, + "industry_competition": self.analyze_industry_competition, + "recent_projects": self.analyze_recent_projects, + "stock_discussion": self.analyze_stock_discussion, + "industry_cooperation": self.analyze_industry_cooperation, + "target_price": self.analyze_target_price, + "valuation_level": self.analyze_valuation_level, + "investment_advice": self.generate_investment_advice + } + + class Recipe(BaseModel): + recipe_name: str + ingredients: list[str] + + def query_analysis(self, stock_code: str, stock_name: str, dimension: str) -> Tuple[bool, str, Optional[str], Optional[list]]: + """查询分析结果,如果不存在则生成新的分析 + + Args: + stock_code: 股票代码 + stock_name: 股票名称 + dimension: 分析维度 + + Returns: + Tuple[bool, str, Optional[str], Optional[list]]: + - 是否成功 + - 分析结果 + - 推理过程(如果有) + - 参考资料(如果有) + """ + try: + # 检查维度是否有效 + if dimension not in self.dimension_methods: + return False, f"无效的分析维度: {dimension}", None, None + + # 查询数据库 + result = get_analysis_result(self.db, stock_code, dimension) + + if result: + # 如果存在结果,直接返回 + logger.info(f"从数据库获取到 {stock_name}({stock_code}) 的 {dimension} 分析结果") + return True, result.ai_response, result.reasoning_process, result.references + + # 如果不存在,生成新的分析 + logger.info(f"数据库中未找到 {stock_name}({stock_code}) 的 {dimension} 分析结果,开始生成") + success = self.dimension_methods[dimension](stock_code, stock_name) + + if success: + # 重新查询数据库获取结果 + result = get_analysis_result(self.db, stock_code, dimension) + if result: + return True, result.ai_response, result.reasoning_process, result.references + + return False, f"生成 {dimension} 分析失败", None, None + + except Exception as e: + logger.error(f"查询分析结果失败: {str(e)}") + return False, f"查询失败: {str(e)}", None, None + + def _remove_references_from_response(self, response: str) -> str: + """从响应中移除参考资料部分 + + Args: + response: 原始响应文本 + + Returns: + str: 移除参考资料后的响应文本 + """ + # # 查找"参考资料:"的位置 + # ref_start = response.find("参考资料:") + # if ref_start != -1: + # # 如果找到参考资料,只保留前面的部分,并移除末尾的换行符 + # return response[:ref_start].rstrip() + return response.strip() + + def analyze_company_profile(self, stock_code: str, stock_name: str) -> bool: + """分析公司简介""" + try: + # 构建提示词 + prompt = f"""请对{stock_name}({stock_code})进行公司简介分析,内容包括: + + 1. 主营业务:详细说明公司的核心业务领域、产品或服务,包括技术特点和产业链布局,以及企业的核心竞争优势。 + + 2. 发展历程:概述公司的成立背景、重要发展节点、关键战略决策和转型。 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。确保分析内容充分体现公司特点和行业地位。""" + + # 调用Gemini模型进行分析 + result = self.gemini_bot.chat( + prompt, + analysis_type="text_analysis_result" + ) + + # 保存到数据库 + return save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="company_profile", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + except Exception as e: + logger.error(f"分析公司简介失败: {str(e)}") + return False + + def analyze_management_ownership(self, stock_code: str, stock_name: str) -> bool: + """分析实控人和管理层持股情况""" + try: + prompt = f"""请对{stock_name}({stock_code})的实控人和管理层持股情况进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出: + +1. 实控人情况: +- 实控人姓名 +- 行业地位(如有) +- 持股比例(冻结/解禁) + +2. 管理层持股: +- 主要管理层持股比例 +- 近3年增减持情况 + +请提供专业、客观的分析,突出关键信息,避免冗长描述。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt, analysis_type="text_analysis_result") + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="management_ownership", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取实控人和管理层信息并更新结果 + if success: + self.extract_management_info(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析实控人和管理层持股情况失败: {str(e)}") + return False + + def extract_management_info(self, ownership_text: str, stock_code: str, stock_name: str) -> Dict[str, int]: + """从实控人和管理层持股分析中提取持股情况和能力评价并更新数据库 + + Args: + ownership_text: 完整的实控人和管理层持股分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + Dict[str, int]: 包含shareholding和ability的字典 + """ + try: + # 提取持股情况 + shareholding = self._extract_shareholding_status(ownership_text, stock_code, stock_name) + + # 提取领军人物和处罚情况 + ability = self._extract_management_ability(ownership_text, stock_code, stock_name) + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "management_ownership") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="management_ownership", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "shareholding": shareholding, + "ability": ability + } + ) + + logger.info(f"已更新实控人和管理层信息到数据库: shareholding={shareholding}, ability={ability}") + + return {"shareholding": shareholding, "ability": ability} + + except Exception as e: + logger.error(f"提取实控人和管理层信息失败: {str(e)}") + return {"shareholding": 0, "ability": 0} + + def _extract_shareholding_status(self, ownership_text: str, stock_code: str, stock_name: str) -> int: + """从实控人和管理层持股分析中提取持股减持情况 + + Args: + ownership_text: 完整的实控人和管理层持股分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 减持情况评价 (1:减持低, 0:减持适中, -1:减持高) + """ + try: + # 使用在线模型分析减持情况 + prompt = f"""请对{stock_name}({stock_code})的大股东和管理层近三年减持情况进行专业分析,并返回对应的数值评级: + +- 如果大股东或管理层近三年减持比例很低或次数很低(减持比例低于2%,减持次数少于5,减持后持股比例仍然高于5%),返回数值"1" +- 如果大股东减持比例适中(减持比例在2%-5%之间,减持次数不超过10次,减持后持股比例仍然高于5%),返回数值"0" +- 如果近三年减持次数较多或比例较高,且大股东持股比例已经减持到5%以下,返回数值"-1" + +持股分析内容: +{ownership_text} + +请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" + + # 使用在线模型进行分析 + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response.get("value") + + except Exception as e: + logger.error(f"提取持股情况失败: {str(e)}") + return 0 + + def _extract_management_ability(self, ownership_text: str, stock_code: str, stock_name: str) -> int: + """从实控人和管理层持股分析中提取管理层能力评价 + + Args: + ownership_text: 完整的实控人和管理层持股分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 管理层能力评价 (2:行业顶尖, 1:能力较强, 0:能力一般, -1:能力较弱) + """ + try: + # 使用在线模型分析管理层能力 + prompt = f"""请对{stock_name}({stock_code})的管理层能力进行专业分析,并返回对应的数值评级: + +- 如果公司管理层在行业内处于顶尖水平(如具有行业公认的领导力、创新能力,过往业绩优秀),返回数值"2" +- 如果管理层能力较强(有较好的领导力、执行力,业绩表现良好),返回数值"1" +- 如果管理层能力一般(无特别突出表现,业绩中规中矩),返回数值"0" +- 如果管理层能力较弱(存在明显问题,如决策失误、执行不力等),返回数值"-1" + +管理层分析内容: +{ownership_text} + +请分析管理层的能力,包括业绩表现、战略规划、执行力等方面,并给出客观评价。""" + + # 使用在线模型进行分析 + result = self.gemini_bot.chat( + prompt, + analysis_type="management_ability", + temperature=0.7, + top_p=0.7, + max_tokens=4096 + ) + + # 如果返回的是整数或能转换为整数的值 + if isinstance(result["response"], int): + ability_value = result["response"] + else: + try: + ability_value = int(result["response"]) + except ValueError: + # 如果无法转换为整数,使用默认值 + logger.warning(f"管理层能力评价值无法转换为整数: {result['response']},使用默认值0") + ability_value = 0 + + # 确保值在有效范围内 + if ability_value < -1 or ability_value > 2: + logger.warning(f"管理层能力评价值超出范围: {ability_value},调整到有效范围") + ability_value = max(-1, min(ability_value, 2)) + + # 更新数据库中的管理层能力评价 + self.db.update_numerical_analysis( + stock_code=stock_code, + dimension="management_ownership", + field="ability", + value=ability_value + ) + + logger.info(f"提取{stock_name}管理层能力评价: {ability_value}") + return ability_value + + except Exception as e: + logger.error(f"提取管理层能力评价失败: {str(e)}") + return 0 + + def analyze_financial_report(self, stock_code: str, stock_name: str) -> bool: + """分析企业财报情况""" + try: + prompt = f"""请对{stock_name}({stock_code})的财报情况进行简要分析,严格要求最新财报情况200字以内,最新业绩预告情况100字以内,近三年变化趋势150字以内,请严格按照以下格式输出: + + 1. 最新财报情况 + - 营业收入及同比变化 + - 主要成本构成及变化 + - 净利润及同比变化 + - 毛利率和净利率变化 + - 其他重要财务指标(如ROE、资产负债率等) + + 2. 最新业绩预告情况(没有就不要提) + - 预告类型(预增/预减/扭亏/续亏) + - 预计业绩区间 + - 变动原因 + + 3. 近三年变化趋势 + - 收入增长趋势 + - 利润变化趋势 + - 盈利能力变化 + - 经营质量变化 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="financial_report", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取财报水平并更新结果 + if success: + self.extract_financial_report_level(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析财报情况失败: {str(e)}") + return False + + def extract_financial_report_level(self, report_text: str, stock_code: str, stock_name: str) -> int: + """从财报分析中提取财报水平评级并更新数据库 + + Args: + report_text: 完整的财报分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 财报水平评级 (2:边际向好无风险, 1:稳定风险小, 0:稳定有隐患, -1:波动大有隐患, -2:波动大隐患大) + """ + try: + # 使用在线模型分析财报水平 + prompt = f"""请对{stock_name}({stock_code})的财报水平进行专业分析,并返回对应的数值评级: + + - 如果财报水平边际向好,最新财报没有任何风险,返回数值"2" + - 如果边际变化波动不高较为稳定,并且风险很小,返回数值"1" + - 如果波动不高较为稳定,但其中存在一定财报隐患,返回数值"0" + - 如果财报波动较大(亏损或者盈利),并且存在一定财报隐患,返回数值"-1" + - 如果财报波动较大(亏损或者盈利),并且存在较大财报隐患,返回数值"-2" + + 财报分析内容: + {report_text} + + 请仅返回一个数值:2、1、0、-1或-2,不要包含任何解释或说明。""" + + # 使用在线模型进行分析 + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + report_level = response["value"] + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "financial_report") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="financial_report", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "financial_report_level": report_level + } + ) + + logger.info(f"已更新财报水平评级到数据库: financial_report_level={report_level}") + + return report_level + + except Exception as e: + logger.error(f"提取财报水平评级失败: {str(e)}") + return 0 + + def analyze_industry_competition(self, stock_code: str, stock_name: str) -> bool: + """分析行业发展趋势和竞争格局""" + try: + prompt = f"""请对{stock_name}({stock_code})所在行业的发展趋势和竞争格局进行简要分析,要求输出控制在400字以内,请严格按照以下格式输出: + + 1. 市场需求: + - 主要下游应用领域 + - 需求增长驱动因素 + - 市场规模和增速 + + 2. 竞争格局: + - 主要竞争对手及特点 + - 行业集中度 + - 竞争壁垒 + + 3. 行业环境: + - 行业平均利润率 + - 政策环境影响 + - 技术发展趋势 + - 市场阶段结论:新兴市场、成熟市场、衰退市场 + + 4. 小结: + - 简要说明当前市场是否有利于企业经营 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="industry_competition", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取行业发展空间并更新结果 + if success: + self.extract_industry_space(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析行业发展趋势和竞争格局失败: {str(e)}") + return False + + def extract_industry_space(self, industry_text: str, stock_code: str, stock_name: str) -> int: + """从行业发展趋势和竞争格局中提取行业发展空间并更新数据库 + + Args: + industry_text: 完整的行业发展趋势和竞争格局文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 行业发展空间值 (2:高速增长, 1:稳定经营, 0:不确定性大, -1:不利经营) + """ + try: + # 使用离线模型分析行业发展空间 + prompt = f"""请分析以下{stock_name}({stock_code})的行业发展趋势和竞争格局文本,评估当前市场环境、阶段和竞争格局对企业未来的影响,并返回对应的数值: + - 如果当前市场环境、阶段和竞争格局符合未来企业高速增长,返回数值"2" + - 如果当前市场环境、阶段和竞争格局符合未来企业稳定经营,返回数值"1" + - 如果当前市场环境、阶段和竞争格局存在较大不确定性,返回数值"0" + - 如果当前市场环境、阶段和竞争格局不利于企业正常经营,返回数值"-1" + + 行业发展趋势和竞争格局文本: + {industry_text} + + 请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。""" + # 使用离线模型进行分析 + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + space_value = response["value"] + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "industry_competition") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="industry_competition", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "industry_space": space_value + } + ) + + logger.info(f"已更新行业发展空间到数据库: industry_space={space_value}") + + return space_value + + except Exception as e: + logger.error(f"提取行业发展空间失败: {str(e)}") + return 0 + + def analyze_recent_projects(self, stock_code: str, stock_name: str) -> bool: + """分析近期重大订单和项目进展""" + try: + prompt = f"""请对{stock_name}({stock_code})的近期重大订单和项目进展进行简要分析,要求输出控制在500字以内,请严格按照以下格式输出: + +1. 主要业务领域进展(300字左右): +- 各业务领域的重要订单情况 +- 项目投产和建设进展 +- 产能扩张计划 +- 技术突破和产品创新 + +2. 海外布局(如有)(200字左右): +- 海外生产基地建设 +- 国际合作项目 +- 市场拓展计划 + +请提供专业、客观的分析,突出关键信息,避免冗长描述。如果企业近期没有重大订单或项目进展,请直接说明。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="recent_projects", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取重大事件评价并更新结果 + if success: + self.extract_major_events(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析近期重大订单和项目进展失败: {str(e)}") + return False + + def extract_major_events(self, projects_text: str, stock_code: str, stock_name: str) -> int: + """从重大订单和项目进展中提取进展情况并更新数据库 + + Args: + projects_text: 完整的重大订单和项目进展文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 项目进展评价 (1:超预期, 0:顺利但未超预期, -1:不顺利或在验证中) + """ + try: + # 使用离线模型分析项目进展情况 + prompt = f"""请分析以下{stock_name}({stock_code})的重大订单和项目进展情况,并返回对应的数值: + - 如果项目进展顺利,且订单交付/建厂等超预期,返回数值"1" + - 如果进展顺利,但没有超预期,返回数值"0" + - 如果进展不顺利,或者按照进度进行但仍然在验证中,返回数值"-1" + + 重大订单和项目进展内容: + {projects_text} + + 请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" + # 使用离线模型进行分析 + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + events_value = response["value"] + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "recent_projects") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="recent_projects", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "major_events": events_value + } + ) + + logger.info(f"已更新重大项目进展评价到数据库: major_events={events_value}") + + return events_value + + except Exception as e: + logger.error(f"提取重大项目进展评价失败: {str(e)}") + return 0 + + def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool: + """分析股吧讨论内容""" + try: + prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在400字以内(主要讨论话题200字,重要信息汇总200字),请严格按照以下格式输出: + + 1. 主要讨论话题: + - 近期热点事件 + - 投资者关注焦点 + - 市场情绪倾向 + + 2. 重要信息汇总: + - 公司相关动态 + - 行业政策变化 + - 市场预期变化 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。重点关注投资者普遍关注的话题和重要市场信息。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="stock_discussion", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取股吧情绪并更新结果 + if success: + self.extract_stock_discussion_emotion(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析股吧讨论内容失败: {str(e)}") + return False + + def extract_stock_discussion_emotion(self, discussion_text: str, stock_code: str, stock_name: str) -> int: + """从股吧讨论内容中提取市场情绪并更新数据库 + + Args: + discussion_text: 完整的股吧讨论内容分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 市场情绪值 (1:乐观, 0:中性, -1:悲观) + """ + try: + # 使用离线模型分析市场情绪 + prompt = f"""请分析以下{stock_name}({stock_code})的股吧讨论内容分析,判断整体市场情绪倾向,并返回对应的数值: + - 如果股吧讨论情绪偏乐观,返回数值"1" + - 如果股吧讨论情绪偏中性,返回数值"0" + - 如果股吧讨论情绪偏悲观,返回数值"-1" + + 股吧讨论内容分析: + {discussion_text} + + 请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" + # 使用离线模型进行分析 + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + emotion_value = response["value"] + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "stock_discussion") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="stock_discussion", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "emotion": emotion_value + } + ) + + logger.info(f"已更新股吧情绪值到数据库: emotion={emotion_value}") + + return emotion_value + + except Exception as e: + logger.error(f"提取股吧情绪值失败: {str(e)}") + return 0 + + def analyze_industry_cooperation(self, stock_code: str, stock_name: str) -> bool: + """分析产业链上下游合作动态""" + try: + prompt = f"""请对{stock_name}({stock_code})最近半年内的产业链上下游合作动态进行简要分析,要求输出控制在400字以内,请严格按照以下格式输出: + + 1. 重要客户合作(200字左右): + - 主要客户合作进展 + - 产品供应情况 + - 合作深度和规模 + + 2. 产业链布局(200字左右): + - 上下游合作动态 + - 新业务领域拓展 + - 战略合作项目 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。重点关注最近半年内的合作动态,如果没有相关动态,请直接说明。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="industry_cooperation", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取产业链合作动态质量并更新结果 + if success: + self.extract_collaboration_dynamics(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析产业链上下游合作动态失败: {str(e)}") + return False + + def extract_collaboration_dynamics(self, cooperation_text: str, stock_code: str, stock_name: str) -> int: + """从产业链上下游合作动态中提取合作动态质量评级并更新数据库 + + Args: + cooperation_text: 完整的产业链上下游合作动态文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + int: 合作动态质量值 (2:质量高, 1:一般, 0:无/质量低, -1:负面) + """ + try: + # 使用在线模型分析合作动态质量 + prompt = f"""请评估{stock_name}({stock_code})的产业链上下游合作动态质量,并返回相应数值: + + - 如果企业近期有较多且质量高的新合作动态(具备新业务拓展能力,以及可以体现在近一年财报中),返回数值"2" + - 如果企业半年内合作动态频率低或质量一般(在原有业务上合作关系的衍生,对财报影响一般),返回数值"1" + - 如果企业没有合作动态或质量低,返回数值"0" + - 如果企业有负面合作关系(解除合作,或业务被其他厂商瓜分),返回数值"-1" + + 以下是合作动态相关信息: + {cooperation_text} + + 请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。""" + + # 使用在线模型进行分析 + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + dynamics_value = response["value"] + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "industry_cooperation") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="industry_cooperation", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "collaboration_dynamics": dynamics_value + } + ) + + logger.info(f"已更新产业链合作动态质量到数据库: collaboration_dynamics={dynamics_value}") + + return dynamics_value + + except Exception as e: + logger.error(f"提取产业链合作动态质量失败: {str(e)}") + return 0 + + def analyze_target_price(self, stock_code: str, stock_name: str) -> bool: + """分析券商和研究机构目标股价""" + try: + prompt = f"""请对{stock_name}({stock_code})的券商和研究机构目标股价进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出: + + 1. 目标股价情况(150字左右): + - 半年内所有券商/研究机构的目标价 + - 半年内评级情况(买入/增持/中性/减持/卖出) + + 2. 当前最新股价对比(150字左右): + - 当前最新股价与目标价对比 + - 上涨/下跌空间 + - 评级建议 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。如果没有券商或研究机构的目标价,请直接说明。""" + # 获取联网AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="target_price", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取券商评级和上涨/下跌空间并更新结果 + if success: + self.extract_target_price_info(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析目标股价失败: {str(e)}") + return False + + def extract_target_price_info(self, price_text: str, stock_code: str, stock_name: str) -> Dict[str, int]: + """从目标股价分析中提取券商评级和上涨/下跌空间并更新数据库 + + Args: + price_text: 完整的目标股价分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + Dict[str, int]: 包含securities_rating和odds的字典 + """ + try: + # 提取券商评级和上涨/下跌空间 + securities_rating = self._extract_securities_rating(price_text) + odds = self._extract_odds(price_text) + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "target_price") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="target_price", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "securities_rating": securities_rating, + "odds": odds + } + ) + + logger.info(f"已更新目标股价信息到数据库: securities_rating={securities_rating}, odds={odds}") + + return {"securities_rating": securities_rating, "odds": odds} + + except Exception as e: + logger.error(f"提取目标股价信息失败: {str(e)}") + return {"securities_rating": 0, "odds": 0} + + def _extract_securities_rating(self, price_text: str) -> int: + """从目标股价分析中提取券商评级 + + Args: + price_text: 完整的目标股价分析文本 + + Returns: + int: 券商评级值 (买入:2, 增持:1, 中性:0, 减持:-1, 卖出:-2, 无评级:0) + """ + try: + # 使用离线模型提取评级 + prompt = f"""请仔细分析以下目标股价文本,判断半年内券商评级的主要倾向,并返回对应的数值: + - 如果半年内券商评级以"买入"居多,返回数值"2" + - 如果半年内券商评级以"增持"居多,返回数值"1" + - 如果半年内券商评级以"中性"居多,返回数值"0" + - 如果半年内券商评级以"减持"居多,返回数值"-1" + - 如果半年内券商评级以"卖出"居多,返回数值"-2" + - 如果半年内没有券商评级信息,返回数值"0" + + 目标股价文本: + {price_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。""" + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response["value"] + + except Exception as e: + logger.error(f"提取券商评级失败: {str(e)}") + return 0 + + def _extract_odds(self, price_text: str) -> int: + """从目标最新股价分析中提取上涨/下跌空间 + + Args: + price_text: 完整的目标股价分析文本 + + Returns: + int: 上涨/下跌空间值 (上涨空间大:1, 差不多:0, 下跌空间大:-1) + """ + try: + prompt = f"""请仔细分析以下目标股价文本,判断上涨空间和下跌空间的关系,并返回对应的数值: + - 如果上涨空间大于下跌空间,返回数值"1" + - 如果上涨空间和下跌空间差不多,返回数值"0" + - 如果下跌空间大于上涨空间,返回数值"-1" + - 如果文本中没有相关信息,返回数值"0" + + 目标股价文本: + {price_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response["value"] + + except Exception as e: + logger.error(f"提取上涨/下跌空间失败: {str(e)}") + return 0 + + def analyze_valuation_level(self, stock_code: str, stock_name: str) -> bool: + """分析企业PE和PB在历史分位水平和行业平均水平的对比情况""" + try: + prompt = f"""请对{stock_name}({stock_code})的估值水平进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出: + +1. 历史估值水平(150字左右): +- 当前PE值及其在历史分位水平的位置(高于/接近/低于历史平均分位) +- 当前PB值及其在历史分位水平的位置(高于/接近/低于历史平均分位) +- 历史估值变化趋势简要分析 + +2. 行业估值对比(150字左右): +- 当前PE值与行业平均水平的比较(高于/接近/低于行业平均) +- 当前PB值与行业平均水平的比较(高于/接近/低于行业平均) +- 与可比公司估值的简要对比 + +请提供专业、客观的分析,突出关键信息,避免冗长描述。如果无法获取某项数据,请直接说明。""" + # 获取AI分析结果 + result = self.gemini_bot.chat(prompt) + + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="valuation_level", + ai_response=self._remove_references_from_response(result["response"]), + reasoning_process=result["reasoning_process"], + references=result["references"] + ) + + # 提取估值水平分类并更新结果 + if success: + self.extract_valuation_classification(result["response"], stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"分析估值水平失败: {str(e)}") + return False + + def extract_valuation_classification(self, valuation_text: str, stock_code: str, stock_name: str) -> Dict[str, int]: + """从估值水平分析中提取历史和行业估值分类并更新数据库 + + Args: + valuation_text: 完整的估值水平分析文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + Dict[str, int]: 包含四个分类值的字典: + - pe_historical: PE历史分位分类 (-1:高于历史, 0:接近历史, 1:低于历史) + - pb_historical: PB历史分位分类 (-1:高于历史, 0:接近历史, 1:低于历史) + - pe_industry: PE行业对比分类 (-1:高于行业, 0:接近行业, 1:低于行业) + - pb_industry: PB行业对比分类 (-1:高于行业, 0:接近行业, 1:低于行业) + """ + try: + # 直接提取四个分类值 + pe_historical = self._extract_pe_historical(valuation_text) + pb_historical = self._extract_pb_historical(valuation_text) + pe_industry = self._extract_pe_industry(valuation_text) + pb_industry = self._extract_pb_industry(valuation_text) + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "valuation_level") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="valuation_level", + ai_response=result.ai_response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={ + "pe_historical": pe_historical, + "pb_historical": pb_historical, + "pe_industry": pe_industry, + "pb_industry": pb_industry + } + ) + + logger.info(f"已更新估值分类到数据库: pe_historical={pe_historical}, pb_historical={pb_historical}, pe_industry={pe_industry}, pb_industry={pb_industry}") + + return { + "pe_historical": pe_historical, + "pb_historical": pb_historical, + "pe_industry": pe_industry, + "pb_industry": pb_industry + } + + except Exception as e: + logger.error(f"提取估值分类失败: {str(e)}") + return { + "pe_historical": 0, + "pb_historical": 0, + "pe_industry": 0, + "pb_industry": 0 + } + + def _extract_pe_historical(self, valuation_text: str) -> int: + """从估值水平分析中提取PE历史分位分类 + + Args: + valuation_text: 完整的估值水平分析文本 + + Returns: + int: PE历史分位分类值 (-1:高于历史, 0:接近历史, 1:低于历史) + """ + try: + prompt = f"""请仔细分析以下估值水平文本,判断当前PE在历史分位的位置,并返回对应的数值: + - 如果当前PE为负数,返回数值"-1" + - 如果当前PE明显高于历史平均水平,返回数值"-1" + - 如果当前PE接近历史平均水平,返回数值"0" + - 如果当前PE明显低于历史平均水平,返回数值"1" + - 如果文本中没有相关信息,返回数值"0" + + 估值水平文本: + {valuation_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response["value"] + + except Exception as e: + logger.error(f"提取PE历史分位分类失败: {str(e)}") + return 0 + + def _extract_pb_historical(self, valuation_text: str) -> int: + """从估值水平分析中提取PB历史分位分类 + + Args: + valuation_text: 完整的估值水平分析文本 + + Returns: + int: PB历史分位分类值 (-1:高于历史, 0:接近历史, 1:低于历史) + """ + try: + prompt = f"""请仔细分析以下估值水平文本,判断当前PB在历史分位的位置,并返回对应的数值: + - 如果当前PB为负数,返回数值"-1" + - 如果当前PB明显高于历史平均水平,返回数值"-1" + - 如果当前PB接近历史平均水平,返回数值"0" + - 如果当前PB明显低于历史平均水平,返回数值"1" + - 如果文本中没有相关信息,返回数值"0" + + 估值水平文本: + {valuation_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response["value"] + + except Exception as e: + logger.error(f"提取PB历史分位分类失败: {str(e)}") + return 0 + + def _extract_pe_industry(self, valuation_text: str) -> int: + """从估值水平分析中提取PE行业对比分类 + + Args: + valuation_text: 完整的估值水平分析文本 + + Returns: + int: PE行业对比分类值 (-1:高于行业, 0:接近行业, 1:低于行业) + """ + try: + prompt = f"""请仔细分析以下估值水平文本,判断当前PE与行业平均水平的对比情况,并返回对应的数值: + - 如果当前PE为负数,返回数值"-1" + - 如果当前PE明显高于行业平均水平,返回数值"-1" + - 如果当前PE接近行业平均水平,返回数值"0" + - 如果当前PE明显低于行业平均水平,返回数值"1" + - 如果文本中没有相关信息,返回数值"0" + + 估值水平文本: + {valuation_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response["value"] + + except Exception as e: + logger.error(f"提取PE行业对比分类失败: {str(e)}") + return 0 + + def _extract_pb_industry(self, valuation_text: str) -> int: + """从估值水平分析中提取PB行业对比分类 + + Args: + valuation_text: 完整的估值水平分析文本 + + Returns: + int: PB行业对比分类值 (-1:高于行业, 0:接近行业, 1:低于行业) + """ + try: + prompt = f"""请仔细分析以下估值水平文本,判断当前PB与行业平均水平的对比情况,并返回对应的数值: + - 如果当前PB为负数,返回数值"-1" + - 如果当前PB明显高于行业平均水平,返回数值"-1" + - 如果当前PB接近行业平均水平,返回数值"0" + - 如果当前PB明显低于行业平均水平,返回数值"1" + - 如果文本中没有相关信息,返回数值"0" + + 估值水平文本: + {valuation_text} + + 只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + return response["value"] + except Exception as e: + logger.error(f"提取PB行业对比分类失败: {str(e)}") + return 0 + + def generate_investment_advice(self, stock_code: str, stock_name: str) -> bool: + """生成最终投资建议""" + try: + # 收集所有维度的分析结果(排除investment_advice) + all_results = {} + analysis_dimensions = [dim for dim in self.dimension_methods.keys() if dim != "investment_advice"] + + for dimension in analysis_dimensions: + # 查询数据库 + result = get_analysis_result(self.db, stock_code, dimension) + if not result: + # 如果数据库中没有结果,生成新的分析 + self.dimension_methods[dimension](stock_code, stock_name) + result = get_analysis_result(self.db, stock_code, dimension) + if result: + all_results[dimension] = result.ai_response + + # 构建提示词 + prompt = f"""请根据以下{stock_name}({stock_code})的各个维度分析结果,生成最终的投资建议,要求输出控制在300字以内,请严格按照以下格式输出: + + 投资建议:请从以下几个方面进行总结: + 1. 业绩表现和增长预期 + 2. 当前估值水平和市场预期 + 3. 行业竞争环境和风险 + 4. 投资建议和理由(请根据以下标准明确给出投资建议): + - 短期持有:近期(1-3个月内)有明确利好因素、催化事件或阶段性业绩改善 + - 中期持有:短期无明确利好,但中期(3-12个月)业绩面临向上拐点或行业处于上升周期 + - 长期持有:公司具备长期稳定的盈利能力、行业地位稳固、长期成长性好 + - 不建议投资:存在明显风险因素、基本面恶化、估值过高、行业前景不佳或者存在退市风险 + + 请提供专业、客观的分析,突出关键信息,避免冗长描述。重点关注投资价值和风险。在输出投资建议时,请明确指出是短期持有、中期持有、长期持有还是不建议投资。 + + 各维度分析结果: + {json.dumps(all_results, ensure_ascii=False, indent=2)}""" + # 使用离线模型生成建议 + result = self.gemini_bot.chat(prompt) + # 保存到数据库 + success = save_analysis_result( + self.db, + stock_code=stock_code, + stock_name=stock_name, + dimension="investment_advice", + ai_response=result["response"], + reasoning_process=None, + references=None + ) + + # 提取投资建议类型并更新结果 + if success: + self.offline_bot.clear_history() + investment_type = self.extract_investment_advice_type(result, stock_code, stock_name) + return True + + return success + + except Exception as e: + logger.error(f"生成投资建议失败: {str(e)}") + return False + + def extract_investment_advice_type(self, advice_text: str, stock_code: str, stock_name: str) -> str: + """从投资建议中提取建议类型并更新数据库 + + Args: + advice_text: 完整的投资建议文本 + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + str: 提取的投资建议类型(短期、中期、长期、不建议)或 None + """ + try: + valid_types = ["短期", "中期", "长期", "不建议"] + max_attempts = 3 + + # 调用辅助函数尝试多次提取 + found_type = self._try_extract_advice_type(advice_text, max_attempts) + + # 更新数据库中的记录 + result = get_analysis_result(self.db, stock_code, "investment_advice") + if result: + update_analysis_result( + self.db, + stock_code=stock_code, + dimension="investment_advice", + ai_response=result.response, + reasoning_process=result.reasoning_process, + references=result.references, + extra_info={"investment_advice_type": found_type} + ) + + if found_type: + logger.info(f"已更新投资建议类型到数据库: {found_type}") + else: + logger.info("已将投资建议类型更新为 null") + + return found_type + + except Exception as e: + logger.error(f"提取投资建议类型失败: {str(e)}") + return None + + def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]: + """尝试多次从投资建议中提取建议类型 + + Args: + advice_text: 完整的投资建议文本 + max_attempts: 最大尝试次数 + + Returns: + Optional[str]: 提取的投资建议类型,如果所有尝试都失败则返回None + """ + valid_types = ["短期", "中期", "长期", "不建议"] + + for attempt in range(1, max_attempts + 1): + try: + logger.info(f"尝试提取投资建议类型 (尝试 {attempt}/{max_attempts})") + + # 根据尝试次数获取不同的提示词 + prompt = self._get_extract_prompt_by_attempt(advice_text, attempt) + + # 使用千问离线模型提取建议类型 + + response = self.gemini_bot.chat(prompt, analysis_type="numerical_analysis_result") + + cleaned_result = response["value"] + + + # 检查结果是否为有效类型 + if cleaned_result in valid_types: + logger.info(f"成功提取投资建议类型: {cleaned_result}(尝试 {attempt}/{max_attempts})") + return cleaned_result + else: + logger.warning(f"未能提取有效的投资建议类型(尝试 {attempt}/{max_attempts}),获取到: '{cleaned_result}'") + + except Exception as e: + logger.error(f"提取投资建议类型失败(尝试 {attempt}/{max_attempts}): {str(e)}") + + # 所有尝试都失败 + logger.warning(f"经过 {max_attempts} 次尝试后未能提取有效的投资建议类型,设置为 null") + return None + + def _get_extract_prompt_by_attempt(self, advice_text: str, attempt: int) -> str: + """根据尝试次数获取不同的提取提示词 + + Args: + advice_text: 完整的投资建议文本 + attempt: 当前尝试次数 + + Returns: + str: 提取提示词 + """ + if attempt == 1: + return f"""请从以下投资建议中提取明确的投资建议类型。仅输出以下四种类型之一: +- 短期:如果建议短期持有 +- 中期:如果建议中期持有 +- 长期:如果建议长期持有 +- 不建议:如果不建议投资 + +投资建议文本: +{advice_text} + +只需要输出一个词,不要输出其他任何内容。""" + elif attempt == 2: + return f"""请仔细分析以下投资建议文本,并严格按照要求输出结果。 +请仅输出以下四个词之一:短期、中期、长期、不建议。 +不要输出其他任何内容,不要加任何解释。 + +投资建议文本: +{advice_text} + +如果建议短期持有,输出"短期" +如果建议中期持有,输出"中期" +如果建议长期持有,输出"长期" +如果不建议投资,输出"不建议" + +请再次确认你只输出一个词,没有任何额外内容。""" + else: + return f"""请判断以下投资建议文本最符合哪种情况,只输出对应的一个词: + +短期:近期(1-3个月内)有明确利好因素、催化事件或阶段性业绩改善 +中期:无明确短期利好,但中期(3-12个月)业绩面临向上拐点或行业处于上升周期 +长期:公司具备长期稳定的盈利能力、行业地位稳固、长期成长性好 +不建议:存在风险因素、基本面恶化、估值过高或行业前景不佳 + +投资建议文本: +{advice_text} + +只输出:短期、中期、长期、不建议中的一个,不要有任何其他内容。""" + + def analyze_all_dimensions(self, stock_code: str, stock_name: str) -> Dict[str, bool]: + """分析所有维度""" + results = {} + + # 逐个分析每个维度 + for dimension, method in self.dimension_methods.items(): + logger.info(f"开始分析 {stock_name}({stock_code}) 的 {dimension}") + results[dimension] = method(stock_code, stock_name) + logger.info(f"{dimension} 分析完成: {'成功' if results[dimension] else '失败'}") + + return results + + def generate_pdf_report(self, stock_code: str, stock_name: str) -> Optional[str]: + """生成PDF分析报告 + + Args: + stock_code: 股票代码 + stock_name: 股票名称 + + Returns: + Optional[str]: 生成的PDF文件路径,如果失败则返回None + """ + try: + # 检查是否已存在PDF文件 + reports_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'reports') + os.makedirs(reports_dir, exist_ok=True) + + # 构建可能的文件名格式 + possible_filenames = [ + f"{stock_name}_{stock_code}_analysis.pdf", + f"{stock_name}_{stock_code}.SZ_analysis.pdf", + f"{stock_name}_{stock_code}.SH_analysis.pdf" + ] + + # 检查是否存在已生成的PDF文件 + for filename in possible_filenames: + filepath = os.path.join(reports_dir, filename) + if os.path.exists(filepath): + logger.info(f"找到已存在的PDF报告: {filepath}") + return filepath + + # 维度名称映射 + dimension_names = { + "company_profile": "公司简介", + "management_ownership": "管理层持股", + "financial_report": "财务报告", + "industry_competition": "行业竞争", + "recent_projects": "近期项目", + "stock_discussion": "市场讨论", + "industry_cooperation": "产业合作", + "target_price": "目标股价", + "valuation_level": "估值水平", + "investment_advice": "投资建议" + } + + # 收集所有可用的分析结果 + content_dict = {} + for dimension in self.dimension_methods.keys(): + result = get_analysis_result(self.db, stock_code, dimension) + if result and result.ai_response: + content_dict[dimension_names[dimension]] = result.ai_response + + if not content_dict: + logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果") + return None + + # 确保字体目录存在 + fonts_dir = os.path.join(os.path.dirname(__file__), "fonts") + os.makedirs(fonts_dir, exist_ok=True) + + # 检查是否存在字体文件,如果不存在则创建简单的默认字体标记文件 + font_path = os.path.join(fonts_dir, "simhei.ttf") + if not os.path.exists(font_path): + # 尝试从系统字体目录复制 + try: + import shutil + # 尝试常见的系统字体位置 + system_fonts = [ + "C:/Windows/Fonts/simhei.ttf", # Windows + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # Linux + "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", # 其他Linux + "/System/Library/Fonts/PingFang.ttc" # macOS + ] + + for system_font in system_fonts: + if os.path.exists(system_font): + shutil.copy2(system_font, font_path) + logger.info(f"已复制字体文件: {system_font} -> {font_path}") + break + except Exception as font_error: + logger.warning(f"复制字体文件失败: {str(font_error)}") + + # 创建PDF生成器实例 + generator = PDFGenerator() + + # 生成PDF报告 + try: + # 第一次尝试生成 + filepath = generator.generate_pdf( + title=f"{stock_name}({stock_code}) 基本面分析报告", + content_dict=content_dict, + output_dir=reports_dir, + filename=f"{stock_name}_{stock_code}_analysis.pdf" + ) + + if filepath: + logger.info(f"PDF报告已生成: {filepath}") + return filepath + except Exception as pdf_error: + logger.error(f"生成PDF报告第一次尝试失败: {str(pdf_error)}") + + # 如果是字体问题,可能需要使用备选方案 + if "找不到中文字体文件" in str(pdf_error): + # 导出为文本文件作为备选 + try: + txt_filename = f"{stock_name}_{stock_code}_analysis.txt" + txt_filepath = os.path.join(reports_dir, txt_filename) + + with open(txt_filepath, 'w', encoding='utf-8') as f: + f.write(f"{stock_name}({stock_code}) 基本面分析报告\n") + f.write(f"生成时间:{datetime.now().strftime('%Y年%m月%d日 %H:%M:%S')}\n\n") + + for section_title, content in content_dict.items(): + if content: + f.write(f"## {section_title}\n\n") + f.write(f"{content}\n\n") + + logger.info(f"由于PDF生成失败,已生成文本报告: {txt_filepath}") + return txt_filepath + except Exception as txt_error: + logger.error(f"生成文本报告失败: {str(txt_error)}") + + logger.error("PDF报告生成失败") + return None + + except Exception as e: + logger.error(f"生成PDF报告失败: {str(e)}") + return None + +def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool: + """测试单个分析方法""" + try: + analyzer = FundamentalAnalyzer() + logger.info(f"开始测试 {method.__name__} 方法") + logger.info(f"测试股票: {stock_name}({stock_code})") + result = method(stock_code, stock_name) + logger.info(f"测试结果: {'成功' if result else '失败'}") + return result + except Exception as e: + logger.error(f"测试失败: {str(e)}") + return False + +def test_single_stock(analyzer: FundamentalAnalyzer, stock_code: str, stock_name: str) -> Dict[str, bool]: + """测试单个股票的所有维度""" + logger.info(f"开始测试股票: {stock_name}({stock_code})") + results = analyzer.analyze_all_dimensions(stock_code, stock_name) + success_count = sum(1 for r in results.values() if r) + logger.info(f"测试完成: 成功维度数 {success_count}/{len(results)}") + return results + +def main(): + """主函数""" + + # 测试股票列表 + test_stocks = [ + ("603690", "至纯科技"), + ("300767", "震安科技"), + ("300750", "宁德时代") + ] + + # 创建分析器实例 + analyzer = FundamentalAnalyzer() + + # 测试选项 + print("\n请选择测试模式:") + print("1. 查询单个维度分析") + print("2. 测试单个方法") + print("3. 测试单个股票") + print("4. 测试所有股票") + print("5. 生成PDF报告") + print("6. 生成投资建议并生成PDF") + print("7. 退出") + + choice = input("\n请输入选项(1-7): ").strip() + + if choice == "1": + # 查询单个维度分析 + print("\n可用的分析维度:") + for i, dimension in enumerate(analyzer.dimension_methods.keys(), 1): + print(f"{i}. {dimension}") + + dimension_choice = input("\n请选择要查询的维度(1-8): ").strip() + if dimension_choice.isdigit() and 1 <= int(dimension_choice) <= len(analyzer.dimension_methods): + dimension = list(analyzer.dimension_methods.keys())[int(dimension_choice) - 1] + + print("\n可用的股票:") + for i, (code, name) in enumerate(test_stocks, 1): + print(f"{i}. {name}({code})") + + stock_choice = input("\n请选择要查询的股票(1-3): ").strip() + if stock_choice.isdigit() and 1 <= int(stock_choice) <= len(test_stocks): + stock_code, stock_name = test_stocks[int(stock_choice) - 1] + + # 查询分析结果 + success, response, reasoning, references = analyzer.query_analysis(stock_code, stock_name, dimension) + + if success: + print(f"\n分析结果:\n{response}") + if reasoning: + print(f"\n推理过程:\n{reasoning}") + if references: + print("\n参考资料:") + for ref in references: + print(f"\n{ref}") + else: + print(f"\n查询失败:{response}") + + elif choice == "2": + # 测试单个方法 + print("\n可用的分析方法:") + for i, (dimension, method) in enumerate(analyzer.dimension_methods.items(), 1): + print(f"{i}. {dimension}") + + method_choice = input("\n请选择要测试的方法(1-8): ").strip() + if method_choice.isdigit() and 1 <= int(method_choice) <= len(analyzer.dimension_methods): + method = list(analyzer.dimension_methods.values())[int(method_choice) - 1] + + print("\n可用的股票:") + for i, (code, name) in enumerate(test_stocks, 1): + print(f"{i}. {name}({code})") + + stock_choice = input("\n请选择要测试的股票(1-3): ").strip() + if stock_choice.isdigit() and 1 <= int(stock_choice) <= len(test_stocks): + stock_code, stock_name = test_stocks[int(stock_choice) - 1] + test_single_method(method, stock_code, stock_name) + + elif choice == "3": + # 测试单个股票 + print("\n可用的股票:") + for i, (code, name) in enumerate(test_stocks, 1): + print(f"{i}. {name}({code})") + + stock_choice = input("\n请选择要测试的股票(1-3): ").strip() + if stock_choice.isdigit() and 1 <= int(stock_choice) <= len(test_stocks): + stock_code, stock_name = test_stocks[int(stock_choice) - 1] + test_single_stock(analyzer, stock_code, stock_name) + + elif choice == "4": + # 测试所有股票 + for stock_code, stock_name in test_stocks: + test_single_stock(analyzer, stock_code, stock_name) + + elif choice == "5": + # 生成PDF报告 + print("\n可用的股票:") + for i, (code, name) in enumerate(test_stocks, 1): + print(f"{i}. {name}({code})") + + stock_choice = input("\n请选择要生成报告的股票(1-3): ").strip() + if stock_choice.isdigit() and 1 <= int(stock_choice) <= len(test_stocks): + stock_code, stock_name = test_stocks[int(stock_choice) - 1] + filepath = analyzer.generate_pdf_report(stock_code, stock_name) + if filepath: + print(f"\n报告已生成: {filepath}") + else: + print("\n报告生成失败") + + elif choice == "6": + # 生成投资建议并生成PDF + print("\n可用的股票:") + for i, (code, name) in enumerate(test_stocks, 1): + print(f"{i}. {name}({code})") + + stock_choice = input("\n请选择要生成投资建议的股票(1-3): ").strip() + if stock_choice.isdigit() and 1 <= int(stock_choice) <= len(test_stocks): + stock_code, stock_name = test_stocks[int(stock_choice) - 1] + + # 先生成投资建议 + print("\n正在生成投资建议...") + success = analyzer.generate_investment_advice(stock_code, stock_name) + + if success: + print("投资建议生成成功") + # 然后生成PDF报告 + filepath = analyzer.generate_pdf_report(stock_code, stock_name) + if filepath: + print(f"\n报告已生成: {filepath}") + else: + print("\n报告生成失败") + else: + print("\n投资建议生成失败") + + elif choice == "7": + print("程序退出") + else: + print("无效的选项") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/fundamentals_llm/pdf_generator.py b/src/fundamentals_llm/pdf_generator.py index cf3af24..cb51fd9 100644 --- a/src/fundamentals_llm/pdf_generator.py +++ b/src/fundamentals_llm/pdf_generator.py @@ -11,25 +11,25 @@ import markdown2 from bs4 import BeautifulSoup import os from datetime import datetime -from fpdf import FPDF -import matplotlib.pyplot as plt +import shutil + import matplotlib matplotlib.use('Agg') # 修改导入路径,使用相对导入 try: # 尝试相对导入 - from .chat_bot_with_offline import ChatBot - from .fundamental_analysis_database import get_db, get_analysis_result + from .chat_bot import ChatBot + from .chat_bot_with_offline import ChatBot as OfflineChatBot except ImportError: # 如果相对导入失败,尝试绝对导入 try: - from src.fundamentals_llm.chat_bot_with_offline import ChatBot - from src.fundamentals_llm.fundamental_analysis_database import get_db, get_analysis_result + from src.fundamentals_llm.chat_bot import ChatBot + from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot except ImportError: # 最后尝试直接导入 - from chat_bot_with_offline import ChatBot - from fundamental_analysis_database import get_db, get_analysis_result + from chat_bot import ChatBot + from chat_bot_with_offline import ChatBot as OfflineChatBot # 设置日志记录 logger = logging.getLogger(__name__) @@ -39,30 +39,9 @@ class PDFGenerator: def __init__(self): """初始化PDF生成器""" - # 注册中文字体 - try: - # 尝试使用系统自带的中文字体 - if os.name == 'nt': # Windows - font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 - else: # Linux/Mac - font_path = "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf" - - if os.path.exists(font_path): - pdfmetrics.registerFont(TTFont('SimHei', font_path)) - self.font_name = 'SimHei' - else: - # 如果找不到系统字体,尝试使用当前目录下的字体 - font_path = os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf") - if os.path.exists(font_path): - pdfmetrics.registerFont(TTFont('SimHei', font_path)) - self.font_name = 'SimHei' - else: - raise FileNotFoundError("找不到中文字体文件") - except Exception as e: - logger.error(f"注册中文字体失败: {str(e)}") - raise - - self.chat_bot = ChatBot() + # 尝试注册中文字体 + self.font_name = self._register_chinese_font() + self.chat_bot = OfflineChatBot(platform="volc", model_type="offline_model") # 使用离线模式 self.styles = getSampleStyleSheet() # 创建自定义样式 @@ -139,6 +118,64 @@ class PDFGenerator: textColor=colors.HexColor('#333333') )) + def _register_chinese_font(self): + """查找并注册中文字体 + + Returns: + str: 注册的字体名称,如果失败则使用默认字体 + """ + try: + # 可能的字体文件位置列表 + font_locations = [ + # 当前目录/子目录字体位置 + os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf"), + os.path.join(os.path.dirname(__file__), "fonts", "wqy-microhei.ttc"), + os.path.join(os.path.dirname(__file__), "fonts", "wqy-zenhei.ttc"), + # Docker中可能的位置 + "/app/src/fundamentals_llm/fonts/simhei.ttf", + # Windows系统字体位置 + "C:/Windows/Fonts/simhei.ttf", + "C:/Windows/Fonts/simfang.ttf", + "C:/Windows/Fonts/simsun.ttc", + # Linux/Mac系统字体位置 + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", + "/usr/share/fonts/wqy-zenhei/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/System/Library/Fonts/PingFang.ttc" # macOS + ] + + # 尝试各个位置 + for font_path in font_locations: + if os.path.exists(font_path): + logger.info(f"使用字体文件: {font_path}") + # 为防止字体文件名称不同,统一拷贝到字体目录并重命名 + font_dir = os.path.join(os.path.dirname(__file__), "fonts") + os.makedirs(font_dir, exist_ok=True) + target_path = os.path.join(font_dir, "simhei.ttf") + + # 如果字体文件不在目标位置,拷贝过去 + if font_path != target_path and not os.path.exists(target_path): + try: + shutil.copy2(font_path, target_path) + logger.info(f"已将字体文件复制到: {target_path}") + except Exception as e: + logger.warning(f"复制字体文件失败: {str(e)}") + + # 注册字体 + pdfmetrics.registerFont(TTFont('SimHei', target_path)) + return 'SimHei' + + # 如果所有位置都找不到,使用默认字体 + logger.warning("找不到中文字体文件,将使用默认字体") + return 'Helvetica' + + except Exception as e: + logger.error(f"注册中文字体失败: {str(e)}") + return 'Helvetica' # 使用默认字体 + def _convert_markdown_to_flowables(self, markdown_text: str) -> List: """将Markdown文本转换为PDF流对象 @@ -191,18 +228,18 @@ class PDFGenerator: """ try: prompt = f"""请对以下内容进行优化和格式化。要求: -1. 保持原文的专业性和准确性 -2. 将零散的内容整合成连贯的段落,并对不重要的内容精简 -3. 使用适当的标点符号和换行 -4. 使用Markdown格式进行排版 -5. 移除所有引用内容(包括"参考资料:"等) -6. 不要返回其他多余的内容 -7. 确保内容结构清晰,层次分明 -8. 将零散的内容整合成完整的段落,避免过于零散的表述 -9. 使用自然流畅的语言,避免过于机械的结构化表达 - -原始内容: -{content}""" + 1. 保持原文的专业性和准确性 + 2. 将零散的内容整合成连贯的段落,并对不重要的内容精简 + 3. 使用适当的标点符号和换行 + 4. 使用Markdown格式进行排版 + 5. 移除所有引用内容(包括"参考资料:"等) + 6. 不要返回其他多余的内容 + 7. 确保内容结构清晰,层次分明 + 8. 将零散的内容整合成完整的段落,避免过于零散的表述 + 9. 使用自然流畅的语言,避免过于机械的结构化表达 + + 原始内容: + {content}""" result = self.chat_bot.chat(prompt) return result diff --git a/src/fundamentals_llm/setup_fonts.py b/src/fundamentals_llm/setup_fonts.py new file mode 100644 index 0000000..bb86b36 --- /dev/null +++ b/src/fundamentals_llm/setup_fonts.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +中文字体安装工具 + +此脚本用于下载和安装中文字体,以支持PDF生成功能。 +""" + +import os +import sys +import logging +import shutil +import tempfile +import platform + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def download_wqy_font(): + """ + 下载文泉驿微米黑字体 + + Returns: + str: 下载的字体文件路径,如果失败则返回None + """ + try: + import requests + + # 创建临时目录 + temp_dir = tempfile.mkdtemp() + font_file = os.path.join(temp_dir, "wqy-microhei.ttc") + + # 文泉驿微米黑字体的下载链接 + url = "https://mirrors.tuna.tsinghua.edu.cn/osdn/wqy/wqy-microhei-0.2.0-beta.tar.gz" + + logger.info(f"开始下载文泉驿微米黑字体: {url}") + response = requests.get(url, stream=True) + response.raise_for_status() + + # 下载压缩包 + tar_file = os.path.join(temp_dir, "wqy-microhei.tar.gz") + with open(tar_file, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + # 解压字体文件 + import tarfile + with tarfile.open(tar_file) as tar: + # 提取字体文件 + for member in tar.getmembers(): + if member.name.endswith(('.ttc', '.ttf')): + member.name = os.path.basename(member.name) + tar.extract(member, temp_dir) + extracted_file = os.path.join(temp_dir, member.name) + if os.path.exists(extracted_file): + return extracted_file + + logger.error("在压缩包中未找到字体文件") + return None + + except Exception as e: + logger.error(f"下载字体失败: {str(e)}") + return None + +def find_system_fonts(): + """ + 在系统中查找可用的中文字体 + + Returns: + str: 找到的字体文件路径,如果找不到则返回None + """ + # 常见的中文字体位置 + font_locations = [] + + system = platform.system() + if system == "Windows": + font_locations = [ + "C:/Windows/Fonts/simhei.ttf", + "C:/Windows/Fonts/simsun.ttc", + "C:/Windows/Fonts/msyh.ttf", + "C:/Windows/Fonts/simfang.ttf", + ] + elif system == "Darwin": # macOS + font_locations = [ + "/System/Library/Fonts/PingFang.ttc", + "/Library/Fonts/Microsoft/SimHei.ttf", + "/Library/Fonts/Microsoft/SimSun.ttf", + ] + else: # Linux + font_locations = [ + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", + "/usr/share/fonts/wenquanyi/wqy-microhei/wqy-microhei.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf", + ] + + # 查找第一个存在的字体文件 + for font_path in font_locations: + if os.path.exists(font_path): + logger.info(f"在系统中找到中文字体: {font_path}") + return font_path + + logger.warning("在系统中未找到任何中文字体") + return None + +def install_font(target_dir=None): + """ + 安装中文字体 + + Args: + target_dir: 目标目录,如果不提供则使用默认目录 + + Returns: + str: 安装后的字体文件路径,如果失败则返回None + """ + # 如果没有指定目标目录,使用默认位置 + if target_dir is None: + # 获取当前脚本所在目录 + current_dir = os.path.dirname(os.path.abspath(__file__)) + target_dir = os.path.join(current_dir, "fonts") + + # 确保目标目录存在 + os.makedirs(target_dir, exist_ok=True) + + # 目标字体文件路径 + target_font = os.path.join(target_dir, "simhei.ttf") + + # 如果目标字体已存在,直接返回 + if os.path.exists(target_font): + logger.info(f"中文字体已存在: {target_font}") + return target_font + + # 尝试从系统中复制字体 + system_font = find_system_fonts() + if system_font: + try: + shutil.copy2(system_font, target_font) + logger.info(f"已从系统复制字体: {system_font} -> {target_font}") + return target_font + except Exception as e: + logger.error(f"复制系统字体失败: {str(e)}") + + # 如果系统中找不到字体,尝试下载 + logger.info("尝试下载中文字体...") + downloaded_font = download_wqy_font() + if downloaded_font: + try: + shutil.copy2(downloaded_font, target_font) + logger.info(f"已安装下载的字体: {downloaded_font} -> {target_font}") + return target_font + except Exception as e: + logger.error(f"复制下载的字体失败: {str(e)}") + + logger.error("安装中文字体失败") + return None + +def main(): + """ + 主函数 + """ + print("中文字体安装工具") + print("----------------") + + target_dir = None + if len(sys.argv) > 1: + target_dir = sys.argv[1] + print(f"使用指定的目标目录: {target_dir}") + + result = install_font(target_dir) + + if result: + print(f"\n✓ 成功安装中文字体: {result}") + print("\n现在可以正常生成PDF报告了。") + else: + print("\n✗ 安装中文字体失败") + print("\n请手动将中文字体(.ttf或.ttc文件)复制到以下目录:") + if target_dir: + print(f" {target_dir}") + else: + print(f" {os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fonts')}") + print("\n并将其重命名为'simhei.ttf'") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/fundamentals_llm/text_processor.py b/src/fundamentals_llm/text_processor.py new file mode 100644 index 0000000..72e7f6c --- /dev/null +++ b/src/fundamentals_llm/text_processor.py @@ -0,0 +1,172 @@ +import re +import logging +from typing import Optional, List, Dict, Any + +logger = logging.getLogger(__name__) + +class TextProcessor: + """大模型文本处理工具类""" + + @staticmethod + def clean_model_output(output: str) -> str: + """清理模型输出文本 + + Args: + output: 模型输出的原始文本 + + Returns: + str: 清理后的文本 + """ + try: + # 移除引用标记 + output = re.sub(r'\[[^\]]*\]', '', output) + + # 移除多余的空白字符 + output = re.sub(r'\s+', ' ', output) + + # 移除首尾的空白字符 + output = output.strip() + + # 移除可能存在的HTML标签 + output = re.sub(r'<[^>]+>', '', output) + + # 移除可能存在的Markdown格式 + output = re.sub(r'[*_`~]', '', output) + + return output + + except Exception as e: + logger.error(f"清理模型输出失败: {str(e)}") + return output + + @staticmethod + def clean_thought_process(output: str) -> str: + """清理模型输出,移除推理过程,只保留最终结果 + + Args: + output: 模型原始输出文本 + + Returns: + str: 清理后的输出文本 + """ + try: + # 找到标签的位置 + think_end = output.find('') + if think_end != -1: + # 移除标签及其之前的所有内容 + output = output[think_end + len(''):] + + # 处理可能存在的空行 + lines = output.split('\n') + cleaned_lines = [] + for line in lines: + line = line.strip() + if line: # 只保留非空行 + cleaned_lines.append(line) + + # 重新组合文本 + output = '\n'.join(cleaned_lines) + + return output.strip() + + except Exception as e: + logger.error(f"清理模型输出失败: {str(e)}") + return output.strip() + + @staticmethod + def extract_numeric_value_from_response(response: str) -> str: + """从模型响应中提取数值 + + Args: + response: 模型响应的文本 + + Returns: + str: 提取出的数值 + """ + try: + # 清理响应文本 + cleaned_response = TextProcessor.clean_model_output(response) + + # 尝试提取数值 + numeric_patterns = [ + r'[-+]?\d*\.\d+', # 浮点数 + r'[-+]?\d+', # 整数 + r'[-+]?\d+\.\d*' # 可能带小数点的数 + ] + + for pattern in numeric_patterns: + matches = re.findall(pattern, cleaned_response) + if matches: + # 返回第一个匹配的数值 + return matches[0] + + # 如果没有找到数值,返回原始响应 + return cleaned_response + + except Exception as e: + logger.error(f"提取数值失败: {str(e)}") + return response + + @staticmethod + def extract_list_from_response(response: str) -> List[str]: + """从模型响应中提取列表项 + + Args: + response: 模型响应的文本 + + Returns: + List[str]: 提取出的列表项 + """ + try: + # 清理响应文本 + cleaned_response = TextProcessor.clean_model_output(response) + + # 尝试匹配列表项 + # 匹配数字编号的列表项 + numbered_items = re.findall(r'\d+\.\s*([^\n]+)', cleaned_response) + if numbered_items: + return [item.strip() for item in numbered_items] + + # 匹配带符号的列表项 + bullet_items = re.findall(r'[-*•]\s*([^\n]+)', cleaned_response) + if bullet_items: + return [item.strip() for item in bullet_items] + + # 如果没有找到列表项,返回空列表 + return [] + + except Exception as e: + logger.error(f"提取列表项失败: {str(e)}") + return [] + + @staticmethod + def extract_key_value_pairs(response: str) -> Dict[str, str]: + """从模型响应中提取键值对 + + Args: + response: 模型响应的文本 + + Returns: + Dict[str, str]: 提取出的键值对 + """ + try: + # 清理响应文本 + cleaned_response = TextProcessor.clean_model_output(response) + + # 尝试匹配键值对 + # 匹配冒号分隔的键值对 + pairs = re.findall(r'([^:]+):\s*([^\n]+)', cleaned_response) + if pairs: + return {key.strip(): value.strip() for key, value in pairs} + + # 匹配等号分隔的键值对 + pairs = re.findall(r'([^=]+)=\s*([^\n]+)', cleaned_response) + if pairs: + return {key.strip(): value.strip() for key, value in pairs} + + # 如果没有找到键值对,返回空字典 + return {} + + except Exception as e: + logger.error(f"提取键值对失败: {str(e)}") + return {} \ No newline at end of file diff --git a/src/scripts/config.py b/src/scripts/config.py index 9955094..54f631d 100644 --- a/src/scripts/config.py +++ b/src/scripts/config.py @@ -66,12 +66,22 @@ MODEL_CONFIGS = { "doubao": "doubao-1-5-pro-32k-250115" } }, + # 谷歌Gemini + "Gemini": { + "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", + "models": { + "offline_model": "gemini-2.0-flash" + } + }, # 天链苹果 "tl_private": { - "base_url": "http://192.168.32.118:1234/v1/", + "base_url": "http://192.168.16.174:1234/v1/", "api_key": "none", "models": { - "ds-v1": "mlx-community/DeepSeek-R1-4bit" + "glm-z1": "glm-z1-rumination-32b-0414", + "glm-4": "glm-4-32b-0414-abliterated", + "ds_v1": "mlx-community/DeepSeek-R1-4bit", } }, # 天链-千问 @@ -79,7 +89,8 @@ MODEL_CONFIGS = { "base_url": "http://192.168.16.178:11434/v1", "api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019", "models": { - "qwq": "qwq:32b" + "qwq": "qwq:32b", + "GLM": "hf-mirror.com/Cobra4687/GLM-4-32B-0414-abliterated-Q4_K_M-GGUF:Q4_K_M" } }, # Deepseek配置 diff --git a/src/stock_analysis_v2 copy.py b/src/stock_analysis_v2 copy.py new file mode 100644 index 0000000..9df554d --- /dev/null +++ b/src/stock_analysis_v2 copy.py @@ -0,0 +1,986 @@ +import pandas as pd +import numpy as np +from sqlalchemy import create_engine +from datetime import datetime, timedelta +from tqdm import tqdm +import matplotlib.pyplot as plt +import os + +# v2版本只做表格 + +# 添加中文字体支持 +plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 +plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + +class StockAnalyzer: + def __init__(self, db_connection_string): + """ + Initialize the stock analyzer + :param db_connection_string: Database connection string + """ + self.engine = create_engine(db_connection_string) + self.initial_capital = 300000 # 30万本金 + self.total_position = 1000000 # 100万建仓规模 + self.borrowed_capital = 700000 # 70万借入资金 + self.annual_interest_rate = 0.05 # 年化5%利息 + self.daily_interest_rate = self.annual_interest_rate / 365 + self.commission_rate = 0.05 # 5%手续费 + self.min_holding_days = 7 # 最少持仓天数 + + def get_stock_list(self): + """获取所有股票列表""" + # query = "SELECT DISTINCT gp_code as symbol FROM gp_code_all where mark1 = '1'" + # return pd.read_sql(query, self.engine)['symbol'].tolist() + return ['SH600522', + 'SZ002340', + 'SZ000733', + 'SH601615', + 'SH600157', + 'SH688005', + 'SH600903', + 'SH600956', + 'SH601187', + 'SH603983'] # 你可以在这里添加更多股票代码 + + def get_stock_data(self, symbol, start_date, end_date, include_history=False): + """获取指定股票在时间范围内的数据""" + try: + if include_history: + # 如果需要历史数据,获取start_date之前240天的数据 + query = f""" + WITH history_data AS ( + SELECT + symbol, + timestamp, + CAST(close as DECIMAL(10,2)) as close + FROM gp_day_data + WHERE symbol = '{symbol}' + AND timestamp < '{start_date}' + ORDER BY timestamp DESC + LIMIT 240 + ) + SELECT + symbol, + timestamp, + CAST(open as DECIMAL(10,2)) as open, + CAST(high as DECIMAL(10,2)) as high, + CAST(low as DECIMAL(10,2)) as low, + CAST(close as DECIMAL(10,2)) as close, + CAST(volume as DECIMAL(20,0)) as volume + FROM gp_day_data + WHERE symbol = '{symbol}' + AND timestamp BETWEEN '{start_date}' AND '{end_date}' + UNION ALL + SELECT + symbol, + timestamp, + NULL as open, + NULL as high, + NULL as low, + close, + NULL as volume + FROM history_data + ORDER BY timestamp + """ + else: + query = f""" + SELECT + symbol, + timestamp, + CAST(open as DECIMAL(10,2)) as open, + CAST(high as DECIMAL(10,2)) as high, + CAST(low as DECIMAL(10,2)) as low, + CAST(close as DECIMAL(10,2)) as close, + CAST(volume as DECIMAL(20,0)) as volume + FROM gp_day_data + WHERE symbol = '{symbol}' + AND timestamp BETWEEN '{start_date}' AND '{end_date}' + ORDER BY timestamp + """ + + # 使用chunksize分批读取数据 + chunks = [] + for chunk in pd.read_sql(query, self.engine, chunksize=1000): + chunks.append(chunk) + df = pd.concat(chunks, ignore_index=True) + + # 确保数值列的类型正确 + numeric_columns = ['open', 'high', 'low', 'close', 'volume'] + for col in numeric_columns: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors='coerce') + + # 删除任何包含 NaN 的行,但只考虑分析期间需要的列 + if include_history: + historical_mask = (df['timestamp'] < start_date) & df['close'].notna() + analysis_mask = (df['timestamp'] >= start_date) & df[['open', 'high', 'low', 'close', 'volume']].notna().all(axis=1) + df = df[historical_mask | analysis_mask] + else: + df = df.dropna() + + return df + + except Exception as e: + print(f"获取股票 {symbol} 数据时出错: {str(e)}") + return pd.DataFrame() + + def calculate_trade_signals(self, df, take_profit_pct, stop_loss_pct, pullback_entry_pct, start_date): + """计算交易信号和收益""" + """计算交易信号和收益""" + # 检查历史数据是否足够 + historical_data = df[df['timestamp'] < start_date] + if len(historical_data) < 240: + # print(f"历史数据不足240天,跳过分析") + return pd.DataFrame() + + # 计算240日均线基准价格(使用历史数据) + ma240_base_price = historical_data['close'].tail(240).mean() + # print(f"240日均线基准价格: {ma240_base_price:.2f}") + + # 只使用分析期间的数据进行交易信号计算 + analysis_df = df[df['timestamp'] >= start_date].copy() + if len(analysis_df) < 10: # 确保分析期间有足够的数据 + # print(f"分析期间数据不足10天,跳过分析") + return pd.DataFrame() + + results = [] + positions = [] # 记录所有持仓,每个元素是一个字典,包含入场价格和日期 + max_positions = 5 # 最大建仓次数 + cumulative_profit = 0 # 总累计收益 + trades_count = 0 + max_holding_period = timedelta(days=180) # 最长持仓6个月 + max_loss_amount = -300000 # 最大亏损金额为30万 + + # 添加统计变量 + total_holding_days = 0 # 总持仓天数 + max_capital_usage = 0 # 最大资金占用 + total_trades = 0 # 总交易次数 + profitable_trades = 0 # 盈利交易次数 + loss_trades = 0 # 亏损交易次数 + + # print(f"\n{'='*50}") + # print(f"开始分析股票: {df.iloc[0]['symbol']}") + # print(f"{'='*50}") + + # 记录前一天的收盘价 + prev_close = None + + for index, row in analysis_df.iterrows(): + current_price = (float(row['high']) + float(row['low'])) / 2 + day_result = { + 'timestamp': row['timestamp'], + 'profit': cumulative_profit, # 记录当前的累计收益 + 'position': False, + 'action': 'hold' + } + + # 更新所有持仓的收益 + positions_to_remove = [] + + for pos in positions: + pos_days_held = (row['timestamp'] - pos['entry_date']).days + pos_size = self.total_position / pos['entry_price'] + + # 检查是否需要平仓 + should_close = False + close_type = None + close_price = current_price + + if pos_days_held >= self.min_holding_days: + # 检查止盈 + if float(row['high']) >= pos['entry_price'] * (1 + take_profit_pct): + should_close = True + close_type = 'sell_profit' + close_price = pos['entry_price'] * (1 + take_profit_pct) + # 检查最大亏损 + elif (close_price - pos['entry_price']) * pos_size <= max_loss_amount: + should_close = True + close_type = 'sell_loss' + # 检查持仓时间 + elif (row['timestamp'] - pos['entry_date']) >= max_holding_period: + should_close = True + close_type = 'time_limit' + + if should_close: + profit = (close_price - pos['entry_price']) * pos_size + commission = profit * self.commission_rate if profit > 0 else 0 + interest_cost = self.borrowed_capital * self.daily_interest_rate * pos_days_held + net_profit = profit - commission - interest_cost + + # 更新统计数据 + total_holding_days += pos_days_held + total_trades += 1 + if net_profit > 0: + profitable_trades += 1 + else: + loss_trades += 1 + + cumulative_profit += net_profit + positions_to_remove.append(pos) + day_result['action'] = close_type + day_result['profit'] = cumulative_profit + + # print(f"\n>>> {'止盈平仓' if close_type == 'sell_profit' else '止损平仓' if close_type == 'sell_loss' else '到期平仓'}") + # print(f"持仓天数: {pos_days_held}") + # print(f"持仓数量: {pos_size:.0f}股") + # print(f"平仓价格: {close_price:.2f}") + # print(f"{'盈利' if profit > 0 else '亏损'}: {profit:,.2f}") + # print(f"手续费: {commission:,.2f}") + # print(f"利息成本: {interest_cost:,.2f}") + # print(f"净{'盈利' if net_profit > 0 else '亏损'}: {net_profit:,.2f}") + + # 移除已平仓的持仓 + for pos in positions_to_remove: + positions.remove(pos) + + # 检查是否可以建仓 + if len(positions) < max_positions and prev_close is not None: + # 计算当日跌幅 + daily_drop = (prev_close - float(row['low'])) / prev_close + + # 检查价格是否在均线区间内 + price_in_range = ( + current_price >= ma240_base_price * 0.7 and + current_price <= ma240_base_price * 1.3 + ) + + # 如果跌幅超过5%且价格在均线区间内 + if daily_drop >= 0.05 and price_in_range: + positions.append({ + 'entry_price': current_price, + 'entry_date': row['timestamp'] + }) + trades_count += 1 + # print(f"\n>>> 建仓信号 #{trades_count}") + # print(f"日期: {row['timestamp'].strftime('%Y-%m-%d')}") + # print(f"建仓价格: {current_price:.2f}") + # print(f"建仓数量: {self.total_position/current_price:.0f}股") + # print(f"当日跌幅: {daily_drop*100:.2f}%") + # print(f"距离均线: {((current_price/ma240_base_price)-1)*100:.2f}%") + + # 更新前一天收盘价 + prev_close = float(row['close']) + + # 更新日结果 + day_result['position'] = len(positions) > 0 + results.append(day_result) + + # 计算当前资金占用 + current_capital_usage = sum(self.total_position for pos in positions) + max_capital_usage = max(max_capital_usage, current_capital_usage) + + # 在最后一天强制平仓所有持仓 + if positions: + final_price = (float(analysis_df.iloc[-1]['high']) + float(analysis_df.iloc[-1]['low'])) / 2 + final_total_profit = 0 + + for pos in positions: + position_size = self.total_position / pos['entry_price'] + days_held = (analysis_df.iloc[-1]['timestamp'] - pos['entry_date']).days + final_profit = (final_price - pos['entry_price']) * position_size + commission = final_profit * self.commission_rate if final_profit > 0 else 0 + interest_cost = self.borrowed_capital * self.daily_interest_rate * days_held + net_profit = final_profit - commission - interest_cost + + # 更新统计数据 + total_holding_days += days_held + total_trades += 1 + if net_profit > 0: + profitable_trades += 1 + else: + loss_trades += 1 + + final_total_profit += net_profit + + # print(f"\n>>> 到期强制平仓") + # print(f"持仓天数: {days_held}") + # print(f"持仓数量: {position_size:.0f}股") + # print(f"平仓价格: {final_price:.2f}") + # print(f"毛利润: {final_profit:,.2f}") + # print(f"手续费: {commission:,.2f}") + # print(f"利息成本: {interest_cost:,.2f}") + # print(f"净利润: {net_profit:,.2f}") + + results[-1]['action'] = 'final_sell' + cumulative_profit += final_total_profit # 更新最终的累计收益 + results[-1]['profit'] = cumulative_profit # 更新最后一天的累计收益 + + # 计算统计数据 + avg_holding_days = total_holding_days / total_trades if total_trades > 0 else 0 + win_rate = profitable_trades / total_trades * 100 if total_trades > 0 else 0 + + # print(f"\n{'='*50}") + # print(f"交易统计") + # print(f"总交易次数: {trades_count}") + # print(f"累计收益: {cumulative_profit:,.2f}") + # print(f"最大资金占用: {max_capital_usage:,.2f}") + # print(f"平均持仓天数: {avg_holding_days:.1f}") + # print(f"胜率: {win_rate:.1f}%") + # print(f"盈利交易: {profitable_trades}次") + # print(f"亏损交易: {loss_trades}次") + # print(f"{'='*50}\n") + + # 将统计数据添加到结果中 + if len(results) > 0: + results[-1]['stats'] = { + 'total_trades': total_trades, + 'profitable_trades': profitable_trades, + 'loss_trades': loss_trades, + 'win_rate': win_rate, + 'avg_holding_days': avg_holding_days, + 'max_capital_usage': max_capital_usage, + 'final_profit': cumulative_profit + } + + return pd.DataFrame(results) + + def analyze_stock(self, symbol, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): + """分析单个股票""" + df = self.get_stock_data(symbol, start_date, end_date, include_history=True) + + if len(df) < 240: # 确保总数据量足够 + print(f"股票 {symbol} 总数据不足240天,跳过分析") + return None + + return self.calculate_trade_signals( + df, + take_profit_pct=take_profit_pct, + stop_loss_pct=stop_loss_pct, + pullback_entry_pct=pullback_entry_pct, + start_date=start_date + ) + + def analyze_all_stocks(self, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): + """分析所有股票""" + stocks = self.get_stock_list() + all_results = {} + analysis_summary = { + 'processed': [], + 'skipped': [], + 'error': [] + } + + print(f"\n开始分析,共 {len(stocks)} 支股票") + + for symbol in tqdm(stocks, desc="Analyzing stocks"): + try: + # 获取股票数据 + df = self.get_stock_data(symbol, start_date, end_date, include_history=True) + + if df.empty: + print(f"\n股票 {symbol} 数据为空,跳过分析") + analysis_summary['skipped'].append({ + 'symbol': symbol, + 'reason': 'empty_data', + 'take_profit_pct': take_profit_pct + }) + continue + + # 检查历史数据 + historical_data = df[df['timestamp'] < start_date] + if len(historical_data) < 240: + # print(f"\n股票 {symbol} 历史数据不足240天({len(historical_data)}天),跳过分析") + analysis_summary['skipped'].append({ + 'symbol': symbol, + 'reason': f'insufficient_history_{len(historical_data)}days', + 'take_profit_pct': take_profit_pct + }) + continue + + # 检查分析期数据 + analysis_data = df[df['timestamp'] >= start_date] + if len(analysis_data) < 10: + print(f"\n股票 {symbol} 分析期数据不足10天({len(analysis_data)}天),跳过分析") + analysis_summary['skipped'].append({ + 'symbol': symbol, + 'reason': f'insufficient_analysis_data_{len(analysis_data)}days', + 'take_profit_pct': take_profit_pct + }) + continue + + # 计算交易信号 + result = self.calculate_trade_signals( + df, + take_profit_pct, + stop_loss_pct, + pullback_entry_pct, + start_date + ) + + if result is not None and not result.empty: + # 检查是否有交易记录 + has_trades = any(action in ['sell_profit', 'sell_loss', 'time_limit', 'final_sell'] + for action in result['action'] if pd.notna(action)) + + if has_trades: + all_results[symbol] = result + analysis_summary['processed'].append({ + 'symbol': symbol, + 'take_profit_pct': take_profit_pct, + 'trades': len([x for x in result['action'] if pd.notna(x) and x != 'hold']), + 'profit': result['profit'].iloc[-1] + }) + else: + # print(f"\n股票 {symbol} 没有产生有效的交易信号") + analysis_summary['skipped'].append({ + 'symbol': symbol, + 'reason': 'no_valid_signals', + 'take_profit_pct': take_profit_pct + }) + + except Exception as e: + print(f"\n处理股票 {symbol} 时出错: {str(e)}") + analysis_summary['error'].append({ + 'symbol': symbol, + 'error': str(e), + 'take_profit_pct': take_profit_pct + }) + continue + + # 打印分析总结 + print(f"\n分析完成:") + print(f"成功分析: {len(analysis_summary['processed'])} 支股票") + print(f"跳过分析: {len(analysis_summary['skipped'])} 支股票") + print(f"出错股票: {len(analysis_summary['error'])} 支股票") + + # 保存分析总结 + if not os.path.exists('results/analysis_summary'): + os.makedirs('results/analysis_summary') + + # 保存处理成功的股票信息 + if analysis_summary['processed']: + pd.DataFrame(analysis_summary['processed']).to_csv( + f'results/analysis_summary/processed_stocks_{take_profit_pct*100:.1f}.csv', + index=False + ) + + # 保存跳过的股票信息 + if analysis_summary['skipped']: + pd.DataFrame(analysis_summary['skipped']).to_csv( + f'results/analysis_summary/skipped_stocks_{take_profit_pct*100:.1f}.csv', + index=False + ) + + # 保存出错的股票信息 + if analysis_summary['error']: + pd.DataFrame(analysis_summary['error']).to_csv( + f'results/analysis_summary/error_stocks_{take_profit_pct*100:.1f}.csv', + index=False + ) + + return all_results + + def plot_results(self, results, symbol): + """绘制单个股票的收益走势图""" + plt.figure(figsize=(12, 6)) + plt.plot(results['timestamp'], results['profit'].cumsum(), label='Cumulative Profit') + plt.title(f'Stock {symbol} Trading Results') + plt.xlabel('Date') + plt.ylabel('Profit (CNY)') + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + return plt + + def plot_all_stocks(self, all_results): + """绘制所有股票的收益走势图""" + plt.figure(figsize=(15, 8)) + + # 记录所有股票的累计收益 + total_profits = {} + max_profit = float('-inf') + min_profit = float('inf') + + # 绘制每支股票的曲线 + for symbol, results in all_results.items(): + # 直接使用已经计算好的累计收益 + profits = results['profit'] + plt.plot(results['timestamp'], profits, label=symbol, alpha=0.5, linewidth=1) + + # 更新最大最小收益 + max_profit = max(max_profit, profits.max()) + min_profit = min(min_profit, profits.min()) + + # 记录总收益(使用最后一个累计收益值) + total_profits[symbol] = profits.iloc[-1] + + # 计算所有股票的中位数收益曲线 + # 首先找到共同的日期范围 + all_dates = set() + for results in all_results.values(): + all_dates.update(results['timestamp']) + all_dates = sorted(list(all_dates)) + + # 创建一个包含所有日期的DataFrame + profits_df = pd.DataFrame(index=all_dates) + + # 对每支股票,填充所有日期的收益 + for symbol, results in all_results.items(): + daily_profits = pd.Series(index=all_dates, data=np.nan) + # 直接使用已经计算好的累计收益 + for date, profit in zip(results['timestamp'], results['profit']): + daily_profits[date] = profit + # 对于没有交易的日期,使用最后一个累计收益填充 + daily_profits.fillna(method='ffill', inplace=True) + daily_profits.fillna(0, inplace=True) # 对开始之前的日期填充0 + profits_df[symbol] = daily_profits + + # 计算并绘制中位数收益曲线 + median_line = profits_df.median(axis=1) + plt.plot(all_dates, median_line, 'r-', label='Median', linewidth=2) + + plt.title('All Stocks Trading Results') + plt.xlabel('Date') + plt.ylabel('Profit (CNY)') + plt.grid(True) + plt.xticks(rotation=45) + + # 添加利润区间标注 + plt.text(0.02, 0.98, + f'Profit Range:\nMax: {max_profit:,.0f}\nMin: {min_profit:,.0f}\nMedian: {median_line.iloc[-1]:,.0f}', + transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8)) + + # 计算所有股票的统计数据 + total_stats = { + 'total_trades': 0, + 'profitable_trades': 0, + 'loss_trades': 0, + 'total_holding_days': 0, + 'max_capital_usage': 0, + 'total_profit': 0, + 'profitable_stocks': 0, + 'loss_stocks': 0, + 'total_capital_usage': 0, # 添加总资金占用 + 'total_win_rate': 0, # 添加总胜率 + } + + # 获取前5名和后5名股票 + top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] + bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] + + # 添加排名信息 + rank_text = "Top 5 Stocks:\n" + for symbol, profit in top_5: + rank_text += f"{symbol}: {profit:,.0f}\n" + rank_text += "\nBottom 5 Stocks:\n" + for symbol, profit in bottom_5: + rank_text += f"{symbol}: {profit:,.0f}\n" + + plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') + + # 记录每支股票的统计数据用于计算平均值 + stock_stats = [] + + for symbol, results in all_results.items(): + if 'stats' in results.iloc[-1]: + stats = results.iloc[-1]['stats'] + stock_stats.append(stats) # 记录每支股票的统计数据 + + total_stats['total_trades'] += stats['total_trades'] + total_stats['profitable_trades'] += stats['profitable_trades'] + total_stats['loss_trades'] += stats['loss_trades'] + total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] + total_stats['max_capital_usage'] = max(total_stats['max_capital_usage'], stats['max_capital_usage']) + total_stats['total_capital_usage'] += stats['max_capital_usage'] + total_stats['total_profit'] += stats['final_profit'] + total_stats['total_win_rate'] += stats['win_rate'] + if stats['final_profit'] > 0: + total_stats['profitable_stocks'] += 1 + else: + total_stats['loss_stocks'] += 1 + + # 计算总体统计 + total_stocks = total_stats['profitable_stocks'] + total_stats['loss_stocks'] + avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 + win_rate = total_stats['profitable_trades'] / total_stats['total_trades'] * 100 if total_stats['total_trades'] > 0 else 0 + stock_win_rate = total_stats['profitable_stocks'] / total_stocks * 100 if total_stocks > 0 else 0 + + # 计算平均值统计 + avg_trades_per_stock = total_stats['total_trades'] / total_stocks if total_stocks > 0 else 0 + avg_profit_per_stock = total_stats['total_profit'] / total_stocks if total_stocks > 0 else 0 + avg_capital_usage = total_stats['total_capital_usage'] / total_stocks if total_stocks > 0 else 0 + avg_win_rate = total_stats['total_win_rate'] / total_stocks if total_stocks > 0 else 0 + + # 添加总体统计信息 + stats_text = ( + f"总体统计:\n" + f"总交易次数: {total_stats['total_trades']}\n" + # f"最大资金占用: {total_stats['max_capital_usage']:,.0f}\n" + f"平均持仓天数: {avg_holding_days:.1f}\n" + f"交易胜率: {win_rate:.1f}%\n" + f"股票胜率: {stock_win_rate:.1f}%\n" + f"盈利股票数: {total_stats['profitable_stocks']}\n" + f"亏损股票数: {total_stats['loss_stocks']}\n" + f"\n平均值统计:\n" + f"平均交易次数: {avg_trades_per_stock:.1f}\n" + f"平均收益: {avg_profit_per_stock:,.0f}\n" + f"平均资金占用: {avg_capital_usage:,.0f}\n" + f"平均持仓天数: {avg_holding_days:.1f}\n" + f"平均胜率: {avg_win_rate:.1f}%" + ) + + plt.text(0.02, 0.6, stats_text, + transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), + verticalalignment='top', + fontsize=10) + + # 获取前5名和后5名股票 + top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] + bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] + + # 添加排名信息 + rank_text = "Top 5 Stocks:\n" + for symbol, profit in top_5: + rank_text += f"{symbol}: {profit:,.0f}\n" + rank_text += "\nBottom 5 Stocks:\n" + for symbol, profit in bottom_5: + rank_text += f"{symbol}: {profit:,.0f}\n" + + plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') + + plt.tight_layout() + return plt, total_profits + + def plot_profit_distribution(self, total_profits): + """绘制所有股票的收益分布小提琴图""" + plt.figure(figsize=(12, 8)) + + # 将收益数据转换为数组,确保使用最终累计收益 + profits = np.array([profit for profit in total_profits.values()]) + + # 计算统计数据 + mean_profit = np.mean(profits) + median_profit = np.median(profits) + max_profit = np.max(profits) + min_profit = np.min(profits) + std_profit = np.std(profits) + + # 绘制小提琴图 + violin_parts = plt.violinplot(profits, positions=[0], showmeans=True, showmedians=True) + + # 设置小提琴图的颜色 + violin_parts['bodies'][0].set_facecolor('lightblue') + violin_parts['bodies'][0].set_alpha(0.7) + violin_parts['cmeans'].set_color('red') + violin_parts['cmedians'].set_color('blue') + + # 添加统计信息 + stats_text = ( + f"统计信息:\n" + f"平均收益: {mean_profit:,.0f}\n" + f"中位数收益: {median_profit:,.0f}\n" + f"最大收益: {max_profit:,.0f}\n" + f"最小收益: {min_profit:,.0f}\n" + f"标准差: {std_profit:,.0f}" + ) + + plt.text(0.65, 0.95, stats_text, + transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), + verticalalignment='top', + fontsize=10) + + # 设置图表样式 + plt.title('股票收益分布图', fontsize=14, pad=20) + plt.ylabel('收益 (元)', fontsize=12) + plt.xticks([0], ['所有股票'], fontsize=10) + plt.grid(True, axis='y', alpha=0.3) + + # 添加零线 + plt.axhline(y=0, color='r', linestyle='--', alpha=0.3) + + # 计算盈利和亏损的股票数量 + profit_count = np.sum(profits > 0) + loss_count = np.sum(profits < 0) + total_count = len(profits) + + # 添加盈亏比例信息 + ratio_text = ( + f"盈亏比例:\n" + f"盈利: {profit_count}支 ({profit_count/total_count*100:.1f}%)\n" + f"亏损: {loss_count}支 ({loss_count/total_count*100:.1f}%)\n" + f"总计: {total_count}支" + ) + + plt.text(0.65, 0.6, ratio_text, + transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), + verticalalignment='top', + fontsize=10) + + plt.tight_layout() + return plt + + def plot_profit_matrix(self, df): + """绘制止盈比例分析矩阵的结果图""" + plt.figure(figsize=(15, 10)) + + # 创建子图 + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) + + # 1. 总收益曲线 + ax1.plot(df.index, df['total_profit'], 'b-', linewidth=2) + ax1.set_title('总收益 vs 止盈比例') + ax1.set_xlabel('止盈比例 (%)') + ax1.set_ylabel('总收益 (元)') + ax1.grid(True) + + # 标记最优止盈比例点 + best_profit_pct = df['total_profit'].idxmax() + best_profit = df['total_profit'].max() + ax1.plot(best_profit_pct, best_profit, 'ro') + ax1.annotate(f'最优: {best_profit_pct:.1f}%\n{best_profit:,.0f}元', + (best_profit_pct, best_profit), + xytext=(10, 10), textcoords='offset points') + + # 2. 平均每笔收益和中位数收益 + ax2.plot(df.index, df['avg_profit_per_trade'], 'g-', linewidth=2, label='平均收益') + ax2.plot(df.index, df['median_profit_per_trade'], 'r--', linewidth=2, label='中位数收益') + ax2.set_title('每笔交易收益 vs 止盈比例') + ax2.set_xlabel('止盈比例 (%)') + ax2.set_ylabel('收益 (元)') + ax2.legend() + ax2.grid(True) + + # 3. 平均持仓天数曲线 + ax3.plot(df.index, df['avg_holding_days'], 'r-', linewidth=2) + ax3.set_title('平均持仓天数 vs 止盈比例') + ax3.set_xlabel('止盈比例 (%)') + ax3.set_ylabel('持仓天数') + ax3.grid(True) + + # 4. 胜率对比 + ax4.plot(df.index, df['trade_win_rate'], 'c-', linewidth=2, label='交易胜率') + ax4.plot(df.index, df['stock_win_rate'], 'm--', linewidth=2, label='股票胜率') + ax4.set_title('胜率 vs 止盈比例') + ax4.set_xlabel('止盈比例 (%)') + ax4.set_ylabel('胜率 (%)') + ax4.legend() + ax4.grid(True) + + # 调整布局 + plt.tight_layout() + return plt + + def analyze_profit_matrix(self, start_date, end_date, stop_loss_pct, pullback_entry_pct): + """分析不同止盈比例的表现矩阵""" + # 创建results目录和临时文件目录 + if not os.path.exists('results/temp'): + os.makedirs('results/temp') + + # 从30%到5%,每次减少1% + profit_percentages = list(np.arange(0.30, 0.04, -0.01)) + + for take_profit_pct in profit_percentages: + # 检查是否已经分析过这个止盈点 + temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' + if os.path.exists(temp_file): + print(f"\n止盈比例 {take_profit_pct*100:.1f}% 已分析,跳过") + continue + + print(f"\n分析止盈比例: {take_profit_pct*100:.1f}%") + + try: + # 运行分析 + results = self.analyze_all_stocks( + start_date, + end_date, + take_profit_pct, + stop_loss_pct, + pullback_entry_pct + ) + + # 过滤掉空结果 + valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} + + if valid_results: + # 计算统计数据 + total_stats = { + 'total_trades': 0, + 'profitable_trades': 0, + 'loss_trades': 0, + 'total_holding_days': 0, + 'total_profit': 0, + 'total_capital_usage': 0, + 'total_win_rate': 0, + 'stock_count': len(valid_results), + 'profitable_stocks': 0, + 'all_trade_profits': [] + } + + # 收集每支股票的统计数据 + for symbol, result in valid_results.items(): + if 'stats' in result.iloc[-1]: + stats = result.iloc[-1]['stats'] + total_stats['total_trades'] += stats['total_trades'] + total_stats['profitable_trades'] += stats['profitable_trades'] + total_stats['loss_trades'] += stats['loss_trades'] + total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] + total_stats['total_profit'] += stats['final_profit'] + total_stats['total_capital_usage'] += stats['max_capital_usage'] + total_stats['total_win_rate'] += stats['win_rate'] + + if stats['final_profit'] > 0: + total_stats['profitable_stocks'] += 1 + + if hasattr(stats, 'trade_profits'): + total_stats['all_trade_profits'].extend(stats['trade_profits']) + + # 计算统计数据 + avg_trades = total_stats['total_trades'] / total_stats['stock_count'] + avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 + avg_profit_per_trade = total_stats['total_profit'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 + median_profit_per_trade = np.median(total_stats['all_trade_profits']) if total_stats['all_trade_profits'] else 0 + total_profit = total_stats['total_profit'] + trade_win_rate = (total_stats['profitable_trades'] / total_stats['total_trades'] * 100) if total_stats['total_trades'] > 0 else 0 + stock_win_rate = (total_stats['profitable_stocks'] / total_stats['stock_count'] * 100) if total_stats['stock_count'] > 0 else 0 + + # 创建单个止盈点的结果DataFrame + result_df = pd.DataFrame([{ + 'take_profit_pct': take_profit_pct * 100, + 'total_profit': total_profit, + 'avg_profit_per_trade': avg_profit_per_trade, + 'median_profit_per_trade': median_profit_per_trade, + 'avg_holding_days': avg_holding_days, + 'trade_win_rate': trade_win_rate, + 'stock_win_rate': stock_win_rate, + 'total_trades': total_stats['total_trades'], + 'stock_count': total_stats['stock_count'] + }]) + + # 保存单个止盈点的结果 + result_df.to_csv(temp_file, index=False) + print(f"保存止盈点 {take_profit_pct*100:.1f}% 的分析结果") + + # 清理内存 + del valid_results + del results + if 'total_stats' in locals(): + del total_stats + + except Exception as e: + print(f"处理止盈点 {take_profit_pct*100:.1f}% 时出错: {str(e)}") + continue + + # 合并所有结果 + print("\n合并所有分析结果...") + all_results = [] + for take_profit_pct in profit_percentages: + temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' + if os.path.exists(temp_file): + try: + df = pd.read_csv(temp_file) + all_results.append(df) + except Exception as e: + print(f"读取文件 {temp_file} 时出错: {str(e)}") + + if all_results: + # 合并所有结果 + df = pd.concat(all_results, ignore_index=True) + df = df.round({ + 'take_profit_pct': 1, + 'total_profit': 0, + 'avg_profit_per_trade': 0, + 'median_profit_per_trade': 0, + 'avg_holding_days': 1, + 'trade_win_rate': 1, + 'stock_win_rate': 1, + 'total_trades': 0, + 'stock_count': 0 + }) + + # 设置止盈比例为索引并排序 + df.set_index('take_profit_pct', inplace=True) + df.sort_index(ascending=False, inplace=True) + + # 保存最终结果 + df.to_csv('results/profit_matrix_analysis.csv') + + # 打印结果 + print("\n止盈比例分析矩阵:") + print("=" * 120) + print(df.to_string()) + print("=" * 120) + + try: + # 绘制并保存矩阵分析图表 + plt = self.plot_profit_matrix(df) + plt.savefig('results/profit_matrix_analysis.png', bbox_inches='tight', dpi=300) + plt.close() + except Exception as e: + print(f"绘制图表时出错: {str(e)}") + + return df + else: + print("没有找到任何分析结果") + return pd.DataFrame() + +def main(): + # 数据库连接配置 + db_config = { + 'host': '192.168.18.199', + 'port': 3306, + 'user': 'root', + 'password': 'Chlry#$.8', + 'database': 'db_gp_cj' + } + + try: + connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}" + + # 创建results目录 + + if not os.path.exists('results'): + os.makedirs('results') + + # 创建分析器实例 + analyzer = StockAnalyzer(connection_string) + + # 设置分析参数 + start_date = '2022-05-05' + end_date = '2023-08-28' + stop_loss_pct = -0.99 # -99%止损 + pullback_entry_pct = -0.05 # -5%回调建仓 + + # 运行止盈比例分析 + profit_matrix = analyzer.analyze_profit_matrix( + start_date, + end_date, + stop_loss_pct, + pullback_entry_pct + ) + + if not profit_matrix.empty: + # 使用最优止盈比例进行完整分析 + take_profit_pct = profit_matrix['total_profit'].idxmax() / 100 + print(f"\n使用最优止盈比例 {take_profit_pct*100:.1f}% 运行详细分析") + + # 运行分析并只生成所需的图表 + results = analyzer.analyze_all_stocks( + start_date, + end_date, + take_profit_pct, + stop_loss_pct, + pullback_entry_pct + ) + + # 过滤掉空结果 + valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} + + if valid_results: + # 只生成所需的两个图表 + plt, _ = analyzer.plot_all_stocks(valid_results) + plt.savefig('results/all_stocks_analysis.png', bbox_inches='tight', dpi=300) + plt.close() + + except Exception as e: + print(f"程序运行出错: {str(e)}") + raise + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/stock_analysis_v2.py b/src/stock_analysis_v2.py index 13f10f9..760aab7 100644 --- a/src/stock_analysis_v2.py +++ b/src/stock_analysis_v2.py @@ -1,24 +1,37 @@ import pandas as pd import numpy as np -from sqlalchemy import create_engine +import pymysql from datetime import datetime, timedelta from tqdm import tqdm import matplotlib.pyplot as plt import os -# v2版本只做表格 - # 添加中文字体支持 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 -class StockAnalyzer: - def __init__(self, db_connection_string): +class StockBacktester: + def __init__(self, db_config): """ - Initialize the stock analyzer - :param db_connection_string: Database connection string + 初始化股票回测器 + :param db_config: 数据库配置字典 """ - self.engine = create_engine(db_connection_string) + self.db_config = db_config + # 测试连接 + try: + conn = pymysql.connect( + host=self.db_config['host'], + port=self.db_config['port'], + user=self.db_config['user'], + password=self.db_config['password'], + database=self.db_config['database'] + ) + conn.close() + print("数据库连接成功") + except Exception as e: + print(f"数据库连接错误: {e}") + raise + self.initial_capital = 300000 # 30万本金 self.total_position = 1000000 # 100万建仓规模 self.borrowed_capital = 700000 # 70万借入资金 @@ -26,573 +39,43 @@ class StockAnalyzer: self.daily_interest_rate = self.annual_interest_rate / 365 self.commission_rate = 0.05 # 5%手续费 self.min_holding_days = 7 # 最少持仓天数 + self.take_profit_pct = 0.20 # 固定20%止盈 - def get_stock_list(self): - """获取所有股票列表""" - # query = "SELECT DISTINCT gp_code as symbol FROM gp_code_all where mark1 = '1'" - # return pd.read_sql(query, self.engine)['symbol'].tolist() - return ['SH600522', - 'SZ002340', - 'SZ000733', - 'SH601615', - 'SH600157', - 'SH688005', - 'SZ002180', - 'SZ000723', - 'SZ002384', - 'SH600885', - 'SZ000009', - 'SZ300285', - 'SH605358', - 'SH601555', - 'SH601117', - 'SZ002385', - 'SH688099', - 'SH600884', - 'SH600298', - 'SH600399', - 'SZ002028', - 'SH603613', - 'SZ300037', - 'SH600765', - 'SH600256', - 'SH600487', - 'SH600563', - 'SH600754', - 'SZ002185', - 'SZ300363', - 'SH601636', - 'SZ300751', - 'SZ300088', - 'SH603589', - 'SZ000630', - 'SZ000988', - 'SZ002268', - 'SZ002353', - 'SH601689', - 'SH601077', - 'SZ002013', - 'SH601168', - 'SZ300763', - 'SH600141', - 'SH600392', - 'SZ000932', - 'SH603267', - 'SZ300699', - 'SH600233', - 'SZ000728', - 'SH600516', - 'SZ300058', - 'SH600486', - 'SZ002797', - 'SH603456', - 'SH603077', - 'SZ002138', - 'SH600862', - 'SZ300146', - 'SZ300223', - 'SZ300136', - 'SH600188', - 'SH600418', - 'SZ002673', - 'SH600705', - 'SH601128', - 'SZ000960', - 'SZ300308', - 'SH688065', - 'SZ300253', - 'SZ002080', - 'SZ000998', - 'SH688536', - 'SZ002092', - 'SH600875', - 'SZ000547', - 'SZ002409', - 'SH600482', - 'SH600521', - 'SZ000807', - 'SH600733', - 'SH600988', - 'SZ000039', - 'SZ300296', - 'SH600177', - 'SZ002439', - 'SZ002212', - 'SZ300418', - 'SZ002078', - 'SZ002223', - 'SH600739', - 'SH603893', - 'SZ002191', - 'SZ002030', - 'SH600699', - 'SZ002603', - 'SH600160', - 'SZ002444', - 'SH601198', - 'SZ300724', - 'SH600637', - 'SZ300474', - 'SZ002465', - 'SH688208', - 'SH603605', - 'SH688002', - 'SH600497', - 'SZ300383', - 'SZ002507', - 'SH601997', - 'SZ002557', - 'SH600549', - 'SZ000739', - 'SH600529', - 'SH600704', - 'SH603290', - 'SZ002508', - 'SH600038', - 'SZ002739', - 'SZ000636', - 'SH603198', - 'SZ002156', - 'SZ000830', - 'SH600201', - 'SH600153', - 'SH600369', - 'SZ000983', - 'SH600642', - 'SZ002299', - 'SZ000629', - 'SZ002936', - 'SZ002131', - 'SZ000738', - 'SH603127', - 'SH601016', - 'SZ000825', - 'SZ000623', - 'SH600879', - 'SH600859', - 'SZ000519', - 'SH600039', - 'SZ300618', - 'SH688521', - 'SZ002273', - 'SH603444', - 'SH600166', - 'SZ002422', - 'SH600732', - 'SZ000750', - 'SZ002373', - 'SH600803', - 'SH600170', - 'SZ000050', - 'SZ000887', - 'SZ300357', - 'SH601699', - 'SH600867', - 'SZ000656', - 'SH600118', - 'SZ300070', - 'SZ002372', - 'SZ000400', - 'SH600271', - 'SH600258', - 'SZ002926', - 'SZ300017', - 'SZ002153', - 'SZ002152', - 'SH600909', - 'SZ002195', - 'SZ300168', - 'SH600008', - 'SZ002572', - 'SH600315', - 'SZ002065', - 'SH600066', - 'SH600580', - 'SZ000975', - 'SH600027', - 'SZ000999', - 'SZ000997', - 'SH603218', - 'SH600536', - 'SZ002203', - 'SH601005', - 'SZ002985', - 'SH600316', - 'SH600372', - 'SH603707', - 'SZ000027', - 'SZ000513', - 'SZ300182', - 'SZ002506', - 'SH603156', - 'SZ002002', - 'SZ002408', - 'SZ002056', - 'SH600236', - 'SZ002266', - 'SZ300001', - 'SZ002500', - 'SZ300244', - 'SZ002250', - 'SH600801', - 'SZ002019', - 'SH600208', - 'SZ300024', - 'SH601872', - 'SH600998', - 'SZ300630', - 'SZ002221', - 'SZ002249', - 'SZ000021', - 'SH600348', - 'SZ300251', - 'SH600415', - 'SH600535', - 'SZ000401', - 'SZ000878', - 'SZ002831', - 'SZ000012', - 'SZ300009', - 'SH600728', - 'SZ000686', - 'SZ002430', - 'SH600171', - 'SZ000537', - 'SZ300115', - 'SZ002683', - 'SH601866', - 'SZ002745', - 'SZ002511', - 'SZ000060', - 'SH600718', - 'SH603568', - 'SZ002127', - 'SZ002532', - 'SH600390', - 'SH601992', - 'SZ300482', - 'SZ000970', - 'SH603858', - 'SH600839', - 'SH600970', - 'SZ002966', - 'SH600155', - 'SZ002128', - 'SZ000690', - 'SH600060', - 'SH600282', - 'SH600517', - 'SH600985', - 'SZ300463', - 'SH600498', - 'SZ300026', - 'SH600380', - 'SH600673', - 'SZ000778', - 'SH603712', - 'SH600528', - 'SZ300212', - 'SZ300072', - 'SZ000930', - 'SH600325', - 'SH600667', - 'SH600782', - 'SZ002382', - 'SZ000729', - 'SZ002939', - 'SH600895', - 'SH600511', - 'SH603885', - 'SH688006', - 'SZ002294', - 'SZ000540', - 'SH601456', - 'SZ000581', - 'SH600500', - 'SH600863', - 'SH601717', - 'SZ002505', - 'SZ002958', - 'SH601179', - 'SZ300315', - 'SH600216', - 'SH600755', - 'SZ002396', - 'SZ000598', - 'SH600297', - 'SZ002701', - 'SH600820', - 'SH600567', - 'SZ000709', - 'SH603927', - 'SH603638', - 'SH603225', - 'SZ002670', - 'SH600507', - 'SH603866', - 'SZ002595', - 'SZ002004', - 'SZ002075', - 'SZ002368', - 'SZ300376', - 'SZ002146', - 'SZ002048', - 'SH600409', - 'SH600021', - 'SH600446', - 'SH601778', - 'SZ300271', - 'SH601577', - 'SH600827', - 'SH688188', - 'SH601880', - 'SH600131', - 'SH600598', - 'SH600808', - 'SZ002010', - 'SZ002925', - 'SH600663', - 'SZ300166', - 'SH600022', - 'SZ002124', - 'SH600546', - 'SH600764', - 'SZ002690', - 'SZ000987', - 'SZ002174', - 'SH600597', - 'SH600737', - 'SH600908', - 'SH603883', - 'SH603228', - 'SZ000883', - 'SZ002155', - 'SH600811', - 'SH601106', - 'SZ002233', - 'SZ002110', - 'SH600566', - 'SZ002281', - 'SH600158', - 'SH600967', - 'SZ000898', - 'SH601118', - 'SH600707', - 'SH600572', - 'SZ002440', - 'SZ000488', - 'SZ000961', - 'SH688029', - 'SZ000990', - 'SZ002183', - 'SH600435', - 'SH601000', - 'SZ000090', - 'SZ000877', - 'SH600906', - 'SH600064', - 'SH601333', - 'SZ002429', - 'SH688088', - 'SH600398', - 'SH601200', - 'SZ002081', - 'SZ000089', - 'SH600373', - 'SH600037', - 'SZ002867', - 'SH600376', - 'SZ001914', - 'SZ002085', - 'SH600167', - 'SZ000559', - 'SZ000959', - 'SZ000528', - 'SZ002416', - 'SZ002244', - 'SH600195', - 'SH600329', - 'SH600259', - 'SZ002434', - 'SH688321', - 'SZ000685', - 'SH601718', - 'SZ000671', - 'SH600026', - 'SH603000', - 'SZ002038', - 'SH600120', - 'SZ002242', - 'SZ001965', - 'SH601608', - 'SZ300133', - 'SH600126', - 'SZ002705', - 'SZ000402', - 'SZ000158', - 'SZ002390', - 'SZ002387', - 'SH601611', - 'SZ000758', - 'SH601098', - 'SZ000415', - 'SH600307', - 'SH601991', - 'SH601975', - 'SH601598', - 'SH603650', - 'SH600959', - 'SH688289', - 'SH600582', - 'SH600643', - 'SZ000563', - 'SZ000718', - 'SH600062', - 'SZ300257', - 'SH601958', - 'SH600901', - 'SH600266', - 'SH600717', - 'SZ000967', - 'SZ002399', - 'SH600056', - 'SZ002458', - 'SH601860', - 'SZ002563', - 'SZ003022', - 'SH601928', - 'SH600928', - 'SH600729', - 'SZ002423', - 'SH600968', - 'SZ000937', - 'SZ002468', - 'SZ000717', - 'SH600556', - 'SZ003035', - 'SH600649', - 'SH600787', - 'SZ000028', - 'SH600095', - 'SH600776', - 'SZ000031', - 'SH600339', - 'SH600006', - 'SH603317', - 'SH600648', - 'SH600623', - 'SZ000062', - 'SH600835', - 'SZ002815', - 'SZ002424', - 'SH600377', - 'SH601139', - 'SH600299', - 'SZ002948', - 'SZ000156', - 'SZ002375', - 'SZ000869', - 'SH600871', - 'SH601828', - 'SZ300869', - 'SZ002302', - 'SH601228', - 'SH601969', - 'SH600639', - 'SH601298', - 'SH603379', - 'SZ000046', - 'SZ002945', - 'SZ002653', - 'SZ001872', - 'SH601156', - 'SH600823', - 'SZ300741', - 'SH603056', - 'SH601869', - 'SH603786', - 'SH603708', - 'SZ000553', - 'SH600350', - 'SZ001203', - 'SH600657', - 'SH601568', - 'SH600466', - 'SH603515', - 'SH603355', - 'SH603719', - 'SH601665', - 'SH601003', - 'SZ002946', - 'SH600917', - 'SH603868', - 'SZ002901', - 'SH600903', - 'SH600956', - 'SH601187', - 'SH603983'] # 你可以在这里添加更多股票代码 - - def get_stock_data(self, symbol, start_date, end_date, include_history=False): - """获取指定股票在时间范围内的数据""" + def get_stock_data(self, symbol, start_date, end_date): + """ + 获取指定股票从起始日期到结束日期的数据 + :param symbol: 股票代码 + :param start_date: 起始日期 + :param end_date: 结束日期 + :return: 股票数据DataFrame + """ try: - if include_history: - # 如果需要历史数据,获取start_date之前240天的数据 - query = f""" - WITH history_data AS ( - SELECT - symbol, - timestamp, - CAST(close as DECIMAL(10,2)) as close - FROM gp_day_data - WHERE symbol = '{symbol}' - AND timestamp < '{start_date}' - ORDER BY timestamp DESC - LIMIT 240 - ) - SELECT - symbol, - timestamp, - CAST(open as DECIMAL(10,2)) as open, - CAST(high as DECIMAL(10,2)) as high, - CAST(low as DECIMAL(10,2)) as low, - CAST(close as DECIMAL(10,2)) as close, - CAST(volume as DECIMAL(20,0)) as volume - FROM gp_day_data - WHERE symbol = '{symbol}' - AND timestamp BETWEEN '{start_date}' AND '{end_date}' - UNION ALL - SELECT - symbol, - timestamp, - NULL as open, - NULL as high, - NULL as low, - close, - NULL as volume - FROM history_data - ORDER BY timestamp - """ - else: - query = f""" - SELECT - symbol, - timestamp, - CAST(open as DECIMAL(10,2)) as open, - CAST(high as DECIMAL(10,2)) as high, - CAST(low as DECIMAL(10,2)) as low, - CAST(close as DECIMAL(10,2)) as close, - CAST(volume as DECIMAL(20,0)) as volume - FROM gp_day_data - WHERE symbol = '{symbol}' - AND timestamp BETWEEN '{start_date}' AND '{end_date}' - ORDER BY timestamp - """ + conn = pymysql.connect( + host=self.db_config['host'], + port=self.db_config['port'], + user=self.db_config['user'], + password=self.db_config['password'], + database=self.db_config['database'] + ) - # 使用chunksize分批读取数据 - chunks = [] - for chunk in pd.read_sql(query, self.engine, chunksize=1000): - chunks.append(chunk) - df = pd.concat(chunks, ignore_index=True) + query = f""" + SELECT + symbol, + timestamp, + CAST(open as DECIMAL(10,2)) as open, + CAST(high as DECIMAL(10,2)) as high, + CAST(low as DECIMAL(10,2)) as low, + CAST(close as DECIMAL(10,2)) as close, + CAST(volume as DECIMAL(20,0)) as volume + FROM gp_day_data + WHERE symbol = %s + AND timestamp BETWEEN %s AND %s + ORDER BY timestamp + """ + + # 使用pandas读取sql + df = pd.read_sql(query, conn, params=(symbol, start_date, end_date)) + conn.close() # 确保数值列的类型正确 numeric_columns = ['open', 'high', 'low', 'close', 'volume'] @@ -600,97 +83,148 @@ class StockAnalyzer: if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce') - # 删除任何包含 NaN 的行,但只考虑分析期间需要的列 - if include_history: - historical_mask = (df['timestamp'] < start_date) & df['close'].notna() - analysis_mask = (df['timestamp'] >= start_date) & df[['open', 'high', 'low', 'close', 'volume']].notna().all(axis=1) - df = df[historical_mask | analysis_mask] - else: - df = df.dropna() + # 删除任何包含 NaN 的行 + df = df.dropna() + if df.empty: + print(f"股票 {symbol} 在指定时间范围内没有数据") + else: + print(f"成功获取股票 {symbol} 数据,共 {len(df)} 条记录") + return df except Exception as e: print(f"获取股票 {symbol} 数据时出错: {str(e)}") return pd.DataFrame() - def calculate_trade_signals(self, df, take_profit_pct, stop_loss_pct, pullback_entry_pct, start_date): - """计算交易信号和收益""" - """计算交易信号和收益""" - # 检查历史数据是否足够 - historical_data = df[df['timestamp'] < start_date] - if len(historical_data) < 240: - # print(f"历史数据不足240天,跳过分析") - return pd.DataFrame() - - # 计算240日均线基准价格(使用历史数据) - ma240_base_price = historical_data['close'].tail(240).mean() - # print(f"240日均线基准价格: {ma240_base_price:.2f}") - - # 只使用分析期间的数据进行交易信号计算 - analysis_df = df[df['timestamp'] >= start_date].copy() - if len(analysis_df) < 10: # 确保分析期间有足够的数据 - # print(f"分析期间数据不足10天,跳过分析") - return pd.DataFrame() - - results = [] - positions = [] # 记录所有持仓,每个元素是一个字典,包含入场价格和日期 - max_positions = 5 # 最大建仓次数 - cumulative_profit = 0 # 总累计收益 - trades_count = 0 - max_holding_period = timedelta(days=180) # 最长持仓6个月 - max_loss_amount = -300000 # 最大亏损金额为30万 + def backtest_stock(self, symbol, buy_dates, end_date): + """ + 回测单个股票的表现 + :param symbol: 股票代码 + :param buy_dates: 买入日期列表,格式为['YYYY-MM-DD', 'YYYY-MM-DD', ...] + :param end_date: 结束日期字符串,格式为'YYYY-MM-DD' + :return: 回测结果 + """ + # 找到最早的买入日期 + buy_dates = sorted(buy_dates) + earliest_buy_date = buy_dates[0] - # 添加统计变量 - total_holding_days = 0 # 总持仓天数 - max_capital_usage = 0 # 最大资金占用 - total_trades = 0 # 总交易次数 - profitable_trades = 0 # 盈利交易次数 - loss_trades = 0 # 亏损交易次数 - - # print(f"\n{'='*50}") - # print(f"开始分析股票: {df.iloc[0]['symbol']}") - # print(f"{'='*50}") - - # 记录前一天的收盘价 - prev_close = None - - for index, row in analysis_df.iterrows(): - current_price = (float(row['high']) + float(row['low'])) / 2 + # 获取股票数据 + df = self.get_stock_data(symbol, earliest_buy_date, end_date) + + if df.empty: + print(f"股票 {symbol} 在指定时间范围内没有数据") + return None, None + + if len(df) < 2: + print(f"股票 {symbol} 在指定时间范围内数据不足") + return None, None + + # 将时间戳转换为datetime + df['timestamp'] = pd.to_datetime(df['timestamp']) + + # 转换买入日期和结束日期为datetime对象 + buy_dates_obj = [datetime.strptime(buy_date, '%Y-%m-%d') for buy_date in buy_dates] + end_date_obj = datetime.strptime(end_date, '%Y-%m-%d') + + # 初始化结果列表 + results = [] + positions = [] # 记录所有持仓 + cumulative_profit = 0 # 总累计收益 + + # 记录统计变量 + total_holding_days = 0 + total_trades = 0 + profitable_trades = 0 + loss_trades = 0 + entry_count = 0 # 成功建仓次数 + + # 创建时间戳到索引的映射,便于查找 + timestamp_to_idx = {row['timestamp']: idx for idx, row in df.iterrows()} + + # 为每个买入日期找到对应的交易日并建仓 + for buy_date_obj in buy_dates_obj: + # 找到买入日期当天或之后的第一个交易日 + buy_idx = -1 + nearest_date = None + + for date in df['timestamp']: + if date >= buy_date_obj: + buy_idx = timestamp_to_idx[date] + nearest_date = date + break + + if buy_idx == -1: + print(f"股票 {symbol} 在买入日期 {buy_date_obj.strftime('%Y-%m-%d')} 之后没有交易数据") + continue + + entry_row = df.iloc[buy_idx] + entry_price = float(entry_row['open']) + entry_date = entry_row['timestamp'] + + positions.append({ + 'entry_price': entry_price, + 'entry_date': entry_date, + 'closed': False + }) + entry_count += 1 + + print(f"建仓 #{entry_count}: {symbol} 在 {entry_date.strftime('%Y-%m-%d')} 以 {entry_price:.2f} 价格买入") + + # 创建初始结果 + if len(df) > 0: + first_day = df.iloc[0] + results.append({ + 'timestamp': first_day['timestamp'], + 'profit': 0, + 'position': len(positions) > 0, + 'action': 'hold' + }) + + # 遍历所有交易日,检查平仓条件 + for i in range(1, len(df)): + current_row = df.iloc[i] + current_date = current_row['timestamp'] + current_price = (float(current_row['high']) + float(current_row['low'])) / 2 + day_result = { - 'timestamp': row['timestamp'], - 'profit': cumulative_profit, # 记录当前的累计收益 + 'timestamp': current_date, + 'profit': cumulative_profit, 'position': False, 'action': 'hold' } - - # 更新所有持仓的收益 - positions_to_remove = [] - - for pos in positions: - pos_days_held = (row['timestamp'] - pos['entry_date']).days + + # 检查是否是买入日期,设置持仓标记 + if any(not pos['closed'] for pos in positions): + day_result['position'] = True + + # 检查所有持仓是否需要平仓 + positions_to_close = [] + + for idx, pos in enumerate(positions): + if pos['closed']: + continue + + pos_days_held = (current_date - pos['entry_date']).days pos_size = self.total_position / pos['entry_price'] # 检查是否需要平仓 should_close = False close_type = None close_price = current_price - + if pos_days_held >= self.min_holding_days: # 检查止盈 - if float(row['high']) >= pos['entry_price'] * (1 + take_profit_pct): + if float(current_row['high']) >= pos['entry_price'] * (1 + self.take_profit_pct): should_close = True close_type = 'sell_profit' - close_price = pos['entry_price'] * (1 + take_profit_pct) - # 检查最大亏损 - elif (close_price - pos['entry_price']) * pos_size <= max_loss_amount: - should_close = True - close_type = 'sell_loss' - # 检查持仓时间 - elif (row['timestamp'] - pos['entry_date']) >= max_holding_period: - should_close = True - close_type = 'time_limit' - + close_price = pos['entry_price'] * (1 + self.take_profit_pct) + + # 强制在结束日期平仓 + if current_date >= end_date_obj and not should_close: + should_close = True + close_type = 'final_sell' + if should_close: profit = (close_price - pos['entry_price']) * pos_size commission = profit * self.commission_rate if profit > 0 else 0 @@ -706,711 +240,209 @@ class StockAnalyzer: loss_trades += 1 cumulative_profit += net_profit - positions_to_remove.append(pos) - day_result['action'] = close_type - day_result['profit'] = cumulative_profit - # print(f"\n>>> {'止盈平仓' if close_type == 'sell_profit' else '止损平仓' if close_type == 'sell_loss' else '到期平仓'}") - # print(f"持仓天数: {pos_days_held}") - # print(f"持仓数量: {pos_size:.0f}股") - # print(f"平仓价格: {close_price:.2f}") - # print(f"{'盈利' if profit > 0 else '亏损'}: {profit:,.2f}") - # print(f"手续费: {commission:,.2f}") - # print(f"利息成本: {interest_cost:,.2f}") - # print(f"净{'盈利' if net_profit > 0 else '亏损'}: {net_profit:,.2f}") - - # 移除已平仓的持仓 - for pos in positions_to_remove: - positions.remove(pos) - - # 检查是否可以建仓 - if len(positions) < max_positions and prev_close is not None: - # 计算当日跌幅 - daily_drop = (prev_close - float(row['low'])) / prev_close + positions[idx]['closed'] = True + positions_to_close.append((idx, pos, close_type, close_price, net_profit, pos_days_held)) + + # 处理平仓 + for idx, pos, close_type, close_price, net_profit, pos_days_held in positions_to_close: + day_result['action'] = close_type + day_result['profit'] = cumulative_profit - # 检查价格是否在均线区间内 - price_in_range = ( - current_price >= ma240_base_price * 0.7 and - current_price <= ma240_base_price * 1.3 - ) - - # 如果跌幅超过5%且价格在均线区间内 - if daily_drop >= 0.05 and price_in_range: - positions.append({ - 'entry_price': current_price, - 'entry_date': row['timestamp'] - }) - trades_count += 1 - # print(f"\n>>> 建仓信号 #{trades_count}") - # print(f"日期: {row['timestamp'].strftime('%Y-%m-%d')}") - # print(f"建仓价格: {current_price:.2f}") - # print(f"建仓数量: {self.total_position/current_price:.0f}股") - # print(f"当日跌幅: {daily_drop*100:.2f}%") - # print(f"距离均线: {((current_price/ma240_base_price)-1)*100:.2f}%") - - # 更新前一天收盘价 - prev_close = float(row['close']) - + print(f"平仓 #{idx+1}: {symbol} 在 {current_date.strftime('%Y-%m-%d')} 以 {close_price:.2f} 价格卖出") + print(f"持仓天数: {pos_days_held}, 净利润: {net_profit:.2f}") + # 更新日结果 - day_result['position'] = len(positions) > 0 + day_result['position'] = any(not pos['closed'] for pos in positions) results.append(day_result) - - # 计算当前资金占用 - current_capital_usage = sum(self.total_position for pos in positions) - max_capital_usage = max(max_capital_usage, current_capital_usage) - - # 在最后一天强制平仓所有持仓 - if positions: - final_price = (float(analysis_df.iloc[-1]['high']) + float(analysis_df.iloc[-1]['low'])) / 2 - final_total_profit = 0 - - for pos in positions: - position_size = self.total_position / pos['entry_price'] - days_held = (analysis_df.iloc[-1]['timestamp'] - pos['entry_date']).days - final_profit = (final_price - pos['entry_price']) * position_size - commission = final_profit * self.commission_rate if final_profit > 0 else 0 - interest_cost = self.borrowed_capital * self.daily_interest_rate * days_held - net_profit = final_profit - commission - interest_cost - - # 更新统计数据 - total_holding_days += days_held - total_trades += 1 - if net_profit > 0: - profitable_trades += 1 - else: - loss_trades += 1 - - final_total_profit += net_profit - - # print(f"\n>>> 到期强制平仓") - # print(f"持仓天数: {days_held}") - # print(f"持仓数量: {position_size:.0f}股") - # print(f"平仓价格: {final_price:.2f}") - # print(f"毛利润: {final_profit:,.2f}") - # print(f"手续费: {commission:,.2f}") - # print(f"利息成本: {interest_cost:,.2f}") - # print(f"净利润: {net_profit:,.2f}") - - results[-1]['action'] = 'final_sell' - cumulative_profit += final_total_profit # 更新最终的累计收益 - results[-1]['profit'] = cumulative_profit # 更新最后一天的累计收益 - + # 计算统计数据 avg_holding_days = total_holding_days / total_trades if total_trades > 0 else 0 win_rate = profitable_trades / total_trades * 100 if total_trades > 0 else 0 - - # print(f"\n{'='*50}") - # print(f"交易统计") - # print(f"总交易次数: {trades_count}") - # print(f"累计收益: {cumulative_profit:,.2f}") - # print(f"最大资金占用: {max_capital_usage:,.2f}") - # print(f"平均持仓天数: {avg_holding_days:.1f}") - # print(f"胜率: {win_rate:.1f}%") - # print(f"盈利交易: {profitable_trades}次") - # print(f"亏损交易: {loss_trades}次") - # print(f"{'='*50}\n") - - # 将统计数据添加到结果中 - if len(results) > 0: - results[-1]['stats'] = { - 'total_trades': total_trades, - 'profitable_trades': profitable_trades, - 'loss_trades': loss_trades, - 'win_rate': win_rate, - 'avg_holding_days': avg_holding_days, - 'max_capital_usage': max_capital_usage, - 'final_profit': cumulative_profit - } - - return pd.DataFrame(results) - - def analyze_stock(self, symbol, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): - """分析单个股票""" - df = self.get_stock_data(symbol, start_date, end_date, include_history=True) - if len(df) < 240: # 确保总数据量足够 - print(f"股票 {symbol} 总数据不足240天,跳过分析") - return None - - return self.calculate_trade_signals( - df, - take_profit_pct=take_profit_pct, - stop_loss_pct=stop_loss_pct, - pullback_entry_pct=pullback_entry_pct, - start_date=start_date - ) - - def analyze_all_stocks(self, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): - """分析所有股票""" - stocks = self.get_stock_list() - all_results = {} - analysis_summary = { - 'processed': [], - 'skipped': [], - 'error': [] + # 统计信息 + stats = { + 'symbol': symbol, + 'total_trades': total_trades, + 'profitable_trades': profitable_trades, + 'loss_trades': loss_trades, + 'win_rate': win_rate, + 'avg_holding_days': avg_holding_days, + 'final_profit': cumulative_profit, + 'entry_count': entry_count } - print(f"\n开始分析,共 {len(stocks)} 支股票") + print(f"\n交易统计 - {symbol}") + print(f"总交易次数: {total_trades}") + print(f"盈利交易: {profitable_trades}次") + print(f"亏损交易: {loss_trades}次") + print(f"胜率: {win_rate:.1f}%") + print(f"平均持仓天数: {avg_holding_days:.1f}") + print(f"累计收益: {cumulative_profit:.2f}") + print(f"建仓次数: {entry_count}\n") - for symbol in tqdm(stocks, desc="Analyzing stocks"): + return pd.DataFrame(results), stats + + def backtest_multiple_stocks(self, stocks_buy_dates, end_date): + """ + 回测多只股票 + :param stocks_buy_dates: 字典,键为股票代码,值为买入日期列表 + :param end_date: 所有股票的共同结束日期 + :return: 回测结果字典 + """ + results = {} + stats_list = [] + + print(f"\n开始回测,共 {len(stocks_buy_dates)} 支股票,结束日期: {end_date}") + + for symbol, buy_dates in tqdm(stocks_buy_dates.items(), desc="回测进度"): try: - # 获取股票数据 - df = self.get_stock_data(symbol, start_date, end_date, include_history=True) + result_df, stats = self.backtest_stock(symbol, buy_dates, end_date) - if df.empty: - print(f"\n股票 {symbol} 数据为空,跳过分析") - analysis_summary['skipped'].append({ - 'symbol': symbol, - 'reason': 'empty_data', - 'take_profit_pct': take_profit_pct - }) - continue - - # 检查历史数据 - historical_data = df[df['timestamp'] < start_date] - if len(historical_data) < 240: - # print(f"\n股票 {symbol} 历史数据不足240天({len(historical_data)}天),跳过分析") - analysis_summary['skipped'].append({ - 'symbol': symbol, - 'reason': f'insufficient_history_{len(historical_data)}days', - 'take_profit_pct': take_profit_pct - }) - continue - - # 检查分析期数据 - analysis_data = df[df['timestamp'] >= start_date] - if len(analysis_data) < 10: - print(f"\n股票 {symbol} 分析期数据不足10天({len(analysis_data)}天),跳过分析") - analysis_summary['skipped'].append({ - 'symbol': symbol, - 'reason': f'insufficient_analysis_data_{len(analysis_data)}days', - 'take_profit_pct': take_profit_pct - }) - continue - - # 计算交易信号 - result = self.calculate_trade_signals( - df, - take_profit_pct, - stop_loss_pct, - pullback_entry_pct, - start_date - ) - - if result is not None and not result.empty: - # 检查是否有交易记录 - has_trades = any(action in ['sell_profit', 'sell_loss', 'time_limit', 'final_sell'] - for action in result['action'] if pd.notna(action)) + if result_df is not None and stats is not None: + results[symbol] = result_df + stats_list.append(stats) - if has_trades: - all_results[symbol] = result - analysis_summary['processed'].append({ - 'symbol': symbol, - 'take_profit_pct': take_profit_pct, - 'trades': len([x for x in result['action'] if pd.notna(x) and x != 'hold']), - 'profit': result['profit'].iloc[-1] - }) - else: - # print(f"\n股票 {symbol} 没有产生有效的交易信号") - analysis_summary['skipped'].append({ - 'symbol': symbol, - 'reason': 'no_valid_signals', - 'take_profit_pct': take_profit_pct - }) - except Exception as e: print(f"\n处理股票 {symbol} 时出错: {str(e)}") - analysis_summary['error'].append({ - 'symbol': symbol, - 'error': str(e), - 'take_profit_pct': take_profit_pct - }) + import traceback + traceback.print_exc() continue - # 打印分析总结 - print(f"\n分析完成:") - print(f"成功分析: {len(analysis_summary['processed'])} 支股票") - print(f"跳过分析: {len(analysis_summary['skipped'])} 支股票") - print(f"出错股票: {len(analysis_summary['error'])} 支股票") + # 保存统计结果 + if stats_list: + stats_df = pd.DataFrame(stats_list) + + if not os.path.exists('results'): + os.makedirs('results') + + stats_df.to_csv('results/backtest_summary.csv', index=False) + print(f"\n回测统计已保存至 results/backtest_summary.csv") + + # 计算总体统计 + total_profit = stats_df['final_profit'].sum() + avg_win_rate = stats_df['win_rate'].mean() + avg_holding_days = stats_df['avg_holding_days'].mean() + total_trades = stats_df['total_trades'].sum() + profitable_stocks = len(stats_df[stats_df['final_profit'] > 0]) + total_entries = stats_df['entry_count'].sum() + + print(f"\n总体回测统计:") + print(f"总收益: {total_profit:.2f}") + print(f"平均胜率: {avg_win_rate:.1f}%") + print(f"平均持仓天数: {avg_holding_days:.1f}") + print(f"总交易次数: {total_trades}") + print(f"总建仓次数: {total_entries}") + print(f"盈利股票数: {profitable_stocks}/{len(stats_df)}") - # 保存分析总结 - if not os.path.exists('results/analysis_summary'): - os.makedirs('results/analysis_summary') - - # 保存处理成功的股票信息 - if analysis_summary['processed']: - pd.DataFrame(analysis_summary['processed']).to_csv( - f'results/analysis_summary/processed_stocks_{take_profit_pct*100:.1f}.csv', - index=False - ) - - # 保存跳过的股票信息 - if analysis_summary['skipped']: - pd.DataFrame(analysis_summary['skipped']).to_csv( - f'results/analysis_summary/skipped_stocks_{take_profit_pct*100:.1f}.csv', - index=False - ) - - # 保存出错的股票信息 - if analysis_summary['error']: - pd.DataFrame(analysis_summary['error']).to_csv( - f'results/analysis_summary/error_stocks_{take_profit_pct*100:.1f}.csv', - index=False - ) - - return all_results + return results, stats_list def plot_results(self, results, symbol): - """绘制单个股票的收益走势图""" + """ + 绘制单个股票的收益走势图 + """ + if results is None or results.empty: + print(f"股票 {symbol} 没有可绘制的结果") + return None + plt.figure(figsize=(12, 6)) - plt.plot(results['timestamp'], results['profit'].cumsum(), label='Cumulative Profit') - plt.title(f'Stock {symbol} Trading Results') - plt.xlabel('Date') - plt.ylabel('Profit (CNY)') + plt.plot(results['timestamp'], results['profit'], label='累计收益') + plt.title(f'股票 {symbol} 回测结果') + plt.xlabel('日期') + plt.ylabel('收益 (元)') plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() + + # 保存图表 + if not os.path.exists('results/charts'): + os.makedirs('results/charts') + plt.savefig(f'results/charts/{symbol}_backtest.png') + return plt - def plot_all_stocks(self, all_results): - """绘制所有股票的收益走势图""" + def plot_all_stocks(self, all_results, stats_list): + """ + 绘制所有股票的收益走势图 + """ + if not all_results: + print("没有可绘制的结果") + return None, {} + plt.figure(figsize=(15, 8)) - # 记录所有股票的累计收益 + # 记录所有股票的收益 total_profits = {} max_profit = float('-inf') min_profit = float('inf') # 绘制每支股票的曲线 for symbol, results in all_results.items(): - # 直接使用已经计算好的累计收益 + if results is None or results.empty: + continue + profits = results['profit'] plt.plot(results['timestamp'], profits, label=symbol, alpha=0.5, linewidth=1) # 更新最大最小收益 - max_profit = max(max_profit, profits.max()) - min_profit = min(min_profit, profits.min()) + if not profits.empty: + max_profit = max(max_profit, profits.max()) + min_profit = min(min_profit, profits.min()) - # 记录总收益(使用最后一个累计收益值) - total_profits[symbol] = profits.iloc[-1] + # 记录总收益 + if not profits.empty: + total_profits[symbol] = profits.iloc[-1] - # 计算所有股票的中位数收益曲线 - # 首先找到共同的日期范围 - all_dates = set() - for results in all_results.values(): - all_dates.update(results['timestamp']) - all_dates = sorted(list(all_dates)) + # 计算统计信息 + if stats_list: + stats_df = pd.DataFrame(stats_list) + total_profit = stats_df['final_profit'].sum() + avg_win_rate = stats_df['win_rate'].mean() + avg_holding_days = stats_df['avg_holding_days'].mean() + total_trades = stats_df['total_trades'].sum() + profitable_stocks = len(stats_df[stats_df['final_profit'] > 0]) + total_entries = stats_df['entry_count'].sum() + + # 添加统计信息 + stats_text = ( + f"回测统计:\n" + f"总收益: {total_profit:.2f}\n" + f"平均胜率: {avg_win_rate:.1f}%\n" + f"平均持仓天数: {avg_holding_days:.1f}\n" + f"总交易次数: {total_trades}\n" + f"总建仓次数: {total_entries}\n" + f"盈利股票数: {profitable_stocks}/{len(stats_df)}" + ) + + plt.text(0.02, 0.95, stats_text, + transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), + verticalalignment='top', + fontsize=10) - # 创建一个包含所有日期的DataFrame - profits_df = pd.DataFrame(index=all_dates) - - # 对每支股票,填充所有日期的收益 - for symbol, results in all_results.items(): - daily_profits = pd.Series(index=all_dates, data=np.nan) - # 直接使用已经计算好的累计收益 - for date, profit in zip(results['timestamp'], results['profit']): - daily_profits[date] = profit - # 对于没有交易的日期,使用最后一个累计收益填充 - daily_profits.fillna(method='ffill', inplace=True) - daily_profits.fillna(0, inplace=True) # 对开始之前的日期填充0 - profits_df[symbol] = daily_profits - - # 计算并绘制中位数收益曲线 - median_line = profits_df.median(axis=1) - plt.plot(all_dates, median_line, 'r-', label='Median', linewidth=2) - - plt.title('All Stocks Trading Results') - plt.xlabel('Date') - plt.ylabel('Profit (CNY)') + plt.title('所有股票回测结果') + plt.xlabel('日期') + plt.ylabel('收益 (元)') plt.grid(True) plt.xticks(rotation=45) - # 添加利润区间标注 - plt.text(0.02, 0.98, - f'Profit Range:\nMax: {max_profit:,.0f}\nMin: {min_profit:,.0f}\nMedian: {median_line.iloc[-1]:,.0f}', - transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8)) - - # 计算所有股票的统计数据 - total_stats = { - 'total_trades': 0, - 'profitable_trades': 0, - 'loss_trades': 0, - 'total_holding_days': 0, - 'max_capital_usage': 0, - 'total_profit': 0, - 'profitable_stocks': 0, - 'loss_stocks': 0, - 'total_capital_usage': 0, # 添加总资金占用 - 'total_win_rate': 0, # 添加总胜率 - } - - # 获取前5名和后5名股票 - top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] - bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] - - # 添加排名信息 - rank_text = "Top 5 Stocks:\n" - for symbol, profit in top_5: - rank_text += f"{symbol}: {profit:,.0f}\n" - rank_text += "\nBottom 5 Stocks:\n" - for symbol, profit in bottom_5: - rank_text += f"{symbol}: {profit:,.0f}\n" - - plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') - - # 记录每支股票的统计数据用于计算平均值 - stock_stats = [] - - for symbol, results in all_results.items(): - if 'stats' in results.iloc[-1]: - stats = results.iloc[-1]['stats'] - stock_stats.append(stats) # 记录每支股票的统计数据 - - total_stats['total_trades'] += stats['total_trades'] - total_stats['profitable_trades'] += stats['profitable_trades'] - total_stats['loss_trades'] += stats['loss_trades'] - total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] - total_stats['max_capital_usage'] = max(total_stats['max_capital_usage'], stats['max_capital_usage']) - total_stats['total_capital_usage'] += stats['max_capital_usage'] - total_stats['total_profit'] += stats['final_profit'] - total_stats['total_win_rate'] += stats['win_rate'] - if stats['final_profit'] > 0: - total_stats['profitable_stocks'] += 1 - else: - total_stats['loss_stocks'] += 1 - - # 计算总体统计 - total_stocks = total_stats['profitable_stocks'] + total_stats['loss_stocks'] - avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 - win_rate = total_stats['profitable_trades'] / total_stats['total_trades'] * 100 if total_stats['total_trades'] > 0 else 0 - stock_win_rate = total_stats['profitable_stocks'] / total_stocks * 100 if total_stocks > 0 else 0 - - # 计算平均值统计 - avg_trades_per_stock = total_stats['total_trades'] / total_stocks if total_stocks > 0 else 0 - avg_profit_per_stock = total_stats['total_profit'] / total_stocks if total_stocks > 0 else 0 - avg_capital_usage = total_stats['total_capital_usage'] / total_stocks if total_stocks > 0 else 0 - avg_win_rate = total_stats['total_win_rate'] / total_stocks if total_stocks > 0 else 0 - - # 添加总体统计信息 - stats_text = ( - f"总体统计:\n" - f"总交易次数: {total_stats['total_trades']}\n" - # f"最大资金占用: {total_stats['max_capital_usage']:,.0f}\n" - f"平均持仓天数: {avg_holding_days:.1f}\n" - f"交易胜率: {win_rate:.1f}%\n" - f"股票胜率: {stock_win_rate:.1f}%\n" - f"盈利股票数: {total_stats['profitable_stocks']}\n" - f"亏损股票数: {total_stats['loss_stocks']}\n" - f"\n平均值统计:\n" - f"平均交易次数: {avg_trades_per_stock:.1f}\n" - f"平均收益: {avg_profit_per_stock:,.0f}\n" - f"平均资金占用: {avg_capital_usage:,.0f}\n" - f"平均持仓天数: {avg_holding_days:.1f}\n" - f"平均胜率: {avg_win_rate:.1f}%" - ) - - plt.text(0.02, 0.6, stats_text, - transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), - verticalalignment='top', - fontsize=10) - - # 获取前5名和后5名股票 - top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] - bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] - - # 添加排名信息 - rank_text = "Top 5 Stocks:\n" - for symbol, profit in top_5: - rank_text += f"{symbol}: {profit:,.0f}\n" - rank_text += "\nBottom 5 Stocks:\n" - for symbol, profit in bottom_5: - rank_text += f"{symbol}: {profit:,.0f}\n" - - plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') - + # 保存图表 + if not os.path.exists('results/charts'): + os.makedirs('results/charts') + plt.savefig('results/charts/all_stocks_backtest.png') plt.tight_layout() + return plt, total_profits - def plot_profit_distribution(self, total_profits): - """绘制所有股票的收益分布小提琴图""" - plt.figure(figsize=(12, 8)) - - # 将收益数据转换为数组,确保使用最终累计收益 - profits = np.array([profit for profit in total_profits.values()]) - - # 计算统计数据 - mean_profit = np.mean(profits) - median_profit = np.median(profits) - max_profit = np.max(profits) - min_profit = np.min(profits) - std_profit = np.std(profits) - - # 绘制小提琴图 - violin_parts = plt.violinplot(profits, positions=[0], showmeans=True, showmedians=True) - - # 设置小提琴图的颜色 - violin_parts['bodies'][0].set_facecolor('lightblue') - violin_parts['bodies'][0].set_alpha(0.7) - violin_parts['cmeans'].set_color('red') - violin_parts['cmedians'].set_color('blue') - - # 添加统计信息 - stats_text = ( - f"统计信息:\n" - f"平均收益: {mean_profit:,.0f}\n" - f"中位数收益: {median_profit:,.0f}\n" - f"最大收益: {max_profit:,.0f}\n" - f"最小收益: {min_profit:,.0f}\n" - f"标准差: {std_profit:,.0f}" - ) - - plt.text(0.65, 0.95, stats_text, - transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), - verticalalignment='top', - fontsize=10) - - # 设置图表样式 - plt.title('股票收益分布图', fontsize=14, pad=20) - plt.ylabel('收益 (元)', fontsize=12) - plt.xticks([0], ['所有股票'], fontsize=10) - plt.grid(True, axis='y', alpha=0.3) - - # 添加零线 - plt.axhline(y=0, color='r', linestyle='--', alpha=0.3) - - # 计算盈利和亏损的股票数量 - profit_count = np.sum(profits > 0) - loss_count = np.sum(profits < 0) - total_count = len(profits) - - # 添加盈亏比例信息 - ratio_text = ( - f"盈亏比例:\n" - f"盈利: {profit_count}支 ({profit_count/total_count*100:.1f}%)\n" - f"亏损: {loss_count}支 ({loss_count/total_count*100:.1f}%)\n" - f"总计: {total_count}支" - ) - - plt.text(0.65, 0.6, ratio_text, - transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), - verticalalignment='top', - fontsize=10) - - plt.tight_layout() - return plt - - def plot_profit_matrix(self, df): - """绘制止盈比例分析矩阵的结果图""" - plt.figure(figsize=(15, 10)) - - # 创建子图 - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) - - # 1. 总收益曲线 - ax1.plot(df.index, df['total_profit'], 'b-', linewidth=2) - ax1.set_title('总收益 vs 止盈比例') - ax1.set_xlabel('止盈比例 (%)') - ax1.set_ylabel('总收益 (元)') - ax1.grid(True) - - # 标记最优止盈比例点 - best_profit_pct = df['total_profit'].idxmax() - best_profit = df['total_profit'].max() - ax1.plot(best_profit_pct, best_profit, 'ro') - ax1.annotate(f'最优: {best_profit_pct:.1f}%\n{best_profit:,.0f}元', - (best_profit_pct, best_profit), - xytext=(10, 10), textcoords='offset points') - - # 2. 平均每笔收益和中位数收益 - ax2.plot(df.index, df['avg_profit_per_trade'], 'g-', linewidth=2, label='平均收益') - ax2.plot(df.index, df['median_profit_per_trade'], 'r--', linewidth=2, label='中位数收益') - ax2.set_title('每笔交易收益 vs 止盈比例') - ax2.set_xlabel('止盈比例 (%)') - ax2.set_ylabel('收益 (元)') - ax2.legend() - ax2.grid(True) - - # 3. 平均持仓天数曲线 - ax3.plot(df.index, df['avg_holding_days'], 'r-', linewidth=2) - ax3.set_title('平均持仓天数 vs 止盈比例') - ax3.set_xlabel('止盈比例 (%)') - ax3.set_ylabel('持仓天数') - ax3.grid(True) - - # 4. 胜率对比 - ax4.plot(df.index, df['trade_win_rate'], 'c-', linewidth=2, label='交易胜率') - ax4.plot(df.index, df['stock_win_rate'], 'm--', linewidth=2, label='股票胜率') - ax4.set_title('胜率 vs 止盈比例') - ax4.set_xlabel('止盈比例 (%)') - ax4.set_ylabel('胜率 (%)') - ax4.legend() - ax4.grid(True) - - # 调整布局 - plt.tight_layout() - return plt - - def analyze_profit_matrix(self, start_date, end_date, stop_loss_pct, pullback_entry_pct): - """分析不同止盈比例的表现矩阵""" - # 创建results目录和临时文件目录 - if not os.path.exists('results/temp'): - os.makedirs('results/temp') - - # 从30%到5%,每次减少1% - profit_percentages = list(np.arange(0.30, 0.04, -0.01)) - - for take_profit_pct in profit_percentages: - # 检查是否已经分析过这个止盈点 - temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' - if os.path.exists(temp_file): - print(f"\n止盈比例 {take_profit_pct*100:.1f}% 已分析,跳过") - continue - - print(f"\n分析止盈比例: {take_profit_pct*100:.1f}%") - - try: - # 运行分析 - results = self.analyze_all_stocks( - start_date, - end_date, - take_profit_pct, - stop_loss_pct, - pullback_entry_pct - ) - - # 过滤掉空结果 - valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} - - if valid_results: - # 计算统计数据 - total_stats = { - 'total_trades': 0, - 'profitable_trades': 0, - 'loss_trades': 0, - 'total_holding_days': 0, - 'total_profit': 0, - 'total_capital_usage': 0, - 'total_win_rate': 0, - 'stock_count': len(valid_results), - 'profitable_stocks': 0, - 'all_trade_profits': [] - } - - # 收集每支股票的统计数据 - for symbol, result in valid_results.items(): - if 'stats' in result.iloc[-1]: - stats = result.iloc[-1]['stats'] - total_stats['total_trades'] += stats['total_trades'] - total_stats['profitable_trades'] += stats['profitable_trades'] - total_stats['loss_trades'] += stats['loss_trades'] - total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] - total_stats['total_profit'] += stats['final_profit'] - total_stats['total_capital_usage'] += stats['max_capital_usage'] - total_stats['total_win_rate'] += stats['win_rate'] - - if stats['final_profit'] > 0: - total_stats['profitable_stocks'] += 1 - - if hasattr(stats, 'trade_profits'): - total_stats['all_trade_profits'].extend(stats['trade_profits']) - - # 计算统计数据 - avg_trades = total_stats['total_trades'] / total_stats['stock_count'] - avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 - avg_profit_per_trade = total_stats['total_profit'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 - median_profit_per_trade = np.median(total_stats['all_trade_profits']) if total_stats['all_trade_profits'] else 0 - total_profit = total_stats['total_profit'] - trade_win_rate = (total_stats['profitable_trades'] / total_stats['total_trades'] * 100) if total_stats['total_trades'] > 0 else 0 - stock_win_rate = (total_stats['profitable_stocks'] / total_stats['stock_count'] * 100) if total_stats['stock_count'] > 0 else 0 - - # 创建单个止盈点的结果DataFrame - result_df = pd.DataFrame([{ - 'take_profit_pct': take_profit_pct * 100, - 'total_profit': total_profit, - 'avg_profit_per_trade': avg_profit_per_trade, - 'median_profit_per_trade': median_profit_per_trade, - 'avg_holding_days': avg_holding_days, - 'trade_win_rate': trade_win_rate, - 'stock_win_rate': stock_win_rate, - 'total_trades': total_stats['total_trades'], - 'stock_count': total_stats['stock_count'] - }]) - - # 保存单个止盈点的结果 - result_df.to_csv(temp_file, index=False) - print(f"保存止盈点 {take_profit_pct*100:.1f}% 的分析结果") - - # 清理内存 - del valid_results - del results - if 'total_stats' in locals(): - del total_stats - - except Exception as e: - print(f"处理止盈点 {take_profit_pct*100:.1f}% 时出错: {str(e)}") - continue - - # 合并所有结果 - print("\n合并所有分析结果...") - all_results = [] - for take_profit_pct in profit_percentages: - temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' - if os.path.exists(temp_file): - try: - df = pd.read_csv(temp_file) - all_results.append(df) - except Exception as e: - print(f"读取文件 {temp_file} 时出错: {str(e)}") - - if all_results: - # 合并所有结果 - df = pd.concat(all_results, ignore_index=True) - df = df.round({ - 'take_profit_pct': 1, - 'total_profit': 0, - 'avg_profit_per_trade': 0, - 'median_profit_per_trade': 0, - 'avg_holding_days': 1, - 'trade_win_rate': 1, - 'stock_win_rate': 1, - 'total_trades': 0, - 'stock_count': 0 - }) - - # 设置止盈比例为索引并排序 - df.set_index('take_profit_pct', inplace=True) - df.sort_index(ascending=False, inplace=True) - - # 保存最终结果 - df.to_csv('results/profit_matrix_analysis.csv') - - # 打印结果 - print("\n止盈比例分析矩阵:") - print("=" * 120) - print(df.to_string()) - print("=" * 120) - - try: - # 绘制并保存矩阵分析图表 - plt = self.plot_profit_matrix(df) - plt.savefig('results/profit_matrix_analysis.png', bbox_inches='tight', dpi=300) - plt.close() - except Exception as e: - print(f"绘制图表时出错: {str(e)}") - - return df - else: - print("没有找到任何分析结果") - return pd.DataFrame() - -def main(): +def run_backtest(stocks_buy_dates, end_date): + """ + 运行回测的便捷方法 + :param stocks_buy_dates: 字典,键为股票代码,值为买入日期列表, + 例如 {'SH600522': ['2022-05-10', '2022-06-10'], 'SZ002340': ['2022-06-15']} + :param end_date: 所有股票共同的结束日期,格式为'YYYY-MM-DD' + :return: 回测结果 + """ # 数据库连接配置 db_config = { 'host': '192.168.18.199', @@ -1421,56 +453,40 @@ def main(): } try: - connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}" - # 创建results目录 - if not os.path.exists('results'): os.makedirs('results') - # 创建分析器实例 - analyzer = StockAnalyzer(connection_string) + # 创建回测器实例 + backtester = StockBacktester(db_config) - # 设置分析参数 - start_date = '2022-05-05' - end_date = '2023-08-28' - stop_loss_pct = -0.99 # -99%止损 - pullback_entry_pct = -0.05 # -5%回调建仓 + # 运行回测 + results, stats_list = backtester.backtest_multiple_stocks(stocks_buy_dates, end_date) - # 运行止盈比例分析 - profit_matrix = analyzer.analyze_profit_matrix( - start_date, - end_date, - stop_loss_pct, - pullback_entry_pct - ) + # 绘制图表 + if results: + # 绘制所有股票的图表 + backtester.plot_all_stocks(results, stats_list) + + # 绘制每支股票的单独图表 + for symbol, result in results.items(): + backtester.plot_results(result, symbol) - if not profit_matrix.empty: - # 使用最优止盈比例进行完整分析 - take_profit_pct = profit_matrix['total_profit'].idxmax() / 100 - print(f"\n使用最优止盈比例 {take_profit_pct*100:.1f}% 运行详细分析") - - # 运行分析并只生成所需的图表 - results = analyzer.analyze_all_stocks( - start_date, - end_date, - take_profit_pct, - stop_loss_pct, - pullback_entry_pct - ) - - # 过滤掉空结果 - valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} - - if valid_results: - # 只生成所需的两个图表 - plt, _ = analyzer.plot_all_stocks(valid_results) - plt.savefig('results/all_stocks_analysis.png', bbox_inches='tight', dpi=300) - plt.close() + return results, stats_list except Exception as e: - print(f"程序运行出错: {str(e)}") - raise + print(f"回测运行出错: {str(e)}") + import traceback + traceback.print_exc() + return None, None if __name__ == "__main__": - main() \ No newline at end of file + # 测试回测 - 现在使用股票对应多个买入日期 + stocks_buy_dates = { + 'SH600522': ['2022-05-10', '2022-06-10'], # 两个买入日期 + 'SZ002340': ['2022-06-15'], # 一个买入日期 + 'SH601615': ['2022-07-20', '2022-08-01'] # 两个买入日期 + } + end_date = '2022-10-20' # 所有股票共同的结束日期 + + run_backtest(stocks_buy_dates, end_date) \ No newline at end of file diff --git a/src/stock_simulation.py b/src/stock_simulation.py new file mode 100644 index 0000000..e5a4662 --- /dev/null +++ b/src/stock_simulation.py @@ -0,0 +1,1249 @@ +import pandas as pd +import numpy as np +from sqlalchemy import create_engine, text +from datetime import datetime, timedelta, time +import uuid +import logging +import sys +from typing import Dict, List, Optional, Tuple, Union + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("simulation.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("stock_simulation") + +class StockSimulator: + def __init__(self, db_connection_string): + """ + 初始化股票模拟交易平台 + :param db_connection_string: 数据库连接字符串 + """ + self.engine = create_engine(db_connection_string, future=True) + self.MORNING_OPEN = time(9, 30) # 早市开盘时间 9:30 + self.AFTERNOON_CLOSE = time(15, 0) # 午市收盘时间 15:00 + + def _is_before_open(self, dt: datetime) -> bool: + """判断是否在开盘前""" + return dt.time() < self.MORNING_OPEN + + def _is_after_close(self, dt: datetime) -> bool: + """判断是否在收盘后""" + return dt.time() > self.AFTERNOON_CLOSE + + def _read_sql(self, query: str) -> pd.DataFrame: + raw_conn = self.engine.raw_connection() + try: + df = pd.read_sql(query, raw_conn) + finally: + raw_conn.close() + return df + + def _get_stock_info(self, stock_code: str) -> Dict: + """获取股票基本信息""" + try: + query = f""" + SELECT gp_name FROM gp_code_all + WHERE gp_code = '{stock_code}' OR gp_code_two = '{stock_code}' OR gp_code_three = '{stock_code}' + LIMIT 1 + """ + result = self._read_sql(query) + + if result.empty: + return {"success": False, "message": f"股票代码 {stock_code} 不存在"} + + return { + "success": True, + "stock_code": stock_code, + "stock_name": result.iloc[0]['gp_name'] + } + except Exception as e: + logger.error(f"获取股票信息出错: {str(e)}") + return {"success": False, "message": f"获取股票信息出错: {str(e)}"} + + def _get_latest_trade_date(self) -> str: + """获取最新交易日期""" + try: + query = """ + SELECT MAX(timestamp) as latest_date FROM gp_day_data + """ + result = self._read_sql(query) + + if result.empty or pd.isna(result.iloc[0]['latest_date']): + # 如果查询失败,返回当前日期 + return datetime.now().strftime('%Y-%m-%d') + + return result.iloc[0]['latest_date'].strftime('%Y-%m-%d') + except Exception as e: + logger.error(f"获取最新交易日期出错: {str(e)}") + # 如果出错,返回当前日期 + return datetime.now().strftime('%Y-%m-%d') + + def _get_next_trade_date(self, date_str: str) -> str: + """获取下一个交易日期""" + try: + query = f""" + SELECT MIN(timestamp) as next_date + FROM gp_day_data + WHERE timestamp > '{date_str}' + """ + result = self._read_sql(query) + + if result.empty or pd.isna(result.iloc[0]['next_date']): + # 如果查询失败,返回当前日期加1天 + next_date = datetime.strptime(date_str, '%Y-%m-%d') + timedelta(days=1) + return next_date.strftime('%Y-%m-%d') + + return result.iloc[0]['next_date'].strftime('%Y-%m-%d') + except Exception as e: + logger.error(f"获取下一个交易日期出错: {str(e)}") + # 如果出错,返回当前日期加1天 + next_date = datetime.strptime(date_str, '%Y-%m-%d') + timedelta(days=1) + return next_date.strftime('%Y-%m-%d') + + def _get_stock_price(self, stock_code: str, trade_date: str, price_type: str = 'close') -> float: + """ + 获取指定股票在指定日期的价格 + :param stock_code: 股票代码 + :param trade_date: 交易日期 + :param price_type: 价格类型 (open/close) + :return: 股票价格 + """ + try: + query = f""" + SELECT CAST({price_type} as DECIMAL(10,2)) as price + FROM gp_day_data + WHERE symbol = '{stock_code}' AND timestamp = '{trade_date}' + """ + result = self._read_sql(query) + + if result.empty: + return None + + return float(result.iloc[0]['price']) + except Exception as e: + logger.error(f"获取股票价格出错: {str(e)}") + return None + + def _calculate_buy_quantity(self, available_capital: float, price: float, single_position_limit: float) -> int: + """ + 计算可买入的股票数量,必须是100的整数倍 + """ + # 计算可用资金和单支标的限额的较小值 + max_amount = min(available_capital, single_position_limit) + + # 计算理论上可买入的最大数量 + theoretical_max = max_amount / price + + # 取100的整数倍,确保不少于100股 + quantity = int(theoretical_max / 100) * 100 + + # 如果不足100股,检查是否有足够资金买入100股 + if quantity == 0 and available_capital >= price * 100: + quantity = 100 + + return quantity + + def _execute_database_operation(self, query: str, params: Optional[Dict] = None) -> bool: + """执行数据库操作""" + try: + with self.engine.begin() as conn: + if params: + conn.execute(text(query), params) + else: + conn.execute(text(query)) + return True + except Exception as e: + logger.error(f"数据库操作出错: {str(e)}") + return False + + def _get_account_info(self, account_id: str) -> Dict: + """获取账户信息""" + try: + query = f""" + SELECT * FROM simulation_account WHERE id = '{account_id}' + """ + result = self._read_sql(query) + + if result.empty: + return {"success": False, "message": f"账户ID {account_id} 不存在"} + + account_data = result.iloc[0].to_dict() + + if account_data['status'] == 'closed': + return {"success": False, "message": f"账户ID {account_id} 已关闭"} + + return { + "success": True, + "account_data": account_data + } + except Exception as e: + logger.error(f"获取账户信息出错: {str(e)}") + return {"success": False, "message": f"获取账户信息出错: {str(e)}"} + + def _check_holding_exists(self, account_id: str, stock_code: str) -> bool: + """检查是否已持有该股票""" + try: + query = f""" + SELECT * FROM simulation_holding + WHERE account_id = '{account_id}' AND stock_code = '{stock_code}' + """ + result = self._read_sql(query) + + return not result.empty + except Exception as e: + logger.error(f"检查持仓出错: {str(e)}") + return False + + def init_simulation_account(self, total_capital: float, single_position_limit: float) -> Dict: + """ + 初始化模拟账户 + :param total_capital: 总仓位金额 + :param single_position_limit: 单支股票最大仓位 + :return: 账户信息 + """ + try: + account_id = str(uuid.uuid4()) + now = datetime.now() + + # 插入账户记录 + query = """ + INSERT INTO simulation_account ( + id, total_capital, single_position_limit, available_capital, + create_time, update_time + ) VALUES ( + :account_id, :total_capital, :single_position_limit, :available_capital, + :create_time, :update_time + ) + """ + + params = { + 'account_id': account_id, + 'total_capital': total_capital, + 'single_position_limit': single_position_limit, + 'available_capital': total_capital, + 'create_time': now, + 'update_time': now + } + + success = self._execute_database_operation(query, params) + + if not success: + return {"success": False, "message": "创建模拟账户失败"} + + # 记录操作日志 + log_query = """ + INSERT INTO simulation_operation_log ( + account_id, operation_type, operation_time, status, message + ) VALUES ( + :account_id, 'init', :operation_time, 'executed', + :message + ) + """ + + log_params = { + 'account_id': account_id, + 'operation_time': now, + 'message': f"初始化账户,总资金: {total_capital},单支最大仓位: {single_position_limit}" + } + + self._execute_database_operation(log_query, log_params) + + # 初始化净值记录 + today = now.strftime('%Y-%m-%d') + net_value_query = """ + INSERT INTO simulation_net_value ( + account_id, trade_date, net_value, available_capital, holding_value, + daily_profit_loss, accumulated_profit_loss, yield_rate + ) VALUES ( + :account_id, :trade_date, :net_value, :available_capital, :holding_value, + :daily_profit_loss, :accumulated_profit_loss, :yield_rate + ) + """ + + net_value_params = { + 'account_id': account_id, + 'trade_date': today, + 'net_value': total_capital, + 'available_capital': total_capital, + 'holding_value': 0, + 'daily_profit_loss': 0, + 'accumulated_profit_loss': 0, + 'yield_rate': 0 + } + + self._execute_database_operation(net_value_query, net_value_params) + + return { + "success": True, + "account_id": account_id, + "total_capital": total_capital, + "available_capital": total_capital, + "message": "模拟账户创建成功" + } + + except Exception as e: + logger.error(f"初始化模拟账户出错: {str(e)}") + return {"success": False, "message": f"初始化模拟账户出错: {str(e)}"} + + def buy_stock(self, account_id: str, stock_code: str) -> Dict: + """ + 买入股票 + :param account_id: 账户ID + :param stock_code: 股票代码 + :return: 操作结果 + """ + try: + now = datetime.now() + # 记录操作 + operation_query = """ + INSERT INTO simulation_operation_log ( + account_id, operation_type, stock_code, operation_time, status, message + ) VALUES ( + :account_id, 'buy', :stock_code, :operation_time, 'pending', + :message + ) + """ + + operation_params = { + 'account_id': account_id, + 'stock_code': stock_code, + 'operation_time': now, + 'message': f"买入股票 {stock_code}" + } + + self._execute_database_operation(operation_query, operation_params) + + # 获取账户信息 + account_info = self._get_account_info(account_id) + if not account_info['success']: + return account_info + + account_data = account_info['account_data'] + + # 获取股票信息 + stock_info = self._get_stock_info(stock_code) + if not stock_info['success']: + return stock_info + + # 确定交易日期和价格 + today = now.strftime('%Y-%m-%d') + latest_trade_date = self._get_latest_trade_date() + + # 判断买入执行时间 + is_before_open = self._is_before_open(now) + trade_date = today if is_before_open else self._get_next_trade_date(today) + price_type = 'open' # 按开盘价交易 + + # 获取股票价格 + price = self._get_stock_price(stock_code, trade_date, price_type) + + # 如果获取不到下一个交易日的价格,使用最新交易日的收盘价作为参考 + if price is None: + price = self._get_stock_price(stock_code, latest_trade_date, 'close') + if price is None: + return {"success": False, "message": f"无法获取股票 {stock_code} 的价格信息"} + + # 计算可买入数量 + available_capital = float(account_data['available_capital']) + single_position_limit = float(account_data['single_position_limit']) + + # 检查是否已持有该股票 + already_holding = self._check_holding_exists(account_id, stock_code) + if already_holding: + return {"success": False, "message": f"账户已持有股票 {stock_code},不能重复买入"} + + quantity = self._calculate_buy_quantity(available_capital, price, single_position_limit) + + if quantity < 100: + return {"success": False, "message": f"可用资金不足以购买最小交易单位(100股)的股票 {stock_code}"} + + # 计算交易金额 + trade_amount = quantity * price + + # 更新操作日志 + update_log_query = """ + UPDATE simulation_operation_log SET + execution_time = :execution_time, + execution_price = :execution_price, + status = 'executed', + message = :message + WHERE account_id = :account_id AND operation_type = 'buy' AND stock_code = :stock_code + ORDER BY id DESC LIMIT 1 + """ + + update_log_params = { + 'execution_time': now, + 'execution_price': price, + 'message': f"买入股票 {stock_code}, 数量: {quantity}, 价格: {price}, 金额: {trade_amount}", + 'account_id': account_id, + 'stock_code': stock_code + } + + self._execute_database_operation(update_log_query, update_log_params) + + # 添加持仓记录 + holding_query = """ + INSERT INTO simulation_holding ( + account_id, stock_code, stock_name, quantity, cost_price, + current_price, market_value + ) VALUES ( + :account_id, :stock_code, :stock_name, :quantity, :cost_price, + :current_price, :market_value + ) + """ + + holding_params = { + 'account_id': account_id, + 'stock_code': stock_code, + 'stock_name': stock_info['stock_name'], + 'quantity': quantity, + 'cost_price': price, + 'current_price': price, + 'market_value': trade_amount + } + + self._execute_database_operation(holding_query, holding_params) + + # 添加交易记录 + transaction_query = """ + INSERT INTO simulation_transaction ( + account_id, stock_code, stock_name, transaction_type, + transaction_price, quantity, transaction_amount, trade_time + ) VALUES ( + :account_id, :stock_code, :stock_name, 'buy', + :transaction_price, :quantity, :transaction_amount, :trade_time + ) + """ + + transaction_params = { + 'account_id': account_id, + 'stock_code': stock_code, + 'stock_name': stock_info['stock_name'], + 'transaction_price': price, + 'quantity': quantity, + 'transaction_amount': trade_amount, + 'trade_time': now + } + + self._execute_database_operation(transaction_query, transaction_params) + + # 更新账户信息 + new_available_capital = available_capital - trade_amount + holding_value = float(account_data['holding_value']) + trade_amount + + update_account_query = """ + UPDATE simulation_account SET + available_capital = :available_capital, + holding_value = :holding_value, + update_time = :update_time + WHERE id = :account_id + """ + + update_account_params = { + 'available_capital': new_available_capital, + 'holding_value': holding_value, + 'update_time': now, + 'account_id': account_id + } + + self._execute_database_operation(update_account_query, update_account_params) + + # 更新净值记录 + net_value_date = datetime.strptime(trade_date, '%Y-%m-%d').date() + self._update_net_value(account_id, net_value_date) + + return { + "success": True, + "message": f"买入股票 {stock_code} 成功", + "stock_code": stock_code, + "stock_name": stock_info['stock_name'], + "quantity": quantity, + "price": price, + "amount": trade_amount, + "trade_date": trade_date + } + + except Exception as e: + logger.error(f"买入股票出错: {str(e)}") + return {"success": False, "message": f"买入股票出错: {str(e)}"} + + def sell_stock(self, account_id: str, stock_code: str) -> Dict: + """ + 卖出股票 + :param account_id: 账户ID + :param stock_code: 股票代码 + :return: 操作结果 + """ + try: + now = datetime.now() + + # 记录操作 + operation_query = """ + INSERT INTO simulation_operation_log ( + account_id, operation_type, stock_code, operation_time, status, message + ) VALUES ( + :account_id, 'sell', :stock_code, :operation_time, 'pending', + :message + ) + """ + + operation_params = { + 'account_id': account_id, + 'stock_code': stock_code, + 'operation_time': now, + 'message': f"卖出股票 {stock_code}" + } + + self._execute_database_operation(operation_query, operation_params) + + # 获取账户信息 + account_info = self._get_account_info(account_id) + if not account_info['success']: + return account_info + + account_data = account_info['account_data'] + + # 检查是否持有该股票 + holding_query = f""" + SELECT * FROM simulation_holding + WHERE account_id = '{account_id}' AND stock_code = '{stock_code}' + """ + holding = self._read_sql(holding_query) + + if holding.empty: + return {"success": False, "message": f"账户未持有股票 {stock_code},无法卖出"} + + holding_data = holding.iloc[0] + + # 确定交易日期和价格 + today = now.strftime('%Y-%m-%d') + latest_trade_date = self._get_latest_trade_date() + + # 判断卖出执行时间 + is_after_close = self._is_after_close(now) + trade_date = today if not is_after_close else self._get_next_trade_date(today) + price_type = 'close' # 按收盘价交易 + + # 获取股票价格 + price = self._get_stock_price(stock_code, trade_date, price_type) + + # 如果获取不到当日收盘价,使用最新交易日的收盘价 + if price is None: + price = self._get_stock_price(stock_code, latest_trade_date, 'close') + if price is None: + price = float(holding_data['current_price']) # 使用当前价格 + + # 计算交易金额和盈亏 + quantity = int(holding_data['quantity']) + cost_price = float(holding_data['cost_price']) + + trade_amount = quantity * price + profit_loss = (price - cost_price) * quantity + + # 更新操作日志 + update_log_query = """ + UPDATE simulation_operation_log SET + execution_time = :execution_time, + execution_price = :execution_price, + status = 'executed', + message = :message + WHERE account_id = :account_id AND operation_type = 'sell' AND stock_code = :stock_code + ORDER BY id DESC LIMIT 1 + """ + + update_log_params = { + 'execution_time': now, + 'execution_price': price, + 'message': f"卖出股票 {stock_code}, 数量: {quantity}, 价格: {price}, 金额: {trade_amount}, 盈亏: {profit_loss}", + 'account_id': account_id, + 'stock_code': stock_code + } + + self._execute_database_operation(update_log_query, update_log_params) + + # 添加交易记录 + transaction_query = """ + INSERT INTO simulation_transaction ( + account_id, stock_code, stock_name, transaction_type, + transaction_price, quantity, transaction_amount, profit_loss, trade_time + ) VALUES ( + :account_id, :stock_code, :stock_name, 'sell', + :transaction_price, :quantity, :transaction_amount, :profit_loss, :trade_time + ) + """ + + transaction_params = { + 'account_id': account_id, + 'stock_code': stock_code, + 'stock_name': holding_data['stock_name'], + 'transaction_price': price, + 'quantity': quantity, + 'transaction_amount': trade_amount, + 'profit_loss': profit_loss, + 'trade_time': now + } + + self._execute_database_operation(transaction_query, transaction_params) + + # 删除持仓记录 + delete_holding_query = f""" + DELETE FROM simulation_holding + WHERE account_id = '{account_id}' AND stock_code = '{stock_code}' + """ + + self._execute_database_operation(delete_holding_query) + + # 更新账户信息 + available_capital = float(account_data['available_capital']) + trade_amount + holding_value = float(account_data['holding_value']) - float(holding_data['market_value']) + total_profit_loss = float(account_data['profit_loss']) + profit_loss + + # 更新盈亏统计 + win_count = int(account_data['win_count']) + loss_count = int(account_data['loss_count']) + + if profit_loss > 0: + win_count += 1 + else: + loss_count += 1 + + update_account_query = """ + UPDATE simulation_account SET + available_capital = :available_capital, + holding_value = :holding_value, + profit_loss = :profit_loss, + win_count = :win_count, + loss_count = :loss_count, + update_time = :update_time + WHERE id = :account_id + """ + + update_account_params = { + 'available_capital': available_capital, + 'holding_value': holding_value, + 'profit_loss': total_profit_loss, + 'win_count': win_count, + 'loss_count': loss_count, + 'update_time': now, + 'account_id': account_id + } + + self._execute_database_operation(update_account_query, update_account_params) + + # 更新净值记录 + net_value_date = datetime.strptime(trade_date, '%Y-%m-%d').date() + self._update_net_value(account_id, net_value_date) + + return { + "success": True, + "message": f"卖出股票 {stock_code} 成功", + "stock_code": stock_code, + "stock_name": holding_data['stock_name'], + "quantity": quantity, + "price": price, + "amount": trade_amount, + "profit_loss": profit_loss, + "trade_date": trade_date + } + + except Exception as e: + logger.error(f"卖出股票出错: {str(e)}") + return {"success": False, "message": f"卖出股票出错: {str(e)}"} + + def _update_net_value(self, account_id: str, trade_date: datetime.date) -> bool: + """更新账户净值""" + try: + # 获取账户信息 + account_info = self._get_account_info(account_id) + if not account_info['success']: + return False + + account_data = account_info['account_data'] + + # 获取持仓信息 + holdings_query = f""" + SELECT * FROM simulation_holding WHERE account_id = '{account_id}' + """ + holdings = self._read_sql(holdings_query) + + # 计算持仓市值和持仓盈亏 + holding_value = 0 + holding_profit_loss = 0 + + for _, holding in holdings.iterrows(): + stock_code = holding['stock_code'] + cost_price = float(holding['cost_price']) + quantity = int(holding['quantity']) + + # 获取当前价格 + current_price = self._get_stock_price( + stock_code, + trade_date.strftime('%Y-%m-%d'), + 'close' + ) + + if current_price is None: + # 如果获取不到,使用成本价 + current_price = cost_price + + market_value = current_price * quantity + profit_loss = (current_price - cost_price) * quantity + profit_loss_percent = (current_price - cost_price) / cost_price if cost_price > 0 else 0 + + # 更新持仓信息 + update_holding_query = """ + UPDATE simulation_holding SET + current_price = :current_price, + market_value = :market_value, + profit_loss = :profit_loss, + profit_loss_percent = :profit_loss_percent, + update_time = :update_time + WHERE account_id = :account_id AND stock_code = :stock_code + """ + + update_holding_params = { + 'current_price': current_price, + 'market_value': market_value, + 'profit_loss': profit_loss, + 'profit_loss_percent': profit_loss_percent, + 'update_time': datetime.now(), + 'account_id': account_id, + 'stock_code': stock_code + } + + self._execute_database_operation(update_holding_query, update_holding_params) + + holding_value += market_value + holding_profit_loss += profit_loss + + # 计算账户净值和收益率 + total_capital = float(account_data['total_capital']) + available_capital = float(account_data['available_capital']) + profit_loss = float(account_data['profit_loss']) # 已清算盈亏 + + net_value = available_capital + holding_value + accumulated_profit_loss = profit_loss + holding_profit_loss + yield_rate = accumulated_profit_loss / total_capital if total_capital > 0 else 0 + + # 查询是否已有当日净值记录 + check_query = f""" + SELECT * FROM simulation_net_value + WHERE account_id = '{account_id}' AND trade_date = '{trade_date}' + """ + existing_record = self._read_sql(check_query) + + if not existing_record.empty: + # 如果已有记录,更新 + daily_profit_loss = net_value - float(existing_record.iloc[0]['net_value']) + + update_query = """ + UPDATE simulation_net_value SET + net_value = :net_value, + available_capital = :available_capital, + holding_value = :holding_value, + daily_profit_loss = :daily_profit_loss, + accumulated_profit_loss = :accumulated_profit_loss, + yield_rate = :yield_rate, + create_time = :create_time + WHERE account_id = :account_id AND trade_date = :trade_date + """ + + update_params = { + 'net_value': net_value, + 'available_capital': available_capital, + 'holding_value': holding_value, + 'daily_profit_loss': daily_profit_loss, + 'accumulated_profit_loss': accumulated_profit_loss, + 'yield_rate': yield_rate, + 'create_time': datetime.now(), + 'account_id': account_id, + 'trade_date': trade_date + } + + self._execute_database_operation(update_query, update_params) + else: + # 如果没有记录,创建新记录 + # 查询前一天的净值 + prev_date = trade_date - timedelta(days=1) + prev_query = f""" + SELECT * FROM simulation_net_value + WHERE account_id = '{account_id}' AND trade_date <= '{prev_date}' + ORDER BY trade_date DESC LIMIT 1 + """ + prev_record = self._read_sql(prev_query) + + daily_profit_loss = 0 + if not prev_record.empty: + daily_profit_loss = net_value - float(prev_record.iloc[0]['net_value']) + + insert_query = """ + INSERT INTO simulation_net_value ( + account_id, trade_date, net_value, available_capital, holding_value, + daily_profit_loss, accumulated_profit_loss, yield_rate + ) VALUES ( + :account_id, :trade_date, :net_value, :available_capital, :holding_value, + :daily_profit_loss, :accumulated_profit_loss, :yield_rate + ) + """ + + insert_params = { + 'account_id': account_id, + 'trade_date': trade_date, + 'net_value': net_value, + 'available_capital': available_capital, + 'holding_value': holding_value, + 'daily_profit_loss': daily_profit_loss, + 'accumulated_profit_loss': accumulated_profit_loss, + 'yield_rate': yield_rate + } + + self._execute_database_operation(insert_query, insert_params) + + # 更新账户表的数据 + update_account_query = """ + UPDATE simulation_account SET + holding_value = :holding_value, + holding_profit_loss = :holding_profit_loss, + strategy_yield = :strategy_yield, + update_time = :update_time + WHERE id = :account_id + """ + + update_account_params = { + 'holding_value': holding_value, + 'holding_profit_loss': holding_profit_loss, + 'strategy_yield': yield_rate, + 'update_time': datetime.now(), + 'account_id': account_id + } + + self._execute_database_operation(update_account_query, update_account_params) + + return True + + except Exception as e: + logger.error(f"更新净值出错: {str(e)}") + return False + + def clear_account(self, account_id: str) -> Dict: + """ + 清仓账户所有持仓 + :param account_id: 账户ID + :return: 操作结果 + """ + try: + now = datetime.now() + + # 记录操作 + operation_query = """ + INSERT INTO simulation_operation_log ( + account_id, operation_type, operation_time, status, message + ) VALUES ( + :account_id, 'clear', :operation_time, 'pending', + :message + ) + """ + + operation_params = { + 'account_id': account_id, + 'operation_time': now, + 'message': f"清仓账户 {account_id} 所有持仓" + } + + self._execute_database_operation(operation_query, operation_params) + + # 获取账户信息 + account_info = self._get_account_info(account_id) + if not account_info['success']: + return account_info + + account_data = account_info['account_data'] + + # 获取所有持仓 + holdings_query = f""" + SELECT * FROM simulation_holding WHERE account_id = '{account_id}' + """ + holdings = self._read_sql(holdings_query) + + if holdings.empty: + return {"success": False, "message": f"账户 {account_id} 无持仓,无需清仓"} + + # 确定交易日期和价格 + today = now.strftime('%Y-%m-%d') + latest_trade_date = self._get_latest_trade_date() + + # 判断清仓执行时间 + is_after_close = self._is_after_close(now) + trade_date = today if not is_after_close else self._get_next_trade_date(today) + price_type = 'close' # 按收盘价交易 + + total_profit_loss = 0 + clear_summary = [] + + # 逐个处理持仓 + for _, holding in holdings.iterrows(): + stock_code = holding['stock_code'] + stock_name = holding['stock_name'] + quantity = int(holding['quantity']) + cost_price = float(holding['cost_price']) + + # 获取股票价格 + price = self._get_stock_price(stock_code, trade_date, price_type) + + # 如果获取不到当日收盘价,使用最新交易日的收盘价 + if price is None: + price = self._get_stock_price(stock_code, latest_trade_date, 'close') + if price is None: + price = float(holding['current_price']) # 使用当前价格 + + # 计算交易金额和盈亏 + trade_amount = quantity * price + profit_loss = (price - cost_price) * quantity + + # 添加交易记录 + transaction_query = """ + INSERT INTO simulation_transaction ( + account_id, stock_code, stock_name, transaction_type, + transaction_price, quantity, transaction_amount, profit_loss, trade_time + ) VALUES ( + :account_id, :stock_code, :stock_name, 'clear', + :transaction_price, :quantity, :transaction_amount, :profit_loss, :trade_time + ) + """ + + transaction_params = { + 'account_id': account_id, + 'stock_code': stock_code, + 'stock_name': stock_name, + 'transaction_price': price, + 'quantity': quantity, + 'transaction_amount': trade_amount, + 'profit_loss': profit_loss, + 'trade_time': now + } + + self._execute_database_operation(transaction_query, transaction_params) + + total_profit_loss += profit_loss + + clear_summary.append({ + 'stock_code': stock_code, + 'stock_name': stock_name, + 'quantity': quantity, + 'cost_price': cost_price, + 'price': price, + 'amount': trade_amount, + 'profit_loss': profit_loss + }) + + # 删除所有持仓记录 + delete_holdings_query = f""" + DELETE FROM simulation_holding WHERE account_id = '{account_id}' + """ + + self._execute_database_operation(delete_holdings_query) + + # 更新账户信息 + available_capital = float(account_data['available_capital']) + float(account_data['holding_value']) + account_profit_loss = float(account_data['profit_loss']) + total_profit_loss + + # 统计盈亏情况 + win_count = int(account_data['win_count']) + loss_count = int(account_data['loss_count']) + + for item in clear_summary: + if item['profit_loss'] > 0: + win_count += 1 + else: + loss_count += 1 + + # 计算策略收益率 + strategy_yield = account_profit_loss / float(account_data['total_capital']) + + update_account_query = """ + UPDATE simulation_account SET + available_capital = :available_capital, + holding_value = 0, + holding_profit_loss = 0, + profit_loss = :profit_loss, + strategy_yield = :strategy_yield, + win_count = :win_count, + loss_count = :loss_count, + status = 'closed', + update_time = :update_time + WHERE id = :account_id + """ + + update_account_params = { + 'available_capital': available_capital, + 'profit_loss': account_profit_loss, + 'strategy_yield': strategy_yield, + 'win_count': win_count, + 'loss_count': loss_count, + 'update_time': now, + 'account_id': account_id + } + + self._execute_database_operation(update_account_query, update_account_params) + + # 更新操作日志 + update_log_query = """ + UPDATE simulation_operation_log SET + execution_time = :execution_time, + status = 'executed', + message = :message + WHERE account_id = :account_id AND operation_type = 'clear' + ORDER BY id DESC LIMIT 1 + """ + + update_log_params = { + 'execution_time': now, + 'message': f"成功清仓账户 {account_id} 所有持仓,总盈亏: {total_profit_loss}", + 'account_id': account_id + } + + self._execute_database_operation(update_log_query, update_log_params) + + # 更新净值记录 + net_value_date = datetime.strptime(trade_date, '%Y-%m-%d').date() + self._update_net_value(account_id, net_value_date) + + return { + "success": True, + "message": f"清仓账户 {account_id} 成功", + "total_profit_loss": total_profit_loss, + "strategy_yield": strategy_yield, + "clear_summary": clear_summary, + "trade_date": trade_date + } + + except Exception as e: + logger.error(f"清仓账户出错: {str(e)}") + return {"success": False, "message": f"清仓账户出错: {str(e)}"} + + def get_all_accounts(self) -> Dict: + """ + 获取所有模拟账户状态 + :return: 所有账户信息 + """ + try: + query = """ + SELECT + id, total_capital, available_capital, holding_value, + profit_loss, holding_profit_loss, + strategy_yield, win_count, loss_count, status, + create_time, update_time + FROM simulation_account + """ + + accounts = self._read_sql(query) + + if accounts.empty: + return {"success": True, "message": "没有模拟账户", "accounts": []} + + # 转换为字典列表 + accounts_list = [] + for _, account in accounts.iterrows(): + # 计算总统计数据 + total_trades = account['win_count'] + account['loss_count'] + win_rate = account['win_count'] / total_trades * 100 if total_trades > 0 else 0 + + account_dict = { + 'account_id': account['id'], + 'total_capital': float(account['total_capital']), + 'available_capital': float(account['available_capital']), + 'holding_value': float(account['holding_value']), + 'total_value': float(account['available_capital']) + float(account['holding_value']), + 'profit_loss': float(account['profit_loss']), + 'holding_profit_loss': float(account['holding_profit_loss']), + 'total_profit_loss': float(account['profit_loss']) + float(account['holding_profit_loss']), + 'strategy_yield': float(account['strategy_yield']) * 100, # 转为百分比 + 'win_count': account['win_count'], + 'loss_count': account['loss_count'], + 'total_trades': total_trades, + 'win_rate': win_rate, + 'status': account['status'], + 'create_time': account['create_time'].strftime('%Y-%m-%d %H:%M:%S'), + 'update_time': account['update_time'].strftime('%Y-%m-%d %H:%M:%S') + } + accounts_list.append(account_dict) + + return { + "success": True, + "message": f"获取到 {len(accounts_list)} 个模拟账户", + "accounts": accounts_list + } + + except Exception as e: + logger.error(f"获取所有账户出错: {str(e)}") + return {"success": False, "message": f"获取所有账户出错: {str(e)}"} + + def get_account_detail(self, account_id: str) -> Dict: + """ + 获取账户详细信息 + :param account_id: 账户ID + :return: 账户详细信息 + """ + try: + # 获取账户基本信息 + account_info = self._get_account_info(account_id) + if not account_info['success']: + return account_info + + account_data = account_info['account_data'] + + # 获取持仓信息 + holdings_query = f""" + SELECT * FROM simulation_holding WHERE account_id = '{account_id}' + """ + holdings = self._read_sql(holdings_query) + + holdings_list = [] + for _, holding in holdings.iterrows(): + holding_dict = { + 'stock_code': holding['stock_code'], + 'stock_name': holding['stock_name'], + 'quantity': int(holding['quantity']), + 'cost_price': float(holding['cost_price']), + 'current_price': float(holding['current_price']), + 'market_value': float(holding['market_value']), + 'profit_loss': float(holding['profit_loss']), + 'profit_loss_percent': float(holding['profit_loss_percent']) * 100 # 转为百分比 + } + holdings_list.append(holding_dict) + + # 获取净值变动 + net_value_query = f""" + SELECT * FROM simulation_net_value + WHERE account_id = '{account_id}' + ORDER BY trade_date + """ + net_values = self._read_sql(net_value_query) + + net_values_list = [] + for _, net_value in net_values.iterrows(): + net_value_dict = { + 'trade_date': net_value['trade_date'].strftime('%Y-%m-%d'), + 'net_value': float(net_value['net_value']), + 'available_capital': float(net_value['available_capital']), + 'holding_value': float(net_value['holding_value']), + 'daily_profit_loss': float(net_value['daily_profit_loss']), + 'accumulated_profit_loss': float(net_value['accumulated_profit_loss']), + 'yield_rate': float(net_value['yield_rate']) * 100 # 转为百分比 + } + net_values_list.append(net_value_dict) + + # 获取交易记录 + transactions_query = f""" + SELECT * FROM simulation_transaction + WHERE account_id = '{account_id}' + ORDER BY trade_time + """ + transactions = self._read_sql(transactions_query) + + transactions_list = [] + for _, transaction in transactions.iterrows(): + transaction_dict = { + 'stock_code': transaction['stock_code'], + 'stock_name': transaction['stock_name'], + 'transaction_type': transaction['transaction_type'], + 'transaction_price': float(transaction['transaction_price']), + 'quantity': int(transaction['quantity']), + 'transaction_amount': float(transaction['transaction_amount']), + 'profit_loss': float(transaction['profit_loss']) if pd.notna(transaction['profit_loss']) else None, + 'trade_time': transaction['trade_time'].strftime('%Y-%m-%d %H:%M:%S') + } + transactions_list.append(transaction_dict) + + # 计算统计数据 + total_trades = account_data['win_count'] + account_data['loss_count'] + win_rate = account_data['win_count'] / total_trades * 100 if total_trades > 0 else 0 + + # 准备返回数据 + result = { + "success": True, + "account_id": account_id, + "total_capital": float(account_data['total_capital']), + "available_capital": float(account_data['available_capital']), + "holding_value": float(account_data['holding_value']), + "total_value": float(account_data['available_capital']) + float(account_data['holding_value']), + "profit_loss": float(account_data['profit_loss']), # 已清算盈亏 + "holding_profit_loss": float(account_data['holding_profit_loss']), # 持仓盈亏 + "total_profit_loss": float(account_data['profit_loss']) + float(account_data['holding_profit_loss']), + "strategy_yield": float(account_data['strategy_yield']) * 100, # 转为百分比 + "win_count": account_data['win_count'], + "loss_count": account_data['loss_count'], + "total_trades": total_trades, + "win_rate": win_rate, + "status": account_data['status'], + "holdings": holdings_list, + "net_values": net_values_list, + "transactions": transactions_list, + "create_time": account_data['create_time'].strftime('%Y-%m-%d %H:%M:%S'), + "update_time": account_data['update_time'].strftime('%Y-%m-%d %H:%M:%S') + } + + return result + + except Exception as e: + logger.error(f"获取账户详情出错: {str(e)}") + return {"success": False, "message": f"获取账户详情出错: {str(e)}"} + + +def main(): + """主函数,用于测试""" + # 数据库连接配置 + db_config = { + 'host': '192.168.18.199', + 'port': 3306, + 'user': 'root', + 'password': 'Chlry#$.8', + 'database': 'db_gp_cj' + } + + # 创建数据库连接字符串 + connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}" + + # 创建模拟器实例 + simulator = StockSimulator(connection_string) + + # 测试初始化账户 + # result = simulator.init_simulation_account(5000000, 200000) + # print("初始化账户结果:", result) + # + # if result['success']: + account_id = 'a111336c-b77c-47b2-9ea9-2717827200ba' + + # 测试买入股票 + buy_result = simulator.buy_stock(account_id, 'SH600522') + print("买入股票结果:", buy_result) + # + # # 测试获取账户详情 + # detail_result = simulator.get_account_detail(account_id) + # print("账户详情:", detail_result) + # + # # 测试卖出股票 + # sell_result = simulator.sell_stock(account_id, 'SH600522') + # print("卖出股票结果:", sell_result) + # + # # 测试再次买入股票 + # buy_result = simulator.buy_stock(account_id, 'SH601187') + # print("再次买入股票结果:", buy_result) + # + # # 测试清仓账户 + # clear_result = simulator.clear_account(account_id) + # print("清仓账户结果:", clear_result) + # + # # 测试获取所有账户 + # all_accounts = simulator.get_all_accounts() + # print("所有账户:", all_accounts) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/tq_demo.py b/src/tq_demo.py new file mode 100644 index 0000000..f256fe0 --- /dev/null +++ b/src/tq_demo.py @@ -0,0 +1,81 @@ +# 用前须知 + +## xtdata提供和MiniQmt的交互接口,本质是和MiniQmt建立连接,由MiniQmt处理行情数据请求,再把结果回传返回到python层。使用的行情服务器以及能获取到的行情数据和MiniQmt是一致的,要检查数据或者切换连接时直接操作MiniQmt即可。 + +## 对于数据获取接口,使用时需要先确保MiniQmt已有所需要的数据,如果不足可以通过补充数据接口补充,再调用数据获取接口获取。 + +## 对于订阅接口,直接设置数据回调,数据到来时会由回调返回。订阅接收到的数据一般会保存下来,同种数据不需要再单独补充。 + +# 代码讲解 + +# 从本地python导入xtquant库,如果出现报错则说明安装失败 +from xtquant import xtdata +import time + +# 设定一个标的列表 +code_list = ["000001.SZ"] +# 设定获取数据的周期 +period = "1d" + +# 下载标的行情数据 +if 1: + ## 为了方便用户进行数据管理,xtquant的大部分历史数据都是以压缩形式存储在本地的 + ## 比如行情数据,需要通过download_history_data下载,财务数据需要通过 + ## 所以在取历史数据之前,我们需要调用数据下载接口,将数据下载到本地 + for i in code_list: + xtdata.download_history_data(i, period=period, incrementally=True) # 增量下载行情数据(开高低收,等等)到本地 + + xtdata.download_financial_data(code_list) # 下载财务数据到本地 + xtdata.download_sector_data() # 下载板块数据到本地 + # 更多数据的下载方式可以通过数据字典查询 + +# 读取本地历史行情数据 +history_data = xtdata.get_market_data_ex([], code_list, period=period, count=-1) +print(history_data) +print("=" * 20) + +# 如果需要盘中的实时行情,需要向服务器进行订阅后才能获取 +# 订阅后,get_market_data函数于get_market_data_ex函数将会自动拼接本地历史行情与服务器实时行情 + +# 向服务器订阅数据 +for i in code_list: + xtdata.subscribe_quote(i, period=period, count=-1) # 设置count = -1来取到当天所有实时行情 + +# 等待订阅完成 +time.sleep(1) + +# 获取订阅后的行情 +kline_data = xtdata.get_market_data_ex([], code_list, period=period) +print(kline_data) + +# 获取订阅后的行情,并以固定间隔进行刷新,预期会循环打印10次 +for i in range(10): + # 这边做演示,就用for来循环了,实际使用中可以用while True + kline_data = xtdata.get_market_data_ex([], code_list, period=period) + print(kline_data) + time.sleep(3) # 三秒后再次获取行情 + + +# 如果不想用固定间隔触发,可以以用订阅后的回调来执行 +# 这种模式下当订阅的callback回调函数将会异步的执行,每当订阅的标的tick发生变化更新,callback回调函数就会被调用一次 +# 本地已有的数据不会触发callback + +# 定义的回测函数 +## 回调函数中,data是本次触发回调的数据,只有一条 +def f(data): + # print(data) + + code_list = list(data.keys()) # 获取到本次触发的标的代码 + + kline_in_callabck = xtdata.get_market_data_ex([], code_list, period=period) # 在回调中获取klines数据 + print(kline_in_callabck) + + +for i in code_list: + xtdata.subscribe_quote(i, period=period, count=-1, callback=f) # 订阅时设定回调函数 + +# 使用回调时,必须要同时使用xtdata.run()来阻塞程序,否则程序运行到最后一行就直接结束退出了。 +xtdata.run() + + +