commit;
This commit is contained in:
		
							parent
							
								
									ae9b2f2fcf
								
							
						
					
					
						commit
						aadef5f0fa
					
				
							
								
								
									
										129
									
								
								README.md
								
								
								
								
							
							
						
						
									
										129
									
								
								README.md
								
								
								
								
							|  | @ -211,6 +211,127 @@ pip install -r requirements.txt | |||
| 
 | ||||
| 筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。 | ||||
| 
 | ||||
| ## Docker 部署说明 | ||||
| 
 | ||||
| 本系统支持通过 Docker 进行部署,可以单实例部署或多实例部署。 | ||||
| 
 | ||||
| ### 1. 单实例部署 | ||||
| 
 | ||||
| ```bash | ||||
| # 构建镜像 | ||||
| docker-compose build | ||||
| 
 | ||||
| # 启动服务 | ||||
| docker-compose up -d | ||||
| ``` | ||||
| 
 | ||||
| ### 2. 多实例部署 | ||||
| 
 | ||||
| 提供了三个脚本用于管理多实例部署,每个实例将在不同端口上运行(从5088开始递增): | ||||
| 
 | ||||
| #### 使用 docker run 部署多实例(推荐) | ||||
| 
 | ||||
| ```bash | ||||
| # 赋予脚本执行权限 | ||||
| chmod +x deploy-multiple.sh | ||||
| 
 | ||||
| # 部署3个实例,端口将为5088、5089、5090 | ||||
| ./deploy-multiple.sh 3 | ||||
| ``` | ||||
| 
 | ||||
| 每个实例将有自己独立的数据目录:`./instances/instance-N/`,包含独立的日志和报告文件。 | ||||
| 
 | ||||
| #### 使用 docker-compose 部署多实例 | ||||
| 
 | ||||
| ```bash | ||||
| # 赋予脚本执行权限 | ||||
| chmod +x deploy-compose.sh | ||||
| 
 | ||||
| # 部署3个实例,端口将为5088、5089、5090 | ||||
| ./deploy-compose.sh 3 | ||||
| ``` | ||||
| 
 | ||||
| 每个实例将使用独立的项目名称 `stock-app-N`。 | ||||
| 
 | ||||
| #### 管理多实例 | ||||
| 
 | ||||
| 使用 `manage-instances.sh` 脚本可以方便地管理已部署的实例: | ||||
| 
 | ||||
| ```bash | ||||
| # 赋予脚本执行权限 | ||||
| chmod +x manage-instances.sh | ||||
| 
 | ||||
| # 查看所有实例状态 | ||||
| ./manage-instances.sh status | ||||
| 
 | ||||
| # 列出正在运行的实例 | ||||
| ./manage-instances.sh list | ||||
| 
 | ||||
| # 启动指定实例 | ||||
| ./manage-instances.sh start 2 | ||||
| 
 | ||||
| # 启动所有实例 | ||||
| ./manage-instances.sh start all | ||||
| 
 | ||||
| # 停止指定实例 | ||||
| ./manage-instances.sh stop 1 | ||||
| 
 | ||||
| # 停止所有实例 | ||||
| ./manage-instances.sh stop all | ||||
| 
 | ||||
| # 重启指定实例 | ||||
| ./manage-instances.sh restart 3 | ||||
| 
 | ||||
| # 查看指定实例的日志 | ||||
| ./manage-instances.sh logs 1 | ||||
| 
 | ||||
| # 删除指定实例 | ||||
| ./manage-instances.sh remove 2 | ||||
| 
 | ||||
| # 删除所有实例 | ||||
| ./manage-instances.sh remove all | ||||
| ``` | ||||
| 
 | ||||
| ### 3. 中文字体安装 | ||||
| 
 | ||||
| 系统需要中文字体以正常生成PDF报告。如果部署时遇到 `找不到中文字体文件` 错误,可以使用以下方法解决: | ||||
| 
 | ||||
| #### 自动安装字体(推荐) | ||||
| 
 | ||||
| ```bash | ||||
| # 进入容器内部 | ||||
| docker exec -it [容器ID或名称] bash | ||||
| 
 | ||||
| # 在容器内执行字体安装脚本 | ||||
| cd /app | ||||
| python src/fundamentals_llm/setup_fonts.py | ||||
| ``` | ||||
| 
 | ||||
| #### 手动安装字体 | ||||
| 
 | ||||
| 1. 在宿主机上下载中文字体文件(如simhei.ttf、wqy-microhei.ttc等) | ||||
| 2. 创建字体目录并复制字体文件: | ||||
| ```bash | ||||
| mkdir -p src/fundamentals_llm/fonts | ||||
| cp 你下载的字体文件.ttf src/fundamentals_llm/fonts/simhei.ttf | ||||
| ``` | ||||
| 3. 重新构建镜像: | ||||
| ```bash | ||||
| docker-compose build | ||||
| ``` | ||||
| 
 | ||||
| ### 4. 访问服务 | ||||
| 
 | ||||
| 部署完成后,可以通过以下URL访问服务: | ||||
| 
 | ||||
| - 单实例:`http://服务器IP:5000` | ||||
| - 多实例:`http://服务器IP:508X`(X为实例序号,从8开始) | ||||
| 
 | ||||
| 例如: | ||||
| - 实例1:`http://服务器IP:5088` | ||||
| - 实例2:`http://服务器IP:5089` | ||||
| - 实例3:`http://服务器IP:5090` | ||||
| 
 | ||||
| ## 数据库查询示例 | ||||
| 
 | ||||
| 以下SQL查询示例展示了如何从数据库中筛选特定类型的企业: | ||||
|  | @ -251,3 +372,11 @@ AND stock_code IN ( | |||
| ``` | ||||
| 
 | ||||
| 注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。  | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # 重新构建镜像 | ||||
| docker-compose build | ||||
| 
 | ||||
| # 重启所有实例 | ||||
| ./manage-instances.sh restart all | ||||
|  | @ -13,5 +13,5 @@ volcengine-python-sdk[ark] | |||
| openai>=1.0 | ||||
| reportlab>=4.3.1 | ||||
| markdown2>=2.5.3 | ||||
| # 1. 按下 Win+R ,输入 regedit 打开注册表编辑器。 | ||||
| # 2. 设置 \HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem 路径下的变量 LongPathsEnabled 为 1 即可。 | ||||
| google-genai | ||||
| redis==5.2.1 | ||||
							
								
								
									
										570
									
								
								src/app.py
								
								
								
								
							
							
						
						
									
										570
									
								
								src/app.py
								
								
								
								
							|  | @ -1,27 +1,47 @@ | |||
| import sys | ||||
| import os | ||||
| from datetime import datetime, timedelta | ||||
| import pandas as pd | ||||
| import uuid | ||||
| import json | ||||
| from threading import Thread | ||||
| 
 | ||||
| from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db | ||||
| 
 | ||||
| # 添加项目根目录到 Python 路径 | ||||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||
| 
 | ||||
| from flask import Flask, jsonify, request | ||||
| from flask import Flask, jsonify, request, send_from_directory | ||||
| from flask_cors import CORS | ||||
| import logging | ||||
| 
 | ||||
| # 导入企业筛选器 | ||||
| from src.fundamentals_llm.enterprise_screener import EnterpriseScreener | ||||
| 
 | ||||
| # 导入股票回测器 | ||||
| from src.stock_analysis_v2 import run_backtest, StockBacktester | ||||
| 
 | ||||
| # 设置日志 | ||||
| logging.basicConfig( | ||||
|     level=logging.INFO, | ||||
|     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||
|     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||||
|     handlers=[ | ||||
|         logging.StreamHandler(),  # 输出到控制台 | ||||
|         logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')  # 输出到文件 | ||||
|     ] | ||||
| ) | ||||
| 
 | ||||
| # 确保logs和results目录存在 | ||||
| os.makedirs('logs', exist_ok=True) | ||||
| os.makedirs('results', exist_ok=True) | ||||
| os.makedirs('results/tasks', exist_ok=True) | ||||
| os.makedirs('static/results', exist_ok=True) | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| logger.info("Flask应用启动") | ||||
| 
 | ||||
| # 创建 Flask 应用 | ||||
| app = Flask(__name__) | ||||
| app = Flask(__name__, static_folder='static') | ||||
| CORS(app)  # 启用跨域请求支持 | ||||
| 
 | ||||
| # 创建企业筛选器实例 | ||||
|  | @ -30,6 +50,329 @@ screener = EnterpriseScreener() | |||
| # 获取数据库连接 | ||||
| db = next(get_db()) | ||||
| 
 | ||||
| # 获取项目根目录 | ||||
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
| REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports') | ||||
| 
 | ||||
| # 确保reports目录存在 | ||||
| os.makedirs(REPORTS_DIR, exist_ok=True) | ||||
| logger.info(f"报告目录路径: {REPORTS_DIR}") | ||||
| 
 | ||||
| # 存储回测任务状态的字典 | ||||
| backtest_tasks = {} | ||||
| 
 | ||||
| def run_backtest_task(task_id, stocks_buy_dates, end_date): | ||||
|     """ | ||||
|     在后台运行回测任务 | ||||
|     """ | ||||
|     try: | ||||
|         logger.info(f"开始执行回测任务 {task_id}") | ||||
|         # 更新任务状态为进行中 | ||||
|         backtest_tasks[task_id]['status'] = 'running' | ||||
|          | ||||
|         # 运行回测 | ||||
|         results, stats_list = run_backtest(stocks_buy_dates, end_date) | ||||
|          | ||||
|         # 如果回测成功 | ||||
|         if results and stats_list: | ||||
|             # 计算总体统计 | ||||
|             stats_df = pd.DataFrame(stats_list) | ||||
|             total_profit = stats_df['final_profit'].sum() | ||||
|             avg_win_rate = stats_df['win_rate'].mean() | ||||
|             avg_holding_days = stats_df['avg_holding_days'].mean() | ||||
|              | ||||
|             # 找出最佳止盈比例(假设为0.15,实际应从回测结果中分析得出) | ||||
|             best_take_profit_pct = 0.15 | ||||
|              | ||||
|             # 获取图表URL | ||||
|             chart_urls = { | ||||
|                 "all_stocks": f"/static/results/{task_id}/all_stocks_analysis.png", | ||||
|                 "profit_matrix": f"/static/results/{task_id}/profit_matrix_analysis.png" | ||||
|             } | ||||
|              | ||||
|             # 保存股票详细统计 | ||||
|             stock_stats = [] | ||||
|             for stat in stats_list: | ||||
|                 stock_stats.append({ | ||||
|                     "symbol": stat['symbol'], | ||||
|                     "total_trades": int(stat['total_trades']), | ||||
|                     "profitable_trades": int(stat['profitable_trades']), | ||||
|                     "loss_trades": int(stat['loss_trades']), | ||||
|                     "win_rate": float(stat['win_rate']), | ||||
|                     "avg_holding_days": float(stat['avg_holding_days']), | ||||
|                     "final_profit": float(stat['final_profit']), | ||||
|                     "entry_count": int(stat['entry_count']) | ||||
|                 }) | ||||
|              | ||||
|             # 构建结果数据 | ||||
|             result_data = { | ||||
|                 "task_id": task_id, | ||||
|                 "status": "completed", | ||||
|                 "results": { | ||||
|                     "total_profit": float(total_profit), | ||||
|                     "win_rate": float(avg_win_rate), | ||||
|                     "avg_holding_days": float(avg_holding_days), | ||||
|                     "best_take_profit_pct": best_take_profit_pct, | ||||
|                     "stock_stats": stock_stats, | ||||
|                     "chart_urls": chart_urls | ||||
|                 } | ||||
|             } | ||||
|              | ||||
|             # 保存结果到文件 | ||||
|             task_result_path = os.path.join('results', 'tasks', f"{task_id}.json") | ||||
|             with open(task_result_path, 'w', encoding='utf-8') as f: | ||||
|                 json.dump(result_data, f, ensure_ascii=False, indent=2) | ||||
|              | ||||
|             # 更新任务状态为已完成 | ||||
|             backtest_tasks[task_id].update({ | ||||
|                 'status': 'completed', | ||||
|                 'results': result_data['results'] | ||||
|             }) | ||||
|              | ||||
|             logger.info(f"回测任务 {task_id} 已完成") | ||||
|         else: | ||||
|             # 更新任务状态为失败 | ||||
|             backtest_tasks[task_id]['status'] = 'failed' | ||||
|             backtest_tasks[task_id]['error'] = "回测未产生有效结果" | ||||
|             logger.error(f"回测任务 {task_id} 失败:回测未产生有效结果") | ||||
|     except Exception as e: | ||||
|         # 更新任务状态为失败 | ||||
|         backtest_tasks[task_id]['status'] = 'failed' | ||||
|         backtest_tasks[task_id]['error'] = str(e) | ||||
|         logger.error(f"回测任务 {task_id} 失败:{str(e)}") | ||||
| 
 | ||||
| @app.route('/api/backtest/run', methods=['POST']) | ||||
| def start_backtest(): | ||||
|     """启动回测任务 | ||||
|      | ||||
|     请求体格式: | ||||
|     { | ||||
|         "stocks_buy_dates": { | ||||
|             "SH600522": ["2022-05-10", "2022-06-10"],  // 股票代码: [买入日期列表] | ||||
|             "SZ002340": ["2022-06-15"], | ||||
|             "SH601615": ["2022-07-20", "2022-08-01"] | ||||
|         }, | ||||
|         "end_date": "2022-10-20"  // 所有股票共同的结束日期 | ||||
|     } | ||||
|      | ||||
|     返回内容: | ||||
|     { | ||||
|         "task_id": "backtask-khkerhy4u237y489237489truiy8432" | ||||
|     } | ||||
|     """ | ||||
|     try: | ||||
|         # 从请求体获取参数 | ||||
|         data = request.get_json() | ||||
|          | ||||
|         if not data: | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": "请求格式错误: 需要提供JSON数据" | ||||
|             }), 400 | ||||
|              | ||||
|         stocks_buy_dates = data.get('stocks_buy_dates') | ||||
|         end_date = data.get('end_date') | ||||
|          | ||||
|         if not stocks_buy_dates or not isinstance(stocks_buy_dates, dict): | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": "请求格式错误: 需要提供stocks_buy_dates字典" | ||||
|             }), 400 | ||||
|              | ||||
|         if not end_date or not isinstance(end_date, str): | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": "请求格式错误: 需要提供有效的end_date" | ||||
|             }), 400 | ||||
|              | ||||
|         # 验证日期格式 | ||||
|         try: | ||||
|             datetime.strptime(end_date, '%Y-%m-%d') | ||||
|             for stock_code, buy_dates in stocks_buy_dates.items(): | ||||
|                 if not isinstance(buy_dates, list): | ||||
|                     return jsonify({ | ||||
|                         "status": "error",  | ||||
|                         "message": f"请求格式错误: 股票 {stock_code} 的买入日期必须是列表" | ||||
|                     }), 400 | ||||
|                 for buy_date in buy_dates: | ||||
|                     datetime.strptime(buy_date, '%Y-%m-%d') | ||||
|         except ValueError as e: | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": f"日期格式错误: {str(e)}" | ||||
|             }), 400 | ||||
|          | ||||
|         # 生成任务ID | ||||
|         task_id = f"backtask-{uuid.uuid4().hex[:16]}" | ||||
|          | ||||
|         # 创建任务目录 | ||||
|         task_dir = os.path.join('static', 'results', task_id) | ||||
|         os.makedirs(task_dir, exist_ok=True) | ||||
|          | ||||
|         # 记录任务信息 | ||||
|         backtest_tasks[task_id] = { | ||||
|             'status': 'pending',  | ||||
|             'created_at': datetime.now().isoformat(), | ||||
|             'stocks_buy_dates': stocks_buy_dates, | ||||
|             'end_date': end_date | ||||
|         } | ||||
|          | ||||
|         # 创建线程运行回测 | ||||
|         thread = Thread(target=run_backtest_task, args=(task_id, stocks_buy_dates, end_date)) | ||||
|         thread.daemon = True | ||||
|         thread.start() | ||||
|          | ||||
|         logger.info(f"已创建回测任务: {task_id}") | ||||
|          | ||||
|         return jsonify({ | ||||
|             "task_id": task_id | ||||
|         }) | ||||
|          | ||||
|     except Exception as e: | ||||
|         logger.error(f"创建回测任务失败: {str(e)}") | ||||
|         return jsonify({ | ||||
|             "status": "error",  | ||||
|             "message": f"创建回测任务失败: {str(e)}" | ||||
|         }), 500 | ||||
| 
 | ||||
| @app.route('/api/backtest/status', methods=['GET']) | ||||
| def check_backtest_status(): | ||||
|     """查询回测任务状态 | ||||
|      | ||||
|     参数: | ||||
|     - task_id: 回测任务ID | ||||
|      | ||||
|     返回内容: | ||||
|     { | ||||
|         "task_id": "backtask-khkerhy4u237y489237489truiy8432", | ||||
|         "status": "running" | "completed" | "failed", | ||||
|         "created_at": "2023-12-01T10:30:45", | ||||
|         "error": "错误信息(如有)" | ||||
|     } | ||||
|     """ | ||||
|     try: | ||||
|         task_id = request.args.get('task_id') | ||||
|          | ||||
|         if not task_id: | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": "请求格式错误: 需要提供task_id参数" | ||||
|             }), 400 | ||||
|          | ||||
|         # 检查任务是否存在 | ||||
|         if task_id not in backtest_tasks: | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": f"任务不存在: {task_id}" | ||||
|             }), 404 | ||||
|          | ||||
|         # 获取任务信息 | ||||
|         task_info = backtest_tasks[task_id] | ||||
|          | ||||
|         # 构建响应 | ||||
|         response = { | ||||
|             "task_id": task_id, | ||||
|             "status": task_info['status'], | ||||
|             "created_at": task_info['created_at'] | ||||
|         } | ||||
|          | ||||
|         # 如果任务失败,添加错误信息 | ||||
|         if task_info['status'] == 'failed' and 'error' in task_info: | ||||
|             response['error'] = task_info['error'] | ||||
|          | ||||
|         return jsonify(response) | ||||
|          | ||||
|     except Exception as e: | ||||
|         logger.error(f"查询任务状态失败: {str(e)}") | ||||
|         return jsonify({ | ||||
|             "status": "error",  | ||||
|             "message": f"查询任务状态失败: {str(e)}" | ||||
|         }), 500 | ||||
| 
 | ||||
| @app.route('/api/backtest/result', methods=['GET']) | ||||
| def get_backtest_result(): | ||||
|     """获取回测任务结果 | ||||
|      | ||||
|     参数: | ||||
|     - task_id: 回测任务ID | ||||
|      | ||||
|     返回内容: | ||||
|     { | ||||
|         "task_id": "backtask-2023121-001", | ||||
|         "status": "completed", | ||||
|         "results": { | ||||
|             "total_profit": 123456.78, | ||||
|             "win_rate": 75.5, | ||||
|             "avg_holding_days": 12.3, | ||||
|             "best_take_profit_pct": 0.15, | ||||
|             "stock_stats": [...], | ||||
|             "chart_urls": { | ||||
|                 "all_stocks": "/static/results/backtask-2023121-001/all_stocks_analysis.png", | ||||
|                 "profit_matrix": "/static/results/backtask-2023121-001/profit_matrix_analysis.png" | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     """ | ||||
|     try: | ||||
|         task_id = request.args.get('task_id') | ||||
|          | ||||
|         if not task_id: | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": "请求格式错误: 需要提供task_id参数" | ||||
|             }), 400 | ||||
|          | ||||
|         # 检查任务是否存在 | ||||
|         if task_id not in backtest_tasks: | ||||
|             # 尝试从文件加载 | ||||
|             task_result_path = os.path.join('results', 'tasks', f"{task_id}.json") | ||||
|             if os.path.exists(task_result_path): | ||||
|                 with open(task_result_path, 'r', encoding='utf-8') as f: | ||||
|                     result_data = json.load(f) | ||||
|                 return jsonify(result_data) | ||||
|             else: | ||||
|                 return jsonify({ | ||||
|                     "status": "error",  | ||||
|                     "message": f"任务不存在: {task_id}" | ||||
|                 }), 404 | ||||
|          | ||||
|         # 获取任务信息 | ||||
|         task_info = backtest_tasks[task_id] | ||||
|          | ||||
|         # 检查任务是否完成 | ||||
|         if task_info['status'] != 'completed': | ||||
|             return jsonify({ | ||||
|                 "status": "error",  | ||||
|                 "message": f"任务尚未完成或已失败: {task_info['status']}" | ||||
|             }), 400 | ||||
|          | ||||
|         # 构建响应 | ||||
|         response = { | ||||
|             "task_id": task_id, | ||||
|             "status": "completed", | ||||
|             "results": task_info['results'] | ||||
|         } | ||||
|          | ||||
|         return jsonify(response) | ||||
|          | ||||
|     except Exception as e: | ||||
|         logger.error(f"获取任务结果失败: {str(e)}") | ||||
|         return jsonify({ | ||||
|             "status": "error",  | ||||
|             "message": f"获取任务结果失败: {str(e)}" | ||||
|         }), 500 | ||||
| 
 | ||||
| @app.route('/api/reports/<path:filename>') | ||||
| def serve_report(filename): | ||||
|     """提供PDF报告的访问""" | ||||
|     try: | ||||
|         logger.info(f"请求文件: {filename}") | ||||
|         logger.info(f"从目录: {REPORTS_DIR} 提供文件") | ||||
|         return send_from_directory(REPORTS_DIR, filename, as_attachment=True) | ||||
|     except Exception as e: | ||||
|         logger.error(f"提供报告文件失败: {str(e)}") | ||||
|         return jsonify({"error": "文件不存在"}), 404 | ||||
| 
 | ||||
| @app.route('/api/health', methods=['GET']) | ||||
| def health_check(): | ||||
|     """健康检查接口""" | ||||
|  | @ -185,10 +528,10 @@ def generate_reports(): | |||
|          | ||||
|         # 导入 PDF 生成器模块 | ||||
|         try: | ||||
|             from src.fundamentals_llm.pdf_generator import generate_investment_report | ||||
|             from src.fundamentals_llm.pdf_generator import PDFGenerator | ||||
|         except ImportError: | ||||
|             try: | ||||
|                 from fundamentals_llm.pdf_generator import generate_investment_report | ||||
|                 from fundamentals_llm.pdf_generator import PDFGenerator | ||||
|             except ImportError as e: | ||||
|                 logger.error(f"无法导入 PDF 生成器模块: {str(e)}") | ||||
|                 return jsonify({ | ||||
|  | @ -200,8 +543,14 @@ def generate_reports(): | |||
|         generated_reports = [] | ||||
|         for stock_code, stock_name in stocks: | ||||
|             try: | ||||
|                 # 创建 PDF 生成器实例 | ||||
|                 generator = PDFGenerator() | ||||
|                 # 调用 PDF 生成器 | ||||
|                 report_path = generate_investment_report(stock_code, stock_name) | ||||
|                 report_path = generator.generate_pdf( | ||||
|                     title=f"{stock_name}({stock_code}) 基本面分析报告", | ||||
|                     content_dict={},  # 这里需要传入实际的内容字典 | ||||
|                     filename=f"{stock_name}_{stock_code}_analysis.pdf" | ||||
|                 ) | ||||
|                 generated_reports.append({ | ||||
|                     "code": stock_code, | ||||
|                     "name": stock_name, | ||||
|  | @ -505,9 +854,10 @@ def analyze_and_recommend(): | |||
|             "message": f"分析和推荐股票失败: {str(e)}" | ||||
|         }), 500 | ||||
| 
 | ||||
| 
 | ||||
| @app.route('/api/comprehensive_analysis', methods=['POST']) | ||||
| def comprehensive_analysis(): | ||||
|     """综合分析接口 - 组合多种功能和参数 | ||||
|     """综合分析接口 - 使用队列方式处理被锁定的股票 | ||||
|      | ||||
|     请求体格式: | ||||
|     { | ||||
|  | @ -558,36 +908,9 @@ def comprehensive_analysis(): | |||
|         if not isinstance(limit, int) or limit <= 0: | ||||
|             limit = 10 | ||||
|          | ||||
|         # 导入必要的聊天机器人模块 | ||||
|         # 导入必要的模块 | ||||
|         try: | ||||
|             # 首先尝试导入聊天机器人模块 | ||||
|             try: | ||||
|                 from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot | ||||
|                 logger.info("成功从 src.fundamentals_llm.chat_bot 导入 ChatBot") | ||||
|             except ImportError as e1: | ||||
|                 try: | ||||
|                     from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot | ||||
|                     logger.info("成功从 fundamentals_llm.chat_bot 导入 ChatBot") | ||||
|                 except ImportError as e2: | ||||
|                     logger.error(f"无法导入在线聊天机器人模块: {str(e1)}, {str(e2)}") | ||||
|                     return jsonify({ | ||||
|                         "status": "error",  | ||||
|                         "message": f"服务器配置错误: 聊天机器人模块不可用,错误详情: {str(e2)}" | ||||
|                     }), 500 | ||||
|                      | ||||
|             # 然后尝试导入离线聊天机器人模块 | ||||
|             try: | ||||
|                 from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|                 logger.info("成功从 src.fundamentals_llm.chat_bot_with_offline 导入 ChatBot") | ||||
|             except ImportError as e1: | ||||
|                 try: | ||||
|                     from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|                     logger.info("成功从 fundamentals_llm.chat_bot_with_offline 导入 ChatBot") | ||||
|                 except ImportError as e2: | ||||
|                     logger.warning(f"无法导入离线聊天机器人模块: {str(e1)}, {str(e2)}") | ||||
|                     # 这里可以继续执行,因为某些功能可能不需要离线模型 | ||||
|              | ||||
|             # 最后导入基本面分析器 | ||||
|             # 先导入基本面分析器 | ||||
|             try: | ||||
|                 from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer | ||||
|                 logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer") | ||||
|  | @ -601,6 +924,19 @@ def comprehensive_analysis(): | |||
|                         "status": "error",  | ||||
|                         "message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}" | ||||
|                     }), 500 | ||||
|              | ||||
|             # 再导入其他可能需要的模块 | ||||
|             try: | ||||
|                 from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot | ||||
|                 from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|             except ImportError: | ||||
|                 try: | ||||
|                     from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot | ||||
|                     from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|                 except ImportError: | ||||
|                     # 这些模块不是必须的,所以继续执行 | ||||
|                     logger.warning("无法导入聊天机器人模块,但这不会影响基本功能") | ||||
|                      | ||||
|         except Exception as e: | ||||
|             logger.error(f"导入必要模块时出错: {str(e)}") | ||||
|             return jsonify({ | ||||
|  | @ -611,64 +947,141 @@ def comprehensive_analysis(): | |||
|         # 创建基本面分析器实例 | ||||
|         analyzer = FundamentalAnalyzer() | ||||
|          | ||||
|         # 为每个股票生成投资建议 | ||||
|         investment_advices = [] | ||||
|         for stock_code, stock_name in stocks: | ||||
|         # 准备结果容器 | ||||
|         investment_advices = {}  # 使用字典,股票代码作为键 | ||||
|         processing_queue = list(stocks)  # 初始处理队列 | ||||
|         max_attempts = 5  # 最大重试次数 | ||||
|         total_attempts = 0 | ||||
|          | ||||
|         # 导入数据库模块 | ||||
|         from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db | ||||
|          | ||||
|         # 开始处理队列 | ||||
|         while processing_queue and total_attempts < max_attempts: | ||||
|             total_attempts += 1 | ||||
|             logger.info(f"开始第 {total_attempts} 轮处理,队列中有 {len(processing_queue)} 只股票") | ||||
|              | ||||
|             # 暂存下一轮需要处理的股票 | ||||
|             next_round_queue = [] | ||||
|              | ||||
|             # 处理当前队列中的所有股票 | ||||
|             for stock_code, stock_name in processing_queue: | ||||
|                 try: | ||||
|                 # 生成投资建议 | ||||
|                     # 检查是否已有分析结果 | ||||
|                     db = next(get_db()) | ||||
|                     existing_result = get_analysis_result(db, stock_code, "investment_advice") | ||||
|                      | ||||
|                     # 如果已有近期结果,直接使用 | ||||
|                     if existing_result and existing_result.update_time > datetime.now() - timedelta(hours=12): | ||||
|                         investment_advices[stock_code] = { | ||||
|                             "code": stock_code, | ||||
|                             "name": stock_name, | ||||
|                             "advice": existing_result.ai_response, | ||||
|                             "reasoning": existing_result.reasoning_process, | ||||
|                             "references": existing_result.references, | ||||
|                             "status": "success", | ||||
|                             "from_cache": True | ||||
|                         } | ||||
|                         logger.info(f"使用缓存的 {stock_name}({stock_code}) 分析结果") | ||||
|                         continue | ||||
|                      | ||||
|                     # 检查是否被锁定 | ||||
|                     if analyzer.is_stock_locked(stock_code, "investment_advice"): | ||||
|                         # 已被锁定,放到下一轮队列 | ||||
|                         next_round_queue.append([stock_code, stock_name]) | ||||
|                         # 记录状态 | ||||
|                         if stock_code not in investment_advices: | ||||
|                             investment_advices[stock_code] = { | ||||
|                                 "code": stock_code, | ||||
|                                 "name": stock_name, | ||||
|                                 "status": "pending", | ||||
|                                 "message": f"股票 {stock_code} 正在被其他请求分析中,已加入等待队列" | ||||
|                             } | ||||
|                         logger.info(f"股票 {stock_name}({stock_code}) 已被锁定,放入下一轮队列") | ||||
|                         continue | ||||
|                      | ||||
|                     # 尝试锁定并分析 | ||||
|                     analyzer.lock_stock(stock_code, "investment_advice") | ||||
|                     try: | ||||
|                         # 执行分析 | ||||
|                         success, advice, reasoning, references = analyzer.query_analysis( | ||||
|                             stock_code, stock_name, "investment_advice" | ||||
|                         ) | ||||
|                          | ||||
|                         # 记录结果 | ||||
|                         if success: | ||||
|                     investment_advices.append({ | ||||
|                             investment_advices[stock_code] = { | ||||
|                                 "code": stock_code, | ||||
|                                 "name": stock_name, | ||||
|                                 "advice": advice, | ||||
|                                 "reasoning": reasoning, | ||||
|                                 "references": references, | ||||
|                                 "status": "success" | ||||
|                     }) | ||||
|                     logger.info(f"成功生成 {stock_name}({stock_code}) 的投资建议") | ||||
|                             } | ||||
|                             logger.info(f"成功分析 {stock_name}({stock_code})") | ||||
|                         else: | ||||
|                     investment_advices.append({ | ||||
|                             investment_advices[stock_code] = { | ||||
|                                 "code": stock_code, | ||||
|                                 "name": stock_name, | ||||
|                                 "status": "error", | ||||
|                         "error": advice | ||||
|                     }) | ||||
|                     logger.error(f"生成 {stock_name}({stock_code}) 的投资建议失败: {advice}") | ||||
|                                 "error": advice or "分析失败,无详细信息" | ||||
|                             } | ||||
|                             logger.error(f"分析 {stock_name}({stock_code}) 失败: {advice}") | ||||
|                     finally: | ||||
|                         # 确保释放锁 | ||||
|                         analyzer.unlock_stock(stock_code, "investment_advice") | ||||
|                          | ||||
|                 except Exception as e: | ||||
|                 logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}") | ||||
|                 investment_advices.append({ | ||||
|                     # 处理异常 | ||||
|                     investment_advices[stock_code] = { | ||||
|                         "code": stock_code, | ||||
|                         "name": stock_name, | ||||
|                         "status": "error", | ||||
|                         "error": str(e) | ||||
|                     } | ||||
|                     logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}") | ||||
|                     # 确保释放锁 | ||||
|                     try: | ||||
|                         analyzer.unlock_stock(stock_code, "investment_advice") | ||||
|                     except: | ||||
|                         pass | ||||
|              | ||||
|             # 如果还有下一轮要处理的股票,等待一段时间后继续 | ||||
|             if next_round_queue: | ||||
|                 logger.info(f"本轮结束,还有 {len(next_round_queue)} 只股票等待下一轮处理") | ||||
|                 # 等待30秒再处理下一轮 | ||||
|                 import time | ||||
|                 time.sleep(30) | ||||
|                 processing_queue = next_round_queue | ||||
|             else: | ||||
|                 # 所有股票都已处理完,退出循环 | ||||
|                 logger.info("所有股票处理完毕") | ||||
|                 processing_queue = [] | ||||
|          | ||||
|         # 处理仍在队列中的股票(达到最大重试次数但仍未处理的) | ||||
|         for stock_code, stock_name in processing_queue: | ||||
|             if stock_code in investment_advices and investment_advices[stock_code]["status"] == "pending": | ||||
|                 investment_advices[stock_code].update({ | ||||
|                     "status": "timeout", | ||||
|                     "message": f"等待超时,股票 {stock_code} 可能正在被长时间分析" | ||||
|                 }) | ||||
|                 logger.warning(f"股票 {stock_name}({stock_code}) 分析超时") | ||||
|          | ||||
|         # 将字典转换为列表 | ||||
|         investment_advice_list = list(investment_advices.values()) | ||||
|          | ||||
|         # 生成PDF报告(如果需要) | ||||
|         pdf_results = [] | ||||
|         if generate_pdf: | ||||
|             try: | ||||
|                 # 导入PDF生成器模块 | ||||
|                 # 针对已成功分析的股票生成PDF | ||||
|                 for stock_info in investment_advice_list: | ||||
|                     if stock_info["status"] == "success": | ||||
|                         stock_code = stock_info["code"] | ||||
|                         stock_name = stock_info["name"] | ||||
|                         try: | ||||
|                     from src.fundamentals_llm.pdf_generator import generate_investment_report | ||||
|                 except ImportError: | ||||
|                     try: | ||||
|                         from fundamentals_llm.pdf_generator import generate_investment_report | ||||
|                     except ImportError as e: | ||||
|                         logger.error(f"无法导入 PDF 生成器模块: {str(e)}") | ||||
|                         return jsonify({ | ||||
|                             "status": "error",  | ||||
|                             "message": f"服务器配置错误: PDF 生成器模块不可用, 错误详情: {str(e)}" | ||||
|                         }), 500 | ||||
|                  | ||||
|                 # 生成报告 | ||||
|                 for stock_code, stock_name in stocks: | ||||
|                     try: | ||||
|                         # 调用 PDF 生成器 | ||||
|                         report_path = generate_investment_report(stock_code, stock_name) | ||||
|                             report_path = analyzer.generate_pdf_report(stock_code, stock_name) | ||||
|                             if report_path: | ||||
|                                 pdf_results.append({ | ||||
|                                     "code": stock_code, | ||||
|                                     "name": stock_name, | ||||
|  | @ -676,6 +1089,14 @@ def comprehensive_analysis(): | |||
|                                     "status": "success" | ||||
|                                 }) | ||||
|                                 logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}") | ||||
|                             else: | ||||
|                                 pdf_results.append({ | ||||
|                                     "code": stock_code, | ||||
|                                     "name": stock_name, | ||||
|                                     "status": "error", | ||||
|                                     "error": "生成报告失败" | ||||
|                                 }) | ||||
|                                 logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败") | ||||
|                         except Exception as e: | ||||
|                             logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}") | ||||
|                             pdf_results.append({ | ||||
|  | @ -784,11 +1205,24 @@ def comprehensive_analysis(): | |||
|                 logger.error(f"应用企业画像筛选失败: {str(e)}") | ||||
|                 filtered_stocks = [] | ||||
|          | ||||
|         # 统计各种状态的股票数量 | ||||
|         success_count = sum(1 for item in investment_advice_list if item["status"] == "success") | ||||
|         pending_count = sum(1 for item in investment_advice_list if item["status"] == "pending") | ||||
|         timeout_count = sum(1 for item in investment_advice_list if item["status"] == "timeout") | ||||
|         error_count = sum(1 for item in investment_advice_list if item["status"] == "error") | ||||
|          | ||||
|         # 返回结果 | ||||
|         response = { | ||||
|             "status": "success", | ||||
|             "status": "success" if success_count > 0 else "partial_success" if success_count + pending_count > 0 else "failed", | ||||
|             "total_input_stocks": len(stocks), | ||||
|             "investment_advices": investment_advices | ||||
|             "stats": { | ||||
|                 "success": success_count, | ||||
|                 "pending": pending_count, | ||||
|                 "timeout": timeout_count, | ||||
|                 "error": error_count | ||||
|             }, | ||||
|             "rounds_attempted": total_attempts, | ||||
|             "investment_advices": investment_advice_list | ||||
|         } | ||||
|          | ||||
|         if profile_filter: | ||||
|  |  | |||
|  | @ -83,37 +83,37 @@ class ChatBot: | |||
|                 "role": "system", | ||||
|                 "content": """你是一个专业的股票分析助手,擅长进行深入的基本面分析。你的分析应该: | ||||
| 
 | ||||
| 1. 专业严谨 | ||||
| - 使用准确的专业术语 | ||||
| - 引用可靠的数据来源 | ||||
| - 分析逻辑清晰 | ||||
| - 结论有理有据 | ||||
|                 1. 专业严谨 | ||||
|                 - 使用准确的专业术语 | ||||
|                 - 引用可靠的数据来源 | ||||
|                 - 分析逻辑清晰 | ||||
|                 - 结论有理有据 | ||||
|                  | ||||
| 2. 全面细致 | ||||
| - 深入分析问题的各个方面 | ||||
| - 关注细节和关键信息 | ||||
| - 考虑多个影响因素 | ||||
| - 提供详实的论据支持 | ||||
|                 2. 全面细致 | ||||
|                 - 深入分析问题的各个方面 | ||||
|                 - 关注细节和关键信息 | ||||
|                 - 考虑多个影响因素 | ||||
|                 - 提供详实的论据支持 | ||||
|                  | ||||
| 3. 客观中立 | ||||
| - 保持独立判断 | ||||
| - 不夸大或贬低 | ||||
| - 平衡利弊分析 | ||||
| - 指出潜在风险 | ||||
|                 3. 客观中立 | ||||
|                 - 保持独立判断 | ||||
|                 - 不夸大或贬低 | ||||
|                 - 平衡利弊分析 | ||||
|                 - 指出潜在风险 | ||||
|                  | ||||
| 4. 实用性强 | ||||
| - 分析结论具体明确 | ||||
| - 建议具有可操作性 | ||||
| - 关注实际投资价值 | ||||
| - 提供清晰的决策参考 | ||||
|                 4. 实用性强 | ||||
|                 - 分析结论具体明确 | ||||
|                 - 建议具有可操作性 | ||||
|                 - 关注实际投资价值 | ||||
|                 - 提供清晰的决策参考 | ||||
|                  | ||||
| 5. 及时更新 | ||||
| - 关注最新信息 | ||||
| - 反映市场变化 | ||||
| - 动态调整分析 | ||||
| - 保持信息时效性 | ||||
|                 5. 及时更新 | ||||
|                 - 关注最新信息 | ||||
|                 - 反映市场变化 | ||||
|                 - 动态调整分析 | ||||
|                 - 保持信息时效性 | ||||
|                  | ||||
| 请根据用户的具体需求,提供专业、深入的分析。如果遇到不确定的信息,请明确说明。""" | ||||
|                 请根据用户的具体需求,提供专业、深入的分析。如果遇到不确定的信息,请明确说明。""" | ||||
|             } | ||||
|              | ||||
|             # 对话历史 | ||||
|  | @ -150,11 +150,15 @@ class ChatBot: | |||
|             logger.error(f"格式化参考资料时出错: {str(e)}") | ||||
|             return str(ref) | ||||
|          | ||||
|     def chat(self, user_input: str) -> Dict[str, Any]: | ||||
|     def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]: | ||||
|         """与AI进行对话 | ||||
|          | ||||
|         Args: | ||||
|             user_input: 用户输入的问题 | ||||
|             temperature: 控制输出的随机性,范围0-2,默认1.0 | ||||
|             top_p: 控制输出的多样性,范围0-1,默认0.7 | ||||
|             max_tokens: 控制输出的最大长度,默认4096 | ||||
|             frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 | ||||
|              | ||||
|         Returns: | ||||
|             Dict[str, Any]: 包含以下字段的字典: | ||||
|  | @ -173,6 +177,10 @@ class ChatBot: | |||
|             stream = self.client.chat.completions.create( | ||||
|                 model=self.model, | ||||
|                 messages=self.conversation_history, | ||||
|                 temperature=temperature, | ||||
|                 top_p=top_p, | ||||
|                 max_tokens=max_tokens, | ||||
|                 frequency_penalty=frequency_penalty, | ||||
|                 stream=True | ||||
|             ) | ||||
|              | ||||
|  |  | |||
|  | @ -0,0 +1,420 @@ | |||
| import logging | ||||
| import os | ||||
| import sys | ||||
| import json | ||||
| from typing import Dict, List, Optional, Any, Union | ||||
| from pydantic import BaseModel, Field | ||||
| from google import genai | ||||
| from google.genai.types import Tool, GenerateContentConfig, GoogleSearch | ||||
| 
 | ||||
| # 设置日志记录 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| # 定义基础数据结构 | ||||
| class TextAnalysisResult(BaseModel): | ||||
|     """文本分析结果,包含分析正文、推理过程和引用URL""" | ||||
|     analysis_text: str = Field(description="详细的分析文本") | ||||
|     reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None) | ||||
|     references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None) | ||||
| 
 | ||||
| class NumericalAnalysisResult(BaseModel): | ||||
|     """数值分析结果,包含数值和分析描述""" | ||||
|     value: str = Field(description="评估值") | ||||
|     description: str = Field(description="评估描述") | ||||
| 
 | ||||
| # 分析类型到模型类的映射 | ||||
| ANALYSIS_TYPE_MODEL_MAP = { | ||||
|     # 文本分析类型 | ||||
|     "text_analysis_result": TextAnalysisResult, | ||||
|      | ||||
|     # 数值分析类型 | ||||
|     "numerical_analysis_result": NumericalAnalysisResult | ||||
| } | ||||
| 
 | ||||
| class JsonOutputParser: | ||||
|     """解析JSON输出的工具类""" | ||||
|     def __init__(self, pydantic_object): | ||||
|         """初始化解析器 | ||||
|          | ||||
|         Args: | ||||
|             pydantic_object: 需要解析成的Pydantic模型类 | ||||
|         """ | ||||
|         self.pydantic_object = pydantic_object | ||||
| 
 | ||||
|     def get_format_instructions(self) -> str: | ||||
|         """获取格式化指令""" | ||||
|         schema = self.pydantic_object.schema() | ||||
|         schema_str = json.dumps(schema, ensure_ascii=False, indent=2) | ||||
|          | ||||
|         return f"""请将你的回答格式化为符合以下JSON结构的格式: | ||||
| {schema_str} | ||||
| 
 | ||||
| 你的回答应该只包含一个有效的JSON对象,而不包含其他内容。请不要包含任何解释或前缀,仅输出JSON本身。 | ||||
| """ | ||||
| 
 | ||||
|     def parse(self, text: str) -> Any: | ||||
|         """解析文本为对象""" | ||||
|         try: | ||||
|             # 尝试解析JSON | ||||
|             text = text.strip() | ||||
|             # 如果文本以```json开头且以```结尾,则提取中间部分 | ||||
|             if text.startswith("```json") and text.endswith("```"): | ||||
|                 text = text[7:-3].strip() | ||||
|             # 如果文本以```开头且以```结尾,则提取中间部分 | ||||
|             elif text.startswith("```") and text.endswith("```"): | ||||
|                 text = text[3:-3].strip() | ||||
|                  | ||||
|             json_obj = json.loads(text) | ||||
|             return self.pydantic_object.parse_obj(json_obj) | ||||
|         except Exception as e: | ||||
|             raise ValueError(f"无法解析为JSON或无法匹配模式: {e}") | ||||
| 
 | ||||
| # 获取项目根目录的绝对路径 | ||||
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
| 
 | ||||
| # 尝试导入配置 | ||||
| try: | ||||
|     # 添加项目根目录到 Python 路径 | ||||
|     sys.path.append(os.path.dirname(ROOT_DIR)) | ||||
|     sys.path.append(ROOT_DIR) | ||||
| 
 | ||||
|     # 导入配置 | ||||
|     try: | ||||
|         from scripts.config import get_model_config | ||||
| 
 | ||||
|         logger.info("成功从scripts.config导入配置") | ||||
|     except ImportError: | ||||
|         try: | ||||
|             from src.scripts.config import get_model_config | ||||
| 
 | ||||
|             logger.info("成功从src.scripts.config导入配置") | ||||
|         except ImportError: | ||||
|             logger.warning("无法导入配置模块,使用默认配置") | ||||
| 
 | ||||
| 
 | ||||
|             # 使用默认配置的实现 | ||||
|             def get_model_config(platform: str, model_type: str): | ||||
|                 return { | ||||
|                     "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", | ||||
|                     "model": "gemini-2.0-flash" | ||||
|                 } | ||||
| except Exception as e: | ||||
|     logger.error(f"导入配置时出错: {str(e)},使用默认配置") | ||||
| 
 | ||||
| 
 | ||||
|     # 使用默认配置的实现 | ||||
|     def get_model_config(platform: str, model_type: str): | ||||
|         return { | ||||
|             "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", | ||||
|             "model": "gemini-2.0-flash" | ||||
|         } | ||||
| 
 | ||||
| class ChatBot: | ||||
|     def __init__(self, platform: str = "Gemini", model_type: str = "gemini-2.0-flash"): | ||||
|         """初始化聊天机器人 | ||||
| 
 | ||||
|         Args: | ||||
|             platform: 平台名称,默认为Gemini | ||||
|             model_type: 要使用的模型类型,默认为gemini-2.0-flash | ||||
|         """ | ||||
|         try: | ||||
|             # 从配置获取API配置 | ||||
|             config = get_model_config(platform, model_type) | ||||
|             self.api_key = config["api_key"] | ||||
|             self.model = config["model"]  # 直接使用配置中的模型名称 | ||||
| 
 | ||||
|             logger.info(f"初始化ChatBot,使用平台: {platform}, 模型: {self.model}") | ||||
| 
 | ||||
|             google_search_tool = Tool( | ||||
|                 google_search=GoogleSearch() | ||||
|             ) | ||||
| 
 | ||||
|             self.client = genai.Client(api_key=self.api_key) | ||||
|              | ||||
|             # 系统提示语 | ||||
|             self.system_instruction = """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下: | ||||
| 
 | ||||
|                             1. 分析风格: | ||||
|                             - 专业、客观、理性 | ||||
|                             - 注重数据支撑 | ||||
|                             - 关注风险控制 | ||||
|                             - 重视投资性价比 | ||||
| 
 | ||||
|                             2. 分析框架: | ||||
|                             - 公司基本面分析 | ||||
|                             - 行业竞争格局 | ||||
|                             - 估值水平评估 | ||||
|                             - 风险因素识别 | ||||
|                             - 投资机会判断 | ||||
| 
 | ||||
|                             3. 输出要求: | ||||
|                             - 简明扼要,重点突出 | ||||
|                             - 逻辑清晰,层次分明 | ||||
|                             - 数据准确,论据充分 | ||||
|                             - 结论明确,建议具体 | ||||
| 
 | ||||
|                             请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" | ||||
|              | ||||
|             # 对话历史 | ||||
|             self.conversation_history = [] | ||||
|         except Exception as e: | ||||
|             logger.error(f"初始化ChatBot时出错: {str(e)}") | ||||
|             raise | ||||
| 
 | ||||
|     def chat(self, user_input: str, analysis_type: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096) -> dict: | ||||
|         """处理用户输入并返回AI回复 | ||||
|          | ||||
|         Args: | ||||
|             user_input: 用户输入的问题 | ||||
|             analysis_type: 分析类型,如company_profile或business_scale等 | ||||
|             temperature: 控制输出的随机性,范围0-2,默认0.7 | ||||
|             top_p: 控制输出的多样性,范围0-1,默认0.7 | ||||
|             max_tokens: 控制输出的最大长度,默认4096 | ||||
|              | ||||
|         Returns: | ||||
|             dict: 包含AI回复的字典 | ||||
|         """ | ||||
|         try: | ||||
|             from google.genai import types | ||||
|             google_search_tool = Tool( | ||||
|                 google_search=GoogleSearch() | ||||
|             ) | ||||
|              | ||||
|             # 确定分析类型和模型 | ||||
|             model_class = None | ||||
|             if analysis_type and analysis_type in ANALYSIS_TYPE_MODEL_MAP: | ||||
|                 model_class = ANALYSIS_TYPE_MODEL_MAP[analysis_type] | ||||
|                 parser = JsonOutputParser(pydantic_object=model_class) | ||||
|                 format_instructions = parser.get_format_instructions() | ||||
|                  | ||||
|                 # 在原有基础上添加格式指令 | ||||
|                 combined_instruction = f"{self.system_instruction}\n\n{format_instructions}" | ||||
|             else: | ||||
|                 combined_instruction = self.system_instruction | ||||
|                 parser = None | ||||
|              | ||||
|             response = self.client.models.generate_content_stream( | ||||
|                 model=self.model, | ||||
|                 config=GenerateContentConfig( | ||||
|                     system_instruction=combined_instruction, | ||||
|                     tools=[google_search_tool], | ||||
|                     response_modalities=["TEXT"], | ||||
|                     temperature=temperature, | ||||
|                     top_p=top_p, | ||||
|                     max_output_tokens=max_tokens | ||||
|                 ), | ||||
|                 contents=user_input | ||||
|             ) | ||||
|              | ||||
|             full_response = "" | ||||
|             for chunk in response: | ||||
|                 full_response += chunk.text | ||||
|                 print(chunk.text, end="") | ||||
|              | ||||
|             # 如果指定了分析类型,尝试解析输出 | ||||
|             if parser: | ||||
|                 try: | ||||
|                     # 检查分析类型是文本分析还是数值分析 | ||||
|                     is_text_analysis = model_class == TextAnalysisResult | ||||
|                     is_numerical_analysis = model_class == NumericalAnalysisResult | ||||
|                      | ||||
|                     # 使用解析方法尝试处理JSON响应 | ||||
|                     json_result = self.parse_json_response(full_response, model_class) | ||||
|                     if json_result: | ||||
|                         return json_result | ||||
|                      | ||||
|                     # 如果使用解析方法失败,尝试使用parser解析 | ||||
|                     parsed_result = parser.parse(full_response) | ||||
|                     structured_result = parsed_result.dict() | ||||
|                      | ||||
|                     if isinstance(parsed_result, TextAnalysisResult): | ||||
|                         # 文本分析结果 | ||||
|                         return { | ||||
|                             "response": structured_result["analysis_text"], | ||||
|                             "reasoning_process": structured_result.get("reasoning_process"), | ||||
|                             "references": structured_result.get("references") | ||||
|                         } | ||||
|                     elif isinstance(parsed_result, NumericalAnalysisResult): | ||||
|                         # 数值分析结果 | ||||
|                         return { | ||||
|                             "response": structured_result["value"], | ||||
|                             "description": structured_result.get("description", "") | ||||
|                         } | ||||
|                 except Exception as e: | ||||
|                     logger.warning(f"解析响应失败: {str(e)},返回原始响应") | ||||
|                      | ||||
|                     # 最后尝试根据分析类型直接构建结果 | ||||
|                     if is_text_analysis: | ||||
|                         return { | ||||
|                             "response": full_response, | ||||
|                             "reasoning_process": None, | ||||
|                             "references": None | ||||
|                         } | ||||
|                     elif is_numerical_analysis: | ||||
|                         try: | ||||
|                             # 尝试从文本中提取数字 | ||||
|                             import re | ||||
|                             number_match = re.search(r'\b(\d+)\b', full_response) | ||||
|                             value = int(number_match.group(1)) if number_match else 0 | ||||
|                             return { | ||||
|                                 "response": value, | ||||
|                                 "description": full_response | ||||
|                             } | ||||
|                         except: | ||||
|                             return { | ||||
|                                 "response": 0, | ||||
|                                 "description": full_response | ||||
|                             } | ||||
|                     else: | ||||
|                         return { | ||||
|                             "response": full_response, | ||||
|                             "reasoning_process": None, | ||||
|                             "references": None | ||||
|                         } | ||||
|             else: | ||||
|                 # 普通响应 | ||||
|                 return { | ||||
|                     "response": full_response, | ||||
|                     "reasoning_process": None, | ||||
|                     "references": None | ||||
|                 } | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             logger.error(f"聊天失败: {str(e)}") | ||||
|             error_msg = f"抱歉,发生错误: {str(e)}" | ||||
|             print(f"\n{error_msg}") | ||||
|             return {"response": error_msg, "reasoning_process": None, "references": None} | ||||
| 
 | ||||
|     def clear_history(self): | ||||
|         """清除对话历史""" | ||||
|         self.conversation_history = [] | ||||
|         # 重置会话 | ||||
|         if hasattr(self, 'model_instance'): | ||||
|             self.chat_session = self.model_instance.start_chat() | ||||
|         print("对话历史已清除") | ||||
| 
 | ||||
|     def run(self): | ||||
|         """运行聊天机器人""" | ||||
|         print("欢迎使用Gemini AI助手!输入 'quit' 退出,输入 'clear' 清除对话历史。") | ||||
|         print("-" * 50) | ||||
| 
 | ||||
|         while True: | ||||
|             try: | ||||
|                 # 获取用户输入 | ||||
|                 user_input = input("\n你: ").strip() | ||||
| 
 | ||||
|                 # 检查退出命令 | ||||
|                 if user_input.lower() == 'quit': | ||||
|                     print("感谢使用,再见!") | ||||
|                     break | ||||
| 
 | ||||
|                 # 检查清除历史命令 | ||||
|                 if user_input.lower() == 'clear': | ||||
|                     self.clear_history() | ||||
|                     continue | ||||
| 
 | ||||
|                 # 如果输入为空,继续下一轮 | ||||
|                 if not user_input: | ||||
|                     continue | ||||
| 
 | ||||
|                 # 获取AI回复 | ||||
|                 self.chat(user_input) | ||||
|                 print("-" * 50) | ||||
| 
 | ||||
|             except KeyboardInterrupt: | ||||
|                 print("\n感谢使用,再见!") | ||||
|                 break | ||||
|             except Exception as e: | ||||
|                 logger.error(f"运行错误: {str(e)}") | ||||
|                 print(f"发生错误: {str(e)}") | ||||
| 
 | ||||
|     def parse_json_response(self, full_response: str, model_class) -> dict: | ||||
|         """解析JSON响应,处理不同的JSON结构 | ||||
|          | ||||
|         Args: | ||||
|             full_response: 完整的响应文本 | ||||
|             model_class: 模型类(TextAnalysisResult或NumericalAnalysisResult) | ||||
|              | ||||
|         Returns: | ||||
|             dict: 包含解析结果的字典 | ||||
|         """ | ||||
|         # 检查分析类型 | ||||
|         is_text_analysis = model_class == TextAnalysisResult | ||||
|         is_numerical_analysis = model_class == NumericalAnalysisResult | ||||
|          | ||||
|         try: | ||||
|             # 尝试提取JSON内容 | ||||
|             json_text = full_response | ||||
|             # 如果文本以```json开头且以```结尾,则提取中间部分 | ||||
|             if "```json" in json_text and "```" in json_text.split("```json", 1)[1]: | ||||
|                 json_text = json_text.split("```json", 1)[1].split("```", 1)[0].strip() | ||||
|             # 如果文本以```开头且以```结尾,则提取中间部分 | ||||
|             elif "```" in json_text and "```" in json_text.split("```", 1)[1]: | ||||
|                 json_text = json_text.split("```", 1)[1].split("```", 1)[0].strip() | ||||
|              | ||||
|             # 尝试解析JSON | ||||
|             json_obj = json.loads(json_text) | ||||
|              | ||||
|             # 情况1: 直接包含字段的格式 {"analysis_text": "...", "reasoning_process": null, "references": [...]} | ||||
|             if is_text_analysis and "analysis_text" in json_obj: | ||||
|                 return { | ||||
|                     "response": json_obj["analysis_text"], | ||||
|                     "reasoning_process": json_obj.get("reasoning_process"), | ||||
|                     "references": json_obj.get("references", []) | ||||
|                 } | ||||
|             elif is_numerical_analysis and "value" in json_obj: | ||||
|                 value = json_obj["value"] | ||||
|                 if isinstance(value, str) and value.isdigit(): | ||||
|                     value = int(value) | ||||
|                 elif isinstance(value, (int, float)): | ||||
|                     value = int(value) | ||||
|                 else: | ||||
|                     value = 0 | ||||
|                      | ||||
|                 return { | ||||
|                     "response": value, | ||||
|                     "description": json_obj.get("description", "") | ||||
|                 } | ||||
|              | ||||
|             # 情况2: properties中包含字段的格式 | ||||
|             if "properties" in json_obj and isinstance(json_obj["properties"], dict): | ||||
|                 properties = json_obj["properties"] | ||||
|                  | ||||
|                 if is_text_analysis and "analysis_text" in properties: | ||||
|                     return { | ||||
|                         "response": properties["analysis_text"], | ||||
|                         "reasoning_process": properties.get("reasoning_process"), | ||||
|                         "references": properties.get("references", []) | ||||
|                     } | ||||
|                 elif is_numerical_analysis and "value" in properties: | ||||
|                     value = properties["value"] | ||||
|                     if isinstance(value, str) and value.isdigit(): | ||||
|                         value = int(value) | ||||
|                     elif isinstance(value, (int, float)): | ||||
|                         value = int(value) | ||||
|                     else: | ||||
|                         value = 0 | ||||
|                          | ||||
|                     return { | ||||
|                         "response": value, | ||||
|                         "description": properties.get("description", "") | ||||
|                     } | ||||
|              | ||||
|             # 如果无法识别结构,使用parser解析 | ||||
|             raise ValueError("无法识别的JSON结构,尝试使用parser解析") | ||||
|              | ||||
|         except Exception as e: | ||||
|             # 记录错误但不抛出,让调用方法继续尝试其他解析方式 | ||||
|             logger.warning(f"JSON解析失败: {str(e)}") | ||||
|             return None | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     # 设置日志级别 | ||||
|     logging.basicConfig( | ||||
|         level=logging.INFO, | ||||
|         format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||
|     ) | ||||
| 
 | ||||
|     # 创建并运行聊天机器人 | ||||
|     bot = ChatBot() | ||||
|     bot.run()  | ||||
|  | @ -3,7 +3,7 @@ from openai import OpenAI | |||
| import os | ||||
| import sys | ||||
| import time | ||||
| import random | ||||
| 
 | ||||
| 
 | ||||
| # 设置日志记录 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -74,26 +74,26 @@ class ChatBot: | |||
|                 "role": "system", | ||||
|                 "content": """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下: | ||||
|                              | ||||
| 1. 分析风格: | ||||
| - 专业、客观、理性 | ||||
| - 注重数据支撑 | ||||
| - 关注风险控制 | ||||
| - 重视投资性价比 | ||||
|                             1. 分析风格: | ||||
|                             - 专业、客观、理性 | ||||
|                             - 注重数据支撑 | ||||
|                             - 关注风险控制 | ||||
|                             - 重视投资性价比 | ||||
|                              | ||||
| 2. 分析框架: | ||||
| - 公司基本面分析 | ||||
| - 行业竞争格局 | ||||
| - 估值水平评估 | ||||
| - 风险因素识别 | ||||
| - 投资机会判断 | ||||
|                             2. 分析框架: | ||||
|                             - 公司基本面分析 | ||||
|                             - 行业竞争格局 | ||||
|                             - 估值水平评估 | ||||
|                             - 风险因素识别 | ||||
|                             - 投资机会判断 | ||||
|                              | ||||
| 3. 输出要求: | ||||
| - 简明扼要,重点突出 | ||||
| - 逻辑清晰,层次分明 | ||||
| - 数据准确,论据充分 | ||||
| - 结论明确,建议具体 | ||||
|                             3. 输出要求: | ||||
|                             - 简明扼要,重点突出 | ||||
|                             - 逻辑清晰,层次分明 | ||||
|                             - 数据准确,论据充分 | ||||
|                             - 结论明确,建议具体 | ||||
|                              | ||||
| 请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" | ||||
|                             请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" | ||||
|             } | ||||
|              | ||||
|             # 对话历史 | ||||
|  | @ -102,8 +102,19 @@ class ChatBot: | |||
|             logger.error(f"初始化ChatBot时出错: {str(e)}") | ||||
|             raise | ||||
| 
 | ||||
|     def chat(self, user_input: str) -> str: | ||||
|         """处理用户输入并返回AI回复""" | ||||
|     def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> str: | ||||
|         """处理用户输入并返回AI回复 | ||||
|          | ||||
|         Args: | ||||
|             user_input: 用户输入的问题 | ||||
|             temperature: 控制输出的随机性,范围0-2,默认1.0 | ||||
|             top_p: 控制输出的多样性,范围0-1,默认0.7 | ||||
|             max_tokens: 控制输出的最大长度,默认4096 | ||||
|             frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 | ||||
|              | ||||
|         Returns: | ||||
|             str: AI的回答内容 | ||||
|         """ | ||||
|         try: | ||||
|             # # 添加用户消息到对话历史-多轮 | ||||
|             self.conversation_history.append({ | ||||
|  | @ -116,7 +127,10 @@ class ChatBot: | |||
|             stream = self.client.chat.completions.create( | ||||
|                 model=self.model, | ||||
|                 messages=self.conversation_history, | ||||
|                 temperature=0, | ||||
|                 temperature=temperature, | ||||
|                 top_p=top_p, | ||||
|                 max_tokens=max_tokens, | ||||
|                 frequency_penalty=frequency_penalty, | ||||
|                 stream=True, | ||||
|                 timeout=600 | ||||
|             ) | ||||
|  |  | |||
|  | @ -130,9 +130,15 @@ class EnterpriseScreener: | |||
|             { | ||||
|                 'dimension': 'investment_advice', | ||||
|                 'field': 'investment_advice_type', | ||||
|                 'operator': '!=', | ||||
|                 'value': '不建议' | ||||
|             } | ||||
|                 'operator': 'in', | ||||
|                 'value': ["'短期'","'中期'","'长期'"] | ||||
|             }, | ||||
|             { | ||||
|                 'dimension': 'financial_report', | ||||
|                 'field': 'financial_report_level', | ||||
|                 'operator': '>=', | ||||
|                 'value': -1 | ||||
|             }, | ||||
|         ] | ||||
|         return self._screen_stocks_by_conditions(conditions, limit) | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,11 @@ | |||
| import logging | ||||
| import os | ||||
| from datetime import datetime | ||||
| from typing import Dict, List, Optional, Tuple, Callable | ||||
| import sys | ||||
| from datetime import datetime, timedelta | ||||
| import time | ||||
| import redis | ||||
| from typing import Dict, List, Optional, Tuple, Callable, Any | ||||
| 
 | ||||
| # 修改导入路径,使用相对导入 | ||||
| try: | ||||
|     # 尝试相对导入 | ||||
|  | @ -9,6 +13,7 @@ try: | |||
|     from .chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|     from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result | ||||
|     from .pdf_generator import PDFGenerator | ||||
|     from .text_processor import TextProcessor | ||||
| except ImportError: | ||||
|     # 如果相对导入失败,尝试尝试绝对导入 | ||||
|     try: | ||||
|  | @ -16,18 +21,17 @@ except ImportError: | |||
|         from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|         from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result | ||||
|         from src.fundamentals_llm.pdf_generator import PDFGenerator | ||||
|         from src.fundamentals_llm.text_processor import TextProcessor | ||||
|     except ImportError: | ||||
|         # 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况) | ||||
|         from chat_bot import ChatBot | ||||
|         from chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|         from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result | ||||
|         from pdf_generator import PDFGenerator | ||||
|         from text_processor import TextProcessor | ||||
| import json | ||||
| import re | ||||
| 
 | ||||
| # 设置日志记录 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| # 获取项目根目录的绝对路径 | ||||
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
| 
 | ||||
|  | @ -58,6 +62,34 @@ logging.basicConfig( | |||
|     datefmt=date_format | ||||
| ) | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| logger.info("测试日志输出 - 程序启动") | ||||
| 
 | ||||
| from typing import Dict, List, Optional, Any, Union | ||||
| from pydantic import BaseModel, Field | ||||
| 
 | ||||
| # 定义基础数据结构 | ||||
| class TextAnalysisResult(BaseModel): | ||||
|     """文本分析结果,包含分析正文、推理过程和引用URL""" | ||||
|     analysis_text: str = Field(description="详细的分析文本") | ||||
|     reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None) | ||||
|     references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None) | ||||
| 
 | ||||
| class NumericalAnalysisResult(BaseModel): | ||||
|     """数值分析结果,包含数值和分析描述""" | ||||
|     value: str = Field(description="评估值") | ||||
|     description: str = Field(description="评估描述") | ||||
| 
 | ||||
| # 添加Redis客户端 | ||||
| redis_client = redis.Redis( | ||||
|     host='192.168.18.208',  # Redis服务器地址,根据实际情况调整 | ||||
|     port=6379, | ||||
|     password='wlkj2018', | ||||
|     db=14, | ||||
|     socket_timeout=5, | ||||
|     decode_responses=True | ||||
| ) | ||||
| 
 | ||||
| class FundamentalAnalyzer: | ||||
|     """基本面分析器""" | ||||
|      | ||||
|  | @ -66,9 +98,10 @@ class FundamentalAnalyzer: | |||
|         # 使用联网模型进行基本面分析 | ||||
|         self.chat_bot = ChatBot(model_type="online_bot") | ||||
|         # 使用离线模型进行其他分析 | ||||
|         self.offline_bot = OfflineChatBot(platform="tl_private", model_type="ds-v1") | ||||
|         self.offline_bot = OfflineChatBot(platform="volc", model_type="offline_model") | ||||
|         # 千问打杂 | ||||
|         self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") | ||||
|         # self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") | ||||
|         self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="GLM") | ||||
|         self.db = next(get_db()) | ||||
|          | ||||
|         # 定义维度映射 | ||||
|  | @ -161,6 +194,8 @@ class FundamentalAnalyzer: | |||
|             - 关键战略决策和转型 | ||||
| 
 | ||||
|             请提供专业、客观的分析,突出关键信息,避免冗长描述。""" | ||||
|             #开头清上下文缓存 | ||||
|             self.chat_bot.clear_history() | ||||
|             # 获取AI分析结果 | ||||
|             result = self.chat_bot.chat(prompt) | ||||
|              | ||||
|  | @ -554,8 +589,8 @@ class FundamentalAnalyzer: | |||
|                     请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。""" | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             # 使用离线模型进行分析 | ||||
|             space_value_str = self.offline_bot_tl_qw.chat(prompt) | ||||
|             space_value_str = self._clean_model_output(space_value_str) | ||||
|             space_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|             space_value_str = TextProcessor.clean_thought_process(space_value_str) | ||||
|             # 提取数值 | ||||
|             space_value = 0  # 默认值 | ||||
|              | ||||
|  | @ -658,9 +693,9 @@ class FundamentalAnalyzer: | |||
|                     请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             # 使用离线模型进行分析 | ||||
|             events_value_str = self.offline_bot_tl_qw.chat(prompt) | ||||
|             events_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|             # 数据清洗 | ||||
|             events_value_str = self._clean_model_output(events_value_str) | ||||
|             events_value_str = TextProcessor.clean_thought_process(events_value_str) | ||||
|             # 提取数值 | ||||
|             events_value = 0  # 默认值 | ||||
|              | ||||
|  | @ -700,14 +735,14 @@ class FundamentalAnalyzer: | |||
|     def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool: | ||||
|         """分析股吧讨论内容""" | ||||
|         try: | ||||
|             prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出: | ||||
|             prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在400字以内(主要讨论话题200字,重要信息汇总200字),请严格按照以下格式输出: | ||||
| 
 | ||||
|                     1. 主要讨论话题(150字左右): | ||||
|                     1. 主要讨论话题: | ||||
|                     - 近期热点事件 | ||||
|                     - 投资者关注焦点 | ||||
|                     - 市场情绪倾向 | ||||
|                      | ||||
|                     2. 重要信息汇总(150字左右): | ||||
|                     2. 重要信息汇总: | ||||
|                     - 公司相关动态 | ||||
|                     - 行业政策变化 | ||||
|                     - 市场预期变化 | ||||
|  | @ -762,10 +797,10 @@ class FundamentalAnalyzer: | |||
|                     请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             # 使用离线模型进行分析 | ||||
|             emotion_value_str = self.offline_bot_tl_qw.chat(prompt) | ||||
|             emotion_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
| 
 | ||||
|             # 数据清洗 | ||||
|             emotion_value_str = self._clean_model_output(emotion_value_str) | ||||
|             emotion_value_str = TextProcessor.clean_thought_process(emotion_value_str) | ||||
|             # 提取数值 | ||||
|             emotion_value = 0  # 默认值 | ||||
|              | ||||
|  | @ -1015,10 +1050,10 @@ class FundamentalAnalyzer: | |||
|                      | ||||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。""" | ||||
|              | ||||
|             response = self.chat_bot.chat(prompt) | ||||
|             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|              | ||||
|             # 提取数值 | ||||
|             rating_str = self._extract_numeric_value_from_response(response) | ||||
|             rating_str = TextProcessor.extract_numeric_value_from_response(response) | ||||
|              | ||||
|             # 尝试将响应转换为整数 | ||||
|             try: | ||||
|  | @ -1047,20 +1082,20 @@ class FundamentalAnalyzer: | |||
|         """ | ||||
|         try: | ||||
|             prompt = f"""请仔细分析以下目标股价文本,判断上涨空间和下跌空间的关系,并返回对应的数值: | ||||
| - 如果上涨空间大于下跌空间,返回数值"1" | ||||
| - 如果上涨空间和下跌空间差不多,返回数值"0" | ||||
| - 如果下跌空间大于上涨空间,返回数值"-1" | ||||
| - 如果文本中没有相关信息,返回数值"0" | ||||
|                         - 如果上涨空间大于下跌空间,返回数值"1" | ||||
|                         - 如果上涨空间和下跌空间差不多,返回数值"0" | ||||
|                         - 如果下跌空间大于上涨空间,返回数值"-1" | ||||
|                         - 如果文本中没有相关信息,返回数值"0" | ||||
|                          | ||||
| 目标股价文本: | ||||
| {price_text} | ||||
|                         目标股价文本: | ||||
|                         {price_text} | ||||
|                          | ||||
| 只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" | ||||
|                         只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" | ||||
|              | ||||
|             response = self.chat_bot.chat(prompt) | ||||
|             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|              | ||||
|             # 提取数值 | ||||
|             odds_str = self._extract_numeric_value_from_response(response) | ||||
|             odds_str = TextProcessor.extract_numeric_value_from_response(response) | ||||
|              | ||||
|             # 尝试将响应转换为整数 | ||||
|             try: | ||||
|  | @ -1200,8 +1235,8 @@ class FundamentalAnalyzer: | |||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||
|              | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             response = self.offline_bot_tl_qw.chat(prompt) | ||||
|             pe_hist_str = self._clean_model_output(response) | ||||
|             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|             pe_hist_str =  TextProcessor.clean_thought_process(response) | ||||
|              | ||||
|             try: | ||||
|                 pe_hist = int(pe_hist_str) | ||||
|  | @ -1240,8 +1275,8 @@ class FundamentalAnalyzer: | |||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||
|              | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             response = self.offline_bot_tl_qw.chat(prompt) | ||||
|             pb_hist_str = self._clean_model_output(response) | ||||
|             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|             pb_hist_str = TextProcessor.clean_thought_process(response) | ||||
|              | ||||
|             try: | ||||
|                 pb_hist = int(pb_hist_str) | ||||
|  | @ -1280,8 +1315,8 @@ class FundamentalAnalyzer: | |||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||
|              | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             response = self.offline_bot_tl_qw.chat(prompt) | ||||
|             pe_ind_str = self._clean_model_output(response) | ||||
|             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|             pe_ind_str = TextProcessor.clean_thought_process(response) | ||||
|              | ||||
|             try: | ||||
|                 pe_ind = int(pe_ind_str) | ||||
|  | @ -1320,8 +1355,8 @@ class FundamentalAnalyzer: | |||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||
|              | ||||
|             self.offline_bot_tl_qw.clear_history() | ||||
|             response = self.offline_bot_tl_qw.chat(prompt) | ||||
|             pb_ind_str = self._clean_model_output(response) | ||||
|             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|             pb_ind_str = TextProcessor.clean_thought_process(response) | ||||
|              | ||||
|             try: | ||||
|                 pb_ind = int(pb_ind_str) | ||||
|  | @ -1373,9 +1408,9 @@ class FundamentalAnalyzer: | |||
|             {json.dumps(all_results, ensure_ascii=False, indent=2)}""" | ||||
|             self.offline_bot.clear_history() | ||||
|             # 使用离线模型生成建议 | ||||
|             result = self.offline_bot.chat(prompt) | ||||
|             result = self.offline_bot.chat(prompt,max_tokens=20000) | ||||
|             # 清理模型输出 | ||||
|             result = self._clean_model_output(result) | ||||
|             result = TextProcessor.clean_thought_process(result) | ||||
|             # 保存到数据库 | ||||
|             success = save_analysis_result( | ||||
|                 self.db, | ||||
|  | @ -1441,93 +1476,6 @@ class FundamentalAnalyzer: | |||
|             logger.error(f"提取投资建议类型失败: {str(e)}") | ||||
|             return None | ||||
|          | ||||
|     def _clean_model_output(self, output: str) -> str: | ||||
|         """清理模型输出,移除推理过程,只保留最终结果 | ||||
|          | ||||
|         Args: | ||||
|             output: 模型原始输出文本 | ||||
|              | ||||
|         Returns: | ||||
|             str: 清理后的输出文本 | ||||
|         """ | ||||
|         try: | ||||
|             # 找到</think>标签的位置 | ||||
|             think_end = output.find('</think>') | ||||
|             if think_end != -1: | ||||
|                 # 移除</think>标签及其之前的所有内容 | ||||
|                 output = output[think_end + len('</think>'):] | ||||
|              | ||||
|             # 处理可能存在的空行 | ||||
|             lines = output.split('\n') | ||||
|             cleaned_lines = [] | ||||
|             for line in lines: | ||||
|                 line = line.strip() | ||||
|                 if line:  # 只保留非空行 | ||||
|                     cleaned_lines.append(line) | ||||
|              | ||||
|             # 重新组合文本 | ||||
|             output = '\n'.join(cleaned_lines) | ||||
|              | ||||
|             return output.strip() | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"清理模型输出失败: {str(e)}") | ||||
|             return output.strip() | ||||
| 
 | ||||
|     def _extract_numeric_value_from_response(self, response: str) -> str: | ||||
|         """从模型响应中提取数值,移除参考资料和推理过程 | ||||
|          | ||||
|         Args: | ||||
|             response: 模型原始响应文本或响应对象 | ||||
|              | ||||
|         Returns: | ||||
|             str: 提取的数值字符串 | ||||
|         """ | ||||
|         try: | ||||
|             # 处理响应对象(包含response字段的字典) | ||||
|             if isinstance(response, dict) and "response" in response: | ||||
|                 response = response["response"] | ||||
|                  | ||||
|             # 确保响应是字符串 | ||||
|             if not isinstance(response, str): | ||||
|                 logger.warning(f"响应不是字符串类型: {type(response)}") | ||||
|                 return "0" | ||||
|                  | ||||
|             # 移除推理过程部分 | ||||
|             reasoning_start = response.find("推理过程:") | ||||
|             if reasoning_start != -1: | ||||
|                 response = response[:reasoning_start].strip() | ||||
|                  | ||||
|             # 移除参考资料部分(通常以 [数字] 开头的行) | ||||
|             lines = response.split("\n") | ||||
|             cleaned_lines = [] | ||||
|              | ||||
|             for line in lines: | ||||
|                 # 跳过参考资料行(通常以 [数字] 开头) | ||||
|                 if re.match(r'\[\d+\]', line.strip()): | ||||
|                     continue | ||||
|                 cleaned_lines.append(line) | ||||
|                  | ||||
|             response = "\n".join(cleaned_lines).strip() | ||||
|              | ||||
|             # 提取数值 | ||||
|             # 先尝试直接将整个响应转换为数值 | ||||
|             if response.strip() in ["-2", "-1", "0", "1", "2"]: | ||||
|                 return response.strip() | ||||
|                  | ||||
|             # 如果整个响应不是数值,尝试匹配第一个数值 | ||||
|             match = re.search(r'([-]?[0-9])', response) | ||||
|             if match: | ||||
|                 return match.group(1) | ||||
|                  | ||||
|             # 如果没有找到数值,返回默认值 | ||||
|             logger.warning(f"未能从响应中提取数值: {response}") | ||||
|             return "0" | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"从响应中提取数值失败: {str(e)}") | ||||
|             return "0" | ||||
| 
 | ||||
|     def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]: | ||||
|         """尝试多次从投资建议中提取建议类型 | ||||
|          | ||||
|  | @ -1549,7 +1497,7 @@ class FundamentalAnalyzer: | |||
|                  | ||||
|                 # 使用千问离线模型提取建议类型 | ||||
| 
 | ||||
|                 result = self.offline_bot_tl_qw.chat(prompt) | ||||
|                 result = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||
|                  | ||||
|                 # 检查是否是错误响应 | ||||
|                 if isinstance(result, str) and "抱歉,发生错误" in result: | ||||
|  | @ -1557,7 +1505,7 @@ class FundamentalAnalyzer: | |||
|                     continue | ||||
|                  | ||||
|                 # 清理模型输出 | ||||
|                 cleaned_result = self._clean_model_output(result) | ||||
|                 cleaned_result = TextProcessor.clean_thought_process(result) | ||||
|                  | ||||
|                 # 检查结果是否为有效类型 | ||||
|                 if cleaned_result in valid_types: | ||||
|  | @ -1644,6 +1592,24 @@ class FundamentalAnalyzer: | |||
|             Optional[str]: 生成的PDF文件路径,如果失败则返回None | ||||
|         """ | ||||
|         try: | ||||
|             # 检查是否已存在PDF文件 | ||||
|             reports_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'reports') | ||||
|             os.makedirs(reports_dir, exist_ok=True) | ||||
|              | ||||
|             # 构建可能的文件名格式 | ||||
|             possible_filenames = [ | ||||
|                 f"{stock_name}_{stock_code}_analysis.pdf", | ||||
|                 f"{stock_name}_{stock_code}.SZ_analysis.pdf", | ||||
|                 f"{stock_name}_{stock_code}.SH_analysis.pdf" | ||||
|             ] | ||||
|              | ||||
|             # 检查是否存在已生成的PDF文件 | ||||
|             for filename in possible_filenames: | ||||
|                 filepath = os.path.join(reports_dir, filename) | ||||
|                 if os.path.exists(filepath): | ||||
|                     logger.info(f"找到已存在的PDF报告: {filepath}") | ||||
|                     return filepath | ||||
|              | ||||
|             # 维度名称映射 | ||||
|             dimension_names = { | ||||
|                 "company_profile": "公司简介", | ||||
|  | @ -1669,27 +1635,156 @@ class FundamentalAnalyzer: | |||
|                 logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果") | ||||
|                 return None | ||||
|              | ||||
|             # 确保字体目录存在 | ||||
|             fonts_dir = os.path.join(os.path.dirname(__file__), "fonts") | ||||
|             os.makedirs(fonts_dir, exist_ok=True) | ||||
|              | ||||
|             # 检查是否存在字体文件,如果不存在则创建简单的默认字体标记文件 | ||||
|             font_path = os.path.join(fonts_dir, "simhei.ttf") | ||||
|             if not os.path.exists(font_path): | ||||
|                 # 尝试从系统字体目录复制 | ||||
|                 try: | ||||
|                     import shutil | ||||
|                     # 尝试常见的系统字体位置 | ||||
|                     system_fonts = [ | ||||
|                         "C:/Windows/Fonts/simhei.ttf",  # Windows | ||||
|                         "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",  # Linux | ||||
|                         "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc",  # 其他Linux | ||||
|                         "/System/Library/Fonts/PingFang.ttc"  # macOS | ||||
|                     ] | ||||
|                      | ||||
|                     for system_font in system_fonts: | ||||
|                         if os.path.exists(system_font): | ||||
|                             shutil.copy2(system_font, font_path) | ||||
|                             logger.info(f"已复制字体文件: {system_font} -> {font_path}") | ||||
|                             break | ||||
|                 except Exception as font_error: | ||||
|                     logger.warning(f"复制字体文件失败: {str(font_error)}") | ||||
|              | ||||
|             # 创建PDF生成器实例 | ||||
|             generator = PDFGenerator() | ||||
|              | ||||
|             # 生成PDF报告 | ||||
|             try: | ||||
|                 # 第一次尝试生成 | ||||
|                 filepath = generator.generate_pdf( | ||||
|                     title=f"{stock_name}({stock_code}) 基本面分析报告", | ||||
|                     content_dict=content_dict, | ||||
|                     output_dir=reports_dir, | ||||
|                     filename=f"{stock_name}_{stock_code}_analysis.pdf" | ||||
|                 ) | ||||
|                  | ||||
|                 if filepath: | ||||
|                     logger.info(f"PDF报告已生成: {filepath}") | ||||
|             else: | ||||
|                 logger.error("PDF报告生成失败") | ||||
|              | ||||
|                     return filepath | ||||
|             except Exception as pdf_error: | ||||
|                 logger.error(f"生成PDF报告第一次尝试失败: {str(pdf_error)}") | ||||
|                  | ||||
|                 # 如果是字体问题,可能需要使用备选方案 | ||||
|                 if "找不到中文字体文件" in str(pdf_error): | ||||
|                     # 导出为文本文件作为备选 | ||||
|                     try: | ||||
|                         txt_filename = f"{stock_name}_{stock_code}_analysis.txt" | ||||
|                         txt_filepath = os.path.join(reports_dir, txt_filename) | ||||
|                          | ||||
|                         with open(txt_filepath, 'w', encoding='utf-8') as f: | ||||
|                             f.write(f"{stock_name}({stock_code}) 基本面分析报告\n") | ||||
|                             f.write(f"生成时间:{datetime.now().strftime('%Y年%m月%d日 %H:%M:%S')}\n\n") | ||||
|                              | ||||
|                             for section_title, content in content_dict.items(): | ||||
|                                 if content: | ||||
|                                     f.write(f"## {section_title}\n\n") | ||||
|                                     f.write(f"{content}\n\n") | ||||
|                          | ||||
|                         logger.info(f"由于PDF生成失败,已生成文本报告: {txt_filepath}") | ||||
|                         return txt_filepath | ||||
|                     except Exception as txt_error: | ||||
|                         logger.error(f"生成文本报告失败: {str(txt_error)}") | ||||
|              | ||||
|             logger.error("PDF报告生成失败") | ||||
|             return None | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"生成PDF报告失败: {str(e)}") | ||||
|             return None | ||||
| 
 | ||||
|     def is_stock_locked(self, stock_code: str, dimension: str) -> bool: | ||||
|         """检查股票是否已被锁定 | ||||
|          | ||||
|         Args: | ||||
|             stock_code: 股票代码 | ||||
|             dimension: 分析维度 | ||||
|              | ||||
|         Returns: | ||||
|             bool: 是否被锁定 | ||||
|         """ | ||||
|         try: | ||||
|             lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" | ||||
|              | ||||
|             # 检查是否存在锁 | ||||
|             existing_lock = redis_client.get(lock_key) | ||||
|             if existing_lock: | ||||
|                 lock_time = int(existing_lock) | ||||
|                 current_time = int(time.time()) | ||||
|                  | ||||
|                 # 锁超过30分钟(1800秒)视为过期 | ||||
|                 if current_time - lock_time > 1800: | ||||
|                     # 锁已过期,可以释放 | ||||
|                     redis_client.delete(lock_key) | ||||
|                     logger.info(f"股票 {stock_code} 维度 {dimension} 的过期锁已释放") | ||||
|                     return False | ||||
|                 else: | ||||
|                     # 锁未过期,被锁定 | ||||
|                     return True | ||||
|              | ||||
|             # 不存在锁 | ||||
|             return False | ||||
|         except Exception as e: | ||||
|             logger.error(f"检查股票 {stock_code} 锁状态时出错: {str(e)}") | ||||
|             # 出错时保守处理,返回未锁定 | ||||
|             return False | ||||
|      | ||||
|     def lock_stock(self, stock_code: str, dimension: str) -> bool: | ||||
|         """锁定股票 | ||||
|          | ||||
|         Args: | ||||
|             stock_code: 股票代码 | ||||
|             dimension: 分析维度 | ||||
|              | ||||
|         Returns: | ||||
|             bool: 是否成功锁定 | ||||
|         """ | ||||
|         try: | ||||
|             lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" | ||||
|             current_time = int(time.time()) | ||||
|              | ||||
|             # 设置锁,过期时间1小时 | ||||
|             redis_client.set(lock_key, current_time, ex=3600) | ||||
|             logger.info(f"股票 {stock_code} 维度 {dimension} 已锁定") | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             logger.error(f"锁定股票 {stock_code} 时出错: {str(e)}") | ||||
|             return False | ||||
|      | ||||
|     def unlock_stock(self, stock_code: str, dimension: str) -> bool: | ||||
|         """解锁股票 | ||||
|          | ||||
|         Args: | ||||
|             stock_code: 股票代码 | ||||
|             dimension: 分析维度 | ||||
|              | ||||
|         Returns: | ||||
|             bool: 是否成功解锁 | ||||
|         """ | ||||
|         try: | ||||
|             lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" | ||||
|             redis_client.delete(lock_key) | ||||
|             logger.info(f"股票 {stock_code} 维度 {dimension} 已解锁") | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             logger.error(f"解锁股票 {stock_code} 时出错: {str(e)}") | ||||
|             return False | ||||
| 
 | ||||
| def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool: | ||||
|     """测试单个分析方法""" | ||||
|     try: | ||||
|  | @ -1713,11 +1808,6 @@ def test_single_stock(analyzer: FundamentalAnalyzer, stock_code: str, stock_name | |||
| 
 | ||||
| def main(): | ||||
|     """主函数""" | ||||
|     # 设置日志级别 | ||||
|     logging.basicConfig( | ||||
|         level=logging.INFO, | ||||
|         format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||
|     ) | ||||
| 
 | ||||
|     # 测试股票列表 | ||||
|     test_stocks = [ | ||||
|  |  | |||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -11,25 +11,25 @@ import markdown2 | |||
| from bs4 import BeautifulSoup | ||||
| import os | ||||
| from datetime import datetime | ||||
| from fpdf import FPDF | ||||
| import matplotlib.pyplot as plt | ||||
| import shutil | ||||
| 
 | ||||
| import matplotlib | ||||
| matplotlib.use('Agg') | ||||
| 
 | ||||
| # 修改导入路径,使用相对导入 | ||||
| try: | ||||
|     # 尝试相对导入 | ||||
|     from .chat_bot_with_offline import ChatBot | ||||
|     from .fundamental_analysis_database import get_db, get_analysis_result | ||||
|     from .chat_bot import ChatBot | ||||
|     from .chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
| except ImportError: | ||||
|     # 如果相对导入失败,尝试绝对导入 | ||||
|     try: | ||||
|         from src.fundamentals_llm.chat_bot_with_offline import ChatBot | ||||
|         from src.fundamentals_llm.fundamental_analysis_database import get_db, get_analysis_result | ||||
|         from src.fundamentals_llm.chat_bot import ChatBot | ||||
|         from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
|     except ImportError: | ||||
|         # 最后尝试直接导入 | ||||
|         from chat_bot_with_offline import ChatBot | ||||
|         from fundamental_analysis_database import get_db, get_analysis_result | ||||
|         from chat_bot import ChatBot | ||||
|         from chat_bot_with_offline import ChatBot as OfflineChatBot | ||||
| 
 | ||||
| # 设置日志记录 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -39,30 +39,9 @@ class PDFGenerator: | |||
|      | ||||
|     def __init__(self): | ||||
|         """初始化PDF生成器""" | ||||
|         # 注册中文字体 | ||||
|         try: | ||||
|             # 尝试使用系统自带的中文字体 | ||||
|             if os.name == 'nt':  # Windows | ||||
|                 font_path = "C:/Windows/Fonts/simhei.ttf"  # 黑体 | ||||
|             else:  # Linux/Mac | ||||
|                 font_path = "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf" | ||||
|              | ||||
|             if os.path.exists(font_path): | ||||
|                 pdfmetrics.registerFont(TTFont('SimHei', font_path)) | ||||
|                 self.font_name = 'SimHei' | ||||
|             else: | ||||
|                 # 如果找不到系统字体,尝试使用当前目录下的字体 | ||||
|                 font_path = os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf") | ||||
|                 if os.path.exists(font_path): | ||||
|                     pdfmetrics.registerFont(TTFont('SimHei', font_path)) | ||||
|                     self.font_name = 'SimHei' | ||||
|                 else: | ||||
|                     raise FileNotFoundError("找不到中文字体文件") | ||||
|         except Exception as e: | ||||
|             logger.error(f"注册中文字体失败: {str(e)}") | ||||
|             raise | ||||
|          | ||||
|         self.chat_bot = ChatBot() | ||||
|         # 尝试注册中文字体 | ||||
|         self.font_name = self._register_chinese_font() | ||||
|         self.chat_bot = OfflineChatBot(platform="volc", model_type="offline_model")  # 使用离线模式 | ||||
|         self.styles = getSampleStyleSheet() | ||||
|          | ||||
|         # 创建自定义样式 | ||||
|  | @ -139,6 +118,64 @@ class PDFGenerator: | |||
|             textColor=colors.HexColor('#333333') | ||||
|         )) | ||||
|      | ||||
|     def _register_chinese_font(self): | ||||
|         """查找并注册中文字体 | ||||
|          | ||||
|         Returns: | ||||
|             str: 注册的字体名称,如果失败则使用默认字体 | ||||
|         """ | ||||
|         try: | ||||
|             # 可能的字体文件位置列表 | ||||
|             font_locations = [ | ||||
|                 # 当前目录/子目录字体位置 | ||||
|                 os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf"), | ||||
|                 os.path.join(os.path.dirname(__file__), "fonts", "wqy-microhei.ttc"), | ||||
|                 os.path.join(os.path.dirname(__file__), "fonts", "wqy-zenhei.ttc"), | ||||
|                 # Docker中可能的位置 | ||||
|                 "/app/src/fundamentals_llm/fonts/simhei.ttf", | ||||
|                 # Windows系统字体位置 | ||||
|                 "C:/Windows/Fonts/simhei.ttf", | ||||
|                 "C:/Windows/Fonts/simfang.ttf", | ||||
|                 "C:/Windows/Fonts/simsun.ttc", | ||||
|                 # Linux/Mac系统字体位置 | ||||
|                 "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", | ||||
|                 "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", | ||||
|                 "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", | ||||
|                 "/usr/share/fonts/wqy-zenhei/wqy-zenhei.ttc", | ||||
|                 "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf", | ||||
|                 "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", | ||||
|                 "/System/Library/Fonts/PingFang.ttc"  # macOS | ||||
|             ] | ||||
|              | ||||
|             # 尝试各个位置 | ||||
|             for font_path in font_locations: | ||||
|                 if os.path.exists(font_path): | ||||
|                     logger.info(f"使用字体文件: {font_path}") | ||||
|                     # 为防止字体文件名称不同,统一拷贝到字体目录并重命名 | ||||
|                     font_dir = os.path.join(os.path.dirname(__file__), "fonts") | ||||
|                     os.makedirs(font_dir, exist_ok=True) | ||||
|                     target_path = os.path.join(font_dir, "simhei.ttf") | ||||
|                      | ||||
|                     # 如果字体文件不在目标位置,拷贝过去 | ||||
|                     if font_path != target_path and not os.path.exists(target_path): | ||||
|                         try: | ||||
|                             shutil.copy2(font_path, target_path) | ||||
|                             logger.info(f"已将字体文件复制到: {target_path}") | ||||
|                         except Exception as e: | ||||
|                             logger.warning(f"复制字体文件失败: {str(e)}") | ||||
|                      | ||||
|                     # 注册字体 | ||||
|                     pdfmetrics.registerFont(TTFont('SimHei', target_path)) | ||||
|                     return 'SimHei' | ||||
|              | ||||
|             # 如果所有位置都找不到,使用默认字体 | ||||
|             logger.warning("找不到中文字体文件,将使用默认字体") | ||||
|             return 'Helvetica' | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"注册中文字体失败: {str(e)}") | ||||
|             return 'Helvetica'  # 使用默认字体 | ||||
|      | ||||
|     def _convert_markdown_to_flowables(self, markdown_text: str) -> List: | ||||
|         """将Markdown文本转换为PDF流对象 | ||||
|          | ||||
|  | @ -191,18 +228,18 @@ class PDFGenerator: | |||
|         """ | ||||
|         try: | ||||
|             prompt = f"""请对以下内容进行优化和格式化。要求: | ||||
| 1. 保持原文的专业性和准确性 | ||||
| 2. 将零散的内容整合成连贯的段落,并对不重要的内容精简 | ||||
| 3. 使用适当的标点符号和换行 | ||||
| 4. 使用Markdown格式进行排版 | ||||
| 5. 移除所有引用内容(包括"参考资料:"等) | ||||
| 6. 不要返回其他多余的内容 | ||||
| 7. 确保内容结构清晰,层次分明 | ||||
| 8. 将零散的内容整合成完整的段落,避免过于零散的表述 | ||||
| 9. 使用自然流畅的语言,避免过于机械的结构化表达 | ||||
|                     1. 保持原文的专业性和准确性 | ||||
|                     2. 将零散的内容整合成连贯的段落,并对不重要的内容精简 | ||||
|                     3. 使用适当的标点符号和换行 | ||||
|                     4. 使用Markdown格式进行排版 | ||||
|                     5. 移除所有引用内容(包括"参考资料:"等) | ||||
|                     6. 不要返回其他多余的内容 | ||||
|                     7. 确保内容结构清晰,层次分明 | ||||
|                     8. 将零散的内容整合成完整的段落,避免过于零散的表述 | ||||
|                     9. 使用自然流畅的语言,避免过于机械的结构化表达 | ||||
|                      | ||||
| 原始内容: | ||||
| {content}""" | ||||
|                     原始内容: | ||||
|                     {content}""" | ||||
|              | ||||
|             result = self.chat_bot.chat(prompt) | ||||
|             return result | ||||
|  |  | |||
|  | @ -0,0 +1,190 @@ | |||
| #!/usr/bin/env python | ||||
| # -*- coding: utf-8 -*- | ||||
| 
 | ||||
| """ | ||||
| 中文字体安装工具 | ||||
| 
 | ||||
| 此脚本用于下载和安装中文字体,以支持PDF生成功能。 | ||||
| """ | ||||
| 
 | ||||
| import os | ||||
| import sys | ||||
| import logging | ||||
| import shutil | ||||
| import tempfile | ||||
| import platform | ||||
| 
 | ||||
| # 配置日志 | ||||
| logging.basicConfig( | ||||
|     level=logging.INFO, | ||||
|     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||
| ) | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| def download_wqy_font(): | ||||
|     """ | ||||
|     下载文泉驿微米黑字体 | ||||
|      | ||||
|     Returns: | ||||
|         str: 下载的字体文件路径,如果失败则返回None | ||||
|     """ | ||||
|     try: | ||||
|         import requests | ||||
|          | ||||
|         # 创建临时目录 | ||||
|         temp_dir = tempfile.mkdtemp() | ||||
|         font_file = os.path.join(temp_dir, "wqy-microhei.ttc") | ||||
|          | ||||
|         # 文泉驿微米黑字体的下载链接 | ||||
|         url = "https://mirrors.tuna.tsinghua.edu.cn/osdn/wqy/wqy-microhei-0.2.0-beta.tar.gz" | ||||
|          | ||||
|         logger.info(f"开始下载文泉驿微米黑字体: {url}") | ||||
|         response = requests.get(url, stream=True) | ||||
|         response.raise_for_status() | ||||
|          | ||||
|         # 下载压缩包 | ||||
|         tar_file = os.path.join(temp_dir, "wqy-microhei.tar.gz") | ||||
|         with open(tar_file, 'wb') as f: | ||||
|             for chunk in response.iter_content(chunk_size=8192): | ||||
|                 f.write(chunk) | ||||
|          | ||||
|         # 解压字体文件 | ||||
|         import tarfile | ||||
|         with tarfile.open(tar_file) as tar: | ||||
|             # 提取字体文件 | ||||
|             for member in tar.getmembers(): | ||||
|                 if member.name.endswith(('.ttc', '.ttf')): | ||||
|                     member.name = os.path.basename(member.name) | ||||
|                     tar.extract(member, temp_dir) | ||||
|                     extracted_file = os.path.join(temp_dir, member.name) | ||||
|                     if os.path.exists(extracted_file): | ||||
|                         return extracted_file | ||||
|          | ||||
|         logger.error("在压缩包中未找到字体文件") | ||||
|         return None | ||||
|      | ||||
|     except Exception as e: | ||||
|         logger.error(f"下载字体失败: {str(e)}") | ||||
|         return None | ||||
|      | ||||
| def find_system_fonts(): | ||||
|     """ | ||||
|     在系统中查找可用的中文字体 | ||||
|      | ||||
|     Returns: | ||||
|         str: 找到的字体文件路径,如果找不到则返回None | ||||
|     """ | ||||
|     # 常见的中文字体位置 | ||||
|     font_locations = [] | ||||
|      | ||||
|     system = platform.system() | ||||
|     if system == "Windows": | ||||
|         font_locations = [ | ||||
|             "C:/Windows/Fonts/simhei.ttf", | ||||
|             "C:/Windows/Fonts/simsun.ttc", | ||||
|             "C:/Windows/Fonts/msyh.ttf", | ||||
|             "C:/Windows/Fonts/simfang.ttf", | ||||
|         ] | ||||
|     elif system == "Darwin":  # macOS | ||||
|         font_locations = [ | ||||
|             "/System/Library/Fonts/PingFang.ttc", | ||||
|             "/Library/Fonts/Microsoft/SimHei.ttf", | ||||
|             "/Library/Fonts/Microsoft/SimSun.ttf", | ||||
|         ] | ||||
|     else:  # Linux | ||||
|         font_locations = [ | ||||
|             "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", | ||||
|             "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", | ||||
|             "/usr/share/fonts/wenquanyi/wqy-microhei/wqy-microhei.ttc", | ||||
|             "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", | ||||
|             "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf", | ||||
|         ] | ||||
|      | ||||
|     # 查找第一个存在的字体文件 | ||||
|     for font_path in font_locations: | ||||
|         if os.path.exists(font_path): | ||||
|             logger.info(f"在系统中找到中文字体: {font_path}") | ||||
|             return font_path | ||||
|      | ||||
|     logger.warning("在系统中未找到任何中文字体") | ||||
|     return None | ||||
| 
 | ||||
| def install_font(target_dir=None): | ||||
|     """ | ||||
|     安装中文字体 | ||||
|      | ||||
|     Args: | ||||
|         target_dir: 目标目录,如果不提供则使用默认目录 | ||||
|          | ||||
|     Returns: | ||||
|         str: 安装后的字体文件路径,如果失败则返回None | ||||
|     """ | ||||
|     # 如果没有指定目标目录,使用默认位置 | ||||
|     if target_dir is None: | ||||
|         # 获取当前脚本所在目录 | ||||
|         current_dir = os.path.dirname(os.path.abspath(__file__)) | ||||
|         target_dir = os.path.join(current_dir, "fonts") | ||||
|      | ||||
|     # 确保目标目录存在 | ||||
|     os.makedirs(target_dir, exist_ok=True) | ||||
|      | ||||
|     # 目标字体文件路径 | ||||
|     target_font = os.path.join(target_dir, "simhei.ttf") | ||||
|      | ||||
|     # 如果目标字体已存在,直接返回 | ||||
|     if os.path.exists(target_font): | ||||
|         logger.info(f"中文字体已存在: {target_font}") | ||||
|         return target_font | ||||
|      | ||||
|     # 尝试从系统中复制字体 | ||||
|     system_font = find_system_fonts() | ||||
|     if system_font: | ||||
|         try: | ||||
|             shutil.copy2(system_font, target_font) | ||||
|             logger.info(f"已从系统复制字体: {system_font} -> {target_font}") | ||||
|             return target_font | ||||
|         except Exception as e: | ||||
|             logger.error(f"复制系统字体失败: {str(e)}") | ||||
|      | ||||
|     # 如果系统中找不到字体,尝试下载 | ||||
|     logger.info("尝试下载中文字体...") | ||||
|     downloaded_font = download_wqy_font() | ||||
|     if downloaded_font: | ||||
|         try: | ||||
|             shutil.copy2(downloaded_font, target_font) | ||||
|             logger.info(f"已安装下载的字体: {downloaded_font} -> {target_font}") | ||||
|             return target_font | ||||
|         except Exception as e: | ||||
|             logger.error(f"复制下载的字体失败: {str(e)}") | ||||
|      | ||||
|     logger.error("安装中文字体失败") | ||||
|     return None | ||||
| 
 | ||||
| def main(): | ||||
|     """ | ||||
|     主函数 | ||||
|     """ | ||||
|     print("中文字体安装工具") | ||||
|     print("----------------") | ||||
|      | ||||
|     target_dir = None | ||||
|     if len(sys.argv) > 1: | ||||
|         target_dir = sys.argv[1] | ||||
|         print(f"使用指定的目标目录: {target_dir}") | ||||
|      | ||||
|     result = install_font(target_dir) | ||||
|      | ||||
|     if result: | ||||
|         print(f"\n✓ 成功安装中文字体: {result}") | ||||
|         print("\n现在可以正常生成PDF报告了。") | ||||
|     else: | ||||
|         print("\n✗ 安装中文字体失败") | ||||
|         print("\n请手动将中文字体(.ttf或.ttc文件)复制到以下目录:") | ||||
|         if target_dir: | ||||
|             print(f"  {target_dir}") | ||||
|         else: | ||||
|             print(f"  {os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fonts')}") | ||||
|         print("\n并将其重命名为'simhei.ttf'") | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main()  | ||||
|  | @ -0,0 +1,172 @@ | |||
| import re | ||||
| import logging | ||||
| from typing import Optional, List, Dict, Any | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| class TextProcessor: | ||||
|     """大模型文本处理工具类""" | ||||
|      | ||||
|     @staticmethod | ||||
|     def clean_model_output(output: str) -> str: | ||||
|         """清理模型输出文本 | ||||
|          | ||||
|         Args: | ||||
|             output: 模型输出的原始文本 | ||||
|              | ||||
|         Returns: | ||||
|             str: 清理后的文本 | ||||
|         """ | ||||
|         try: | ||||
|             # 移除引用标记 | ||||
|             output = re.sub(r'\[[^\]]*\]', '', output) | ||||
|              | ||||
|             # 移除多余的空白字符 | ||||
|             output = re.sub(r'\s+', ' ', output) | ||||
|              | ||||
|             # 移除首尾的空白字符 | ||||
|             output = output.strip() | ||||
|              | ||||
|             # 移除可能存在的HTML标签 | ||||
|             output = re.sub(r'<[^>]+>', '', output) | ||||
|              | ||||
|             # 移除可能存在的Markdown格式 | ||||
|             output = re.sub(r'[*_`~]', '', output) | ||||
|              | ||||
|             return output | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"清理模型输出失败: {str(e)}") | ||||
|             return output | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def clean_thought_process(output: str) -> str: | ||||
|         """清理模型输出,移除推理过程,只保留最终结果 | ||||
| 
 | ||||
|         Args: | ||||
|             output: 模型原始输出文本 | ||||
| 
 | ||||
|         Returns: | ||||
|             str: 清理后的输出文本 | ||||
|         """ | ||||
|         try: | ||||
|             # 找到</think>标签的位置 | ||||
|             think_end = output.find('</think>') | ||||
|             if think_end != -1: | ||||
|                 # 移除</think>标签及其之前的所有内容 | ||||
|                 output = output[think_end + len('</think>'):] | ||||
| 
 | ||||
|             # 处理可能存在的空行 | ||||
|             lines = output.split('\n') | ||||
|             cleaned_lines = [] | ||||
|             for line in lines: | ||||
|                 line = line.strip() | ||||
|                 if line:  # 只保留非空行 | ||||
|                     cleaned_lines.append(line) | ||||
| 
 | ||||
|             # 重新组合文本 | ||||
|             output = '\n'.join(cleaned_lines) | ||||
| 
 | ||||
|             return output.strip() | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             logger.error(f"清理模型输出失败: {str(e)}") | ||||
|             return output.strip() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def extract_numeric_value_from_response(response: str) -> str: | ||||
|         """从模型响应中提取数值 | ||||
|          | ||||
|         Args: | ||||
|             response: 模型响应的文本 | ||||
|              | ||||
|         Returns: | ||||
|             str: 提取出的数值 | ||||
|         """ | ||||
|         try: | ||||
|             # 清理响应文本 | ||||
|             cleaned_response = TextProcessor.clean_model_output(response) | ||||
|              | ||||
|             # 尝试提取数值 | ||||
|             numeric_patterns = [ | ||||
|                 r'[-+]?\d*\.\d+',  # 浮点数 | ||||
|                 r'[-+]?\d+',       # 整数 | ||||
|                 r'[-+]?\d+\.\d*'   # 可能带小数点的数 | ||||
|             ] | ||||
|              | ||||
|             for pattern in numeric_patterns: | ||||
|                 matches = re.findall(pattern, cleaned_response) | ||||
|                 if matches: | ||||
|                     # 返回第一个匹配的数值 | ||||
|                     return matches[0] | ||||
|              | ||||
|             # 如果没有找到数值,返回原始响应 | ||||
|             return cleaned_response | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"提取数值失败: {str(e)}") | ||||
|             return response | ||||
|      | ||||
|     @staticmethod | ||||
|     def extract_list_from_response(response: str) -> List[str]: | ||||
|         """从模型响应中提取列表项 | ||||
|          | ||||
|         Args: | ||||
|             response: 模型响应的文本 | ||||
|              | ||||
|         Returns: | ||||
|             List[str]: 提取出的列表项 | ||||
|         """ | ||||
|         try: | ||||
|             # 清理响应文本 | ||||
|             cleaned_response = TextProcessor.clean_model_output(response) | ||||
|              | ||||
|             # 尝试匹配列表项 | ||||
|             # 匹配数字编号的列表项 | ||||
|             numbered_items = re.findall(r'\d+\.\s*([^\n]+)', cleaned_response) | ||||
|             if numbered_items: | ||||
|                 return [item.strip() for item in numbered_items] | ||||
|              | ||||
|             # 匹配带符号的列表项 | ||||
|             bullet_items = re.findall(r'[-*•]\s*([^\n]+)', cleaned_response) | ||||
|             if bullet_items: | ||||
|                 return [item.strip() for item in bullet_items] | ||||
|              | ||||
|             # 如果没有找到列表项,返回空列表 | ||||
|             return [] | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"提取列表项失败: {str(e)}") | ||||
|             return [] | ||||
|      | ||||
|     @staticmethod | ||||
|     def extract_key_value_pairs(response: str) -> Dict[str, str]: | ||||
|         """从模型响应中提取键值对 | ||||
|          | ||||
|         Args: | ||||
|             response: 模型响应的文本 | ||||
|              | ||||
|         Returns: | ||||
|             Dict[str, str]: 提取出的键值对 | ||||
|         """ | ||||
|         try: | ||||
|             # 清理响应文本 | ||||
|             cleaned_response = TextProcessor.clean_model_output(response) | ||||
|              | ||||
|             # 尝试匹配键值对 | ||||
|             # 匹配冒号分隔的键值对 | ||||
|             pairs = re.findall(r'([^:]+):\s*([^\n]+)', cleaned_response) | ||||
|             if pairs: | ||||
|                 return {key.strip(): value.strip() for key, value in pairs} | ||||
|              | ||||
|             # 匹配等号分隔的键值对 | ||||
|             pairs = re.findall(r'([^=]+)=\s*([^\n]+)', cleaned_response) | ||||
|             if pairs: | ||||
|                 return {key.strip(): value.strip() for key, value in pairs} | ||||
|              | ||||
|             # 如果没有找到键值对,返回空字典 | ||||
|             return {} | ||||
|              | ||||
|         except Exception as e: | ||||
|             logger.error(f"提取键值对失败: {str(e)}") | ||||
|             return {}  | ||||
|  | @ -66,12 +66,22 @@ MODEL_CONFIGS = { | |||
|             "doubao": "doubao-1-5-pro-32k-250115" | ||||
|         } | ||||
|     }, | ||||
|     # 谷歌Gemini | ||||
|     "Gemini": { | ||||
|         "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", | ||||
|         "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", | ||||
|         "models": { | ||||
|             "offline_model": "gemini-2.0-flash" | ||||
|         } | ||||
|     }, | ||||
|     # 天链苹果 | ||||
|     "tl_private": { | ||||
|         "base_url": "http://192.168.32.118:1234/v1/", | ||||
|         "base_url": "http://192.168.16.174:1234/v1/", | ||||
|         "api_key": "none", | ||||
|         "models": { | ||||
|             "ds-v1": "mlx-community/DeepSeek-R1-4bit" | ||||
|             "glm-z1": "glm-z1-rumination-32b-0414", | ||||
|             "glm-4": "glm-4-32b-0414-abliterated", | ||||
|             "ds_v1": "mlx-community/DeepSeek-R1-4bit", | ||||
|         } | ||||
|     }, | ||||
|     # 天链-千问 | ||||
|  | @ -79,7 +89,8 @@ MODEL_CONFIGS = { | |||
|         "base_url": "http://192.168.16.178:11434/v1", | ||||
|         "api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019", | ||||
|         "models": { | ||||
|             "qwq": "qwq:32b" | ||||
|             "qwq": "qwq:32b", | ||||
|             "GLM": "hf-mirror.com/Cobra4687/GLM-4-32B-0414-abliterated-Q4_K_M-GGUF:Q4_K_M" | ||||
|         } | ||||
|     }, | ||||
|     # Deepseek配置 | ||||
|  |  | |||
|  | @ -0,0 +1,986 @@ | |||
| import pandas as pd | ||||
| import numpy as np | ||||
| from sqlalchemy import create_engine | ||||
| from datetime import datetime, timedelta | ||||
| from tqdm import tqdm | ||||
| import matplotlib.pyplot as plt | ||||
| import os | ||||
| 
 | ||||
| # v2版本只做表格 | ||||
| 
 | ||||
| # 添加中文字体支持 | ||||
| plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签 | ||||
| plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号 | ||||
| 
 | ||||
| class StockAnalyzer: | ||||
|     def __init__(self, db_connection_string): | ||||
|         """ | ||||
|         Initialize the stock analyzer | ||||
|         :param db_connection_string: Database connection string | ||||
|         """ | ||||
|         self.engine = create_engine(db_connection_string) | ||||
|         self.initial_capital = 300000  # 30万本金 | ||||
|         self.total_position = 1000000  # 100万建仓规模 | ||||
|         self.borrowed_capital = 700000  # 70万借入资金 | ||||
|         self.annual_interest_rate = 0.05  # 年化5%利息 | ||||
|         self.daily_interest_rate = self.annual_interest_rate / 365 | ||||
|         self.commission_rate = 0.05  # 5%手续费 | ||||
|         self.min_holding_days = 7  # 最少持仓天数 | ||||
| 
 | ||||
|     def get_stock_list(self): | ||||
|         """获取所有股票列表""" | ||||
|         # query = "SELECT DISTINCT gp_code as symbol FROM gp_code_all where mark1 = '1'" | ||||
|         # return pd.read_sql(query, self.engine)['symbol'].tolist() | ||||
|         return ['SH600522', | ||||
|                 'SZ002340', | ||||
|                 'SZ000733', | ||||
|                 'SH601615', | ||||
|                 'SH600157', | ||||
|                 'SH688005', | ||||
|                 'SH600903', | ||||
|                 'SH600956', | ||||
|                 'SH601187', | ||||
|                 'SH603983']  # 你可以在这里添加更多股票代码 | ||||
| 
 | ||||
|     def get_stock_data(self, symbol, start_date, end_date, include_history=False): | ||||
|         """获取指定股票在时间范围内的数据""" | ||||
|         try: | ||||
|             if include_history: | ||||
|                 # 如果需要历史数据,获取start_date之前240天的数据 | ||||
|                 query = f""" | ||||
|                 WITH history_data AS ( | ||||
|                     SELECT  | ||||
|                         symbol, | ||||
|                         timestamp, | ||||
|                         CAST(close as DECIMAL(10,2)) as close | ||||
|                     FROM gp_day_data  | ||||
|                     WHERE symbol = '{symbol}'  | ||||
|                     AND timestamp < '{start_date}' | ||||
|                     ORDER BY timestamp DESC | ||||
|                     LIMIT 240 | ||||
|                 ) | ||||
|                 SELECT  | ||||
|                     symbol, | ||||
|                     timestamp, | ||||
|                     CAST(open as DECIMAL(10,2)) as open, | ||||
|                     CAST(high as DECIMAL(10,2)) as high, | ||||
|                     CAST(low as DECIMAL(10,2)) as low, | ||||
|                     CAST(close as DECIMAL(10,2)) as close, | ||||
|                     CAST(volume as DECIMAL(20,0)) as volume | ||||
|                 FROM gp_day_data  | ||||
|                 WHERE symbol = '{symbol}'  | ||||
|                 AND timestamp BETWEEN '{start_date}' AND '{end_date}' | ||||
|                 UNION ALL | ||||
|                 SELECT  | ||||
|                     symbol, | ||||
|                     timestamp, | ||||
|                     NULL as open, | ||||
|                     NULL as high, | ||||
|                     NULL as low, | ||||
|                     close, | ||||
|                     NULL as volume | ||||
|                 FROM history_data | ||||
|                 ORDER BY timestamp | ||||
|                 """ | ||||
|             else: | ||||
|                 query = f""" | ||||
|                 SELECT  | ||||
|                     symbol, | ||||
|                     timestamp, | ||||
|                     CAST(open as DECIMAL(10,2)) as open, | ||||
|                     CAST(high as DECIMAL(10,2)) as high, | ||||
|                     CAST(low as DECIMAL(10,2)) as low, | ||||
|                     CAST(close as DECIMAL(10,2)) as close, | ||||
|                     CAST(volume as DECIMAL(20,0)) as volume | ||||
|                 FROM gp_day_data  | ||||
|                 WHERE symbol = '{symbol}'  | ||||
|                 AND timestamp BETWEEN '{start_date}' AND '{end_date}' | ||||
|                 ORDER BY timestamp | ||||
|                 """ | ||||
|              | ||||
|             # 使用chunksize分批读取数据 | ||||
|             chunks = [] | ||||
|             for chunk in pd.read_sql(query, self.engine, chunksize=1000): | ||||
|                 chunks.append(chunk) | ||||
|             df = pd.concat(chunks, ignore_index=True) | ||||
|              | ||||
|             # 确保数值列的类型正确 | ||||
|             numeric_columns = ['open', 'high', 'low', 'close', 'volume'] | ||||
|             for col in numeric_columns: | ||||
|                 if col in df.columns: | ||||
|                     df[col] = pd.to_numeric(df[col], errors='coerce') | ||||
|              | ||||
|             # 删除任何包含 NaN 的行,但只考虑分析期间需要的列 | ||||
|             if include_history: | ||||
|                 historical_mask = (df['timestamp'] < start_date) & df['close'].notna() | ||||
|                 analysis_mask = (df['timestamp'] >= start_date) & df[['open', 'high', 'low', 'close', 'volume']].notna().all(axis=1) | ||||
|                 df = df[historical_mask | analysis_mask] | ||||
|             else: | ||||
|                 df = df.dropna() | ||||
|              | ||||
|             return df | ||||
|              | ||||
|         except Exception as e: | ||||
|             print(f"获取股票 {symbol} 数据时出错: {str(e)}") | ||||
|             return pd.DataFrame() | ||||
| 
 | ||||
|     def calculate_trade_signals(self, df, take_profit_pct, stop_loss_pct, pullback_entry_pct, start_date): | ||||
|         """计算交易信号和收益""" | ||||
|         """计算交易信号和收益""" | ||||
|         # 检查历史数据是否足够 | ||||
|         historical_data = df[df['timestamp'] < start_date] | ||||
|         if len(historical_data) < 240: | ||||
|             # print(f"历史数据不足240天,跳过分析") | ||||
|             return pd.DataFrame() | ||||
| 
 | ||||
|         # 计算240日均线基准价格(使用历史数据) | ||||
|         ma240_base_price = historical_data['close'].tail(240).mean() | ||||
|         # print(f"240日均线基准价格: {ma240_base_price:.2f}") | ||||
| 
 | ||||
|         # 只使用分析期间的数据进行交易信号计算 | ||||
|         analysis_df = df[df['timestamp'] >= start_date].copy() | ||||
|         if len(analysis_df) < 10:  # 确保分析期间有足够的数据 | ||||
|             # print(f"分析期间数据不足10天,跳过分析") | ||||
|             return pd.DataFrame() | ||||
| 
 | ||||
|         results = [] | ||||
|         positions = []  # 记录所有持仓,每个元素是一个字典,包含入场价格和日期 | ||||
|         max_positions = 5  # 最大建仓次数 | ||||
|         cumulative_profit = 0  # 总累计收益 | ||||
|         trades_count = 0 | ||||
|         max_holding_period = timedelta(days=180)  # 最长持仓6个月 | ||||
|         max_loss_amount = -300000  # 最大亏损金额为30万 | ||||
|          | ||||
|         # 添加统计变量 | ||||
|         total_holding_days = 0  # 总持仓天数 | ||||
|         max_capital_usage = 0  # 最大资金占用 | ||||
|         total_trades = 0  # 总交易次数 | ||||
|         profitable_trades = 0  # 盈利交易次数 | ||||
|         loss_trades = 0  # 亏损交易次数 | ||||
| 
 | ||||
|         # print(f"\n{'='*50}") | ||||
|         # print(f"开始分析股票: {df.iloc[0]['symbol']}") | ||||
|         # print(f"{'='*50}") | ||||
| 
 | ||||
|         # 记录前一天的收盘价 | ||||
|         prev_close = None | ||||
| 
 | ||||
|         for index, row in analysis_df.iterrows(): | ||||
|             current_price = (float(row['high']) + float(row['low'])) / 2 | ||||
|             day_result = { | ||||
|                 'timestamp': row['timestamp'], | ||||
|                 'profit': cumulative_profit,  # 记录当前的累计收益 | ||||
|                 'position': False, | ||||
|                 'action': 'hold' | ||||
|             } | ||||
| 
 | ||||
|             # 更新所有持仓的收益 | ||||
|             positions_to_remove = [] | ||||
| 
 | ||||
|             for pos in positions: | ||||
|                 pos_days_held = (row['timestamp'] - pos['entry_date']).days | ||||
|                 pos_size = self.total_position / pos['entry_price'] | ||||
|                  | ||||
|                 # 检查是否需要平仓 | ||||
|                 should_close = False | ||||
|                 close_type = None | ||||
|                 close_price = current_price | ||||
| 
 | ||||
|                 if pos_days_held >= self.min_holding_days: | ||||
|                     # 检查止盈 | ||||
|                     if float(row['high']) >= pos['entry_price'] * (1 + take_profit_pct): | ||||
|                         should_close = True | ||||
|                         close_type = 'sell_profit' | ||||
|                         close_price = pos['entry_price'] * (1 + take_profit_pct) | ||||
|                     # 检查最大亏损 | ||||
|                     elif (close_price - pos['entry_price']) * pos_size <= max_loss_amount: | ||||
|                         should_close = True | ||||
|                         close_type = 'sell_loss' | ||||
|                     # 检查持仓时间 | ||||
|                     elif (row['timestamp'] - pos['entry_date']) >= max_holding_period: | ||||
|                         should_close = True | ||||
|                         close_type = 'time_limit' | ||||
| 
 | ||||
|                 if should_close: | ||||
|                     profit = (close_price - pos['entry_price']) * pos_size | ||||
|                     commission = profit * self.commission_rate if profit > 0 else 0 | ||||
|                     interest_cost = self.borrowed_capital * self.daily_interest_rate * pos_days_held | ||||
|                     net_profit = profit - commission - interest_cost | ||||
|                      | ||||
|                     # 更新统计数据 | ||||
|                     total_holding_days += pos_days_held | ||||
|                     total_trades += 1 | ||||
|                     if net_profit > 0: | ||||
|                         profitable_trades += 1 | ||||
|                     else: | ||||
|                         loss_trades += 1 | ||||
|                      | ||||
|                     cumulative_profit += net_profit | ||||
|                     positions_to_remove.append(pos) | ||||
|                     day_result['action'] = close_type | ||||
|                     day_result['profit'] = cumulative_profit | ||||
|                      | ||||
|                     # print(f"\n>>> {'止盈平仓' if close_type == 'sell_profit' else '止损平仓' if close_type == 'sell_loss' else '到期平仓'}") | ||||
|                     # print(f"持仓天数: {pos_days_held}") | ||||
|                     # print(f"持仓数量: {pos_size:.0f}股") | ||||
|                     # print(f"平仓价格: {close_price:.2f}") | ||||
|                     # print(f"{'盈利' if profit > 0 else '亏损'}: {profit:,.2f}") | ||||
|                     # print(f"手续费: {commission:,.2f}") | ||||
|                     # print(f"利息成本: {interest_cost:,.2f}") | ||||
|                     # print(f"净{'盈利' if net_profit > 0 else '亏损'}: {net_profit:,.2f}") | ||||
| 
 | ||||
|             # 移除已平仓的持仓 | ||||
|             for pos in positions_to_remove: | ||||
|                 positions.remove(pos) | ||||
| 
 | ||||
|             # 检查是否可以建仓 | ||||
|             if len(positions) < max_positions and prev_close is not None: | ||||
|                 # 计算当日跌幅 | ||||
|                 daily_drop = (prev_close - float(row['low'])) / prev_close | ||||
|                  | ||||
|                 # 检查价格是否在均线区间内 | ||||
|                 price_in_range = ( | ||||
|                     current_price >= ma240_base_price * 0.7 and  | ||||
|                     current_price <= ma240_base_price * 1.3 | ||||
|                 ) | ||||
|                  | ||||
|                 # 如果跌幅超过5%且价格在均线区间内 | ||||
|                 if daily_drop >= 0.05 and price_in_range: | ||||
|                     positions.append({ | ||||
|                         'entry_price': current_price, | ||||
|                         'entry_date': row['timestamp'] | ||||
|                     }) | ||||
|                     trades_count += 1 | ||||
|                     # print(f"\n>>> 建仓信号 #{trades_count}") | ||||
|                     # print(f"日期: {row['timestamp'].strftime('%Y-%m-%d')}") | ||||
|                     # print(f"建仓价格: {current_price:.2f}") | ||||
|                     # print(f"建仓数量: {self.total_position/current_price:.0f}股") | ||||
|                     # print(f"当日跌幅: {daily_drop*100:.2f}%") | ||||
|                     # print(f"距离均线: {((current_price/ma240_base_price)-1)*100:.2f}%") | ||||
| 
 | ||||
|             # 更新前一天收盘价 | ||||
|             prev_close = float(row['close']) | ||||
| 
 | ||||
|             # 更新日结果 | ||||
|             day_result['position'] = len(positions) > 0 | ||||
|             results.append(day_result) | ||||
| 
 | ||||
|             # 计算当前资金占用 | ||||
|             current_capital_usage = sum(self.total_position for pos in positions) | ||||
|             max_capital_usage = max(max_capital_usage, current_capital_usage) | ||||
| 
 | ||||
|         # 在最后一天强制平仓所有持仓 | ||||
|         if positions: | ||||
|             final_price = (float(analysis_df.iloc[-1]['high']) + float(analysis_df.iloc[-1]['low'])) / 2 | ||||
|             final_total_profit = 0 | ||||
|              | ||||
|             for pos in positions: | ||||
|                 position_size = self.total_position / pos['entry_price'] | ||||
|                 days_held = (analysis_df.iloc[-1]['timestamp'] - pos['entry_date']).days | ||||
|                 final_profit = (final_price - pos['entry_price']) * position_size | ||||
|                 commission = final_profit * self.commission_rate if final_profit > 0 else 0 | ||||
|                 interest_cost = self.borrowed_capital * self.daily_interest_rate * days_held | ||||
|                 net_profit = final_profit - commission - interest_cost | ||||
|                  | ||||
|                 # 更新统计数据 | ||||
|                 total_holding_days += days_held | ||||
|                 total_trades += 1 | ||||
|                 if net_profit > 0: | ||||
|                     profitable_trades += 1 | ||||
|                 else: | ||||
|                     loss_trades += 1 | ||||
|                  | ||||
|                 final_total_profit += net_profit | ||||
|                  | ||||
|                 # print(f"\n>>> 到期强制平仓") | ||||
|                 # print(f"持仓天数: {days_held}") | ||||
|                 # print(f"持仓数量: {position_size:.0f}股") | ||||
|                 # print(f"平仓价格: {final_price:.2f}") | ||||
|                 # print(f"毛利润: {final_profit:,.2f}") | ||||
|                 # print(f"手续费: {commission:,.2f}") | ||||
|                 # print(f"利息成本: {interest_cost:,.2f}") | ||||
|                 # print(f"净利润: {net_profit:,.2f}") | ||||
|              | ||||
|             results[-1]['action'] = 'final_sell' | ||||
|             cumulative_profit += final_total_profit  # 更新最终的累计收益 | ||||
|             results[-1]['profit'] = cumulative_profit  # 更新最后一天的累计收益 | ||||
| 
 | ||||
|         # 计算统计数据 | ||||
|         avg_holding_days = total_holding_days / total_trades if total_trades > 0 else 0 | ||||
|         win_rate = profitable_trades / total_trades * 100 if total_trades > 0 else 0 | ||||
| 
 | ||||
|         # print(f"\n{'='*50}") | ||||
|         # print(f"交易统计") | ||||
|         # print(f"总交易次数: {trades_count}") | ||||
|         # print(f"累计收益: {cumulative_profit:,.2f}") | ||||
|         # print(f"最大资金占用: {max_capital_usage:,.2f}") | ||||
|         # print(f"平均持仓天数: {avg_holding_days:.1f}") | ||||
|         # print(f"胜率: {win_rate:.1f}%") | ||||
|         # print(f"盈利交易: {profitable_trades}次") | ||||
|         # print(f"亏损交易: {loss_trades}次") | ||||
|         # print(f"{'='*50}\n") | ||||
| 
 | ||||
|         # 将统计数据添加到结果中 | ||||
|         if len(results) > 0: | ||||
|             results[-1]['stats'] = { | ||||
|                 'total_trades': total_trades, | ||||
|                 'profitable_trades': profitable_trades, | ||||
|                 'loss_trades': loss_trades, | ||||
|                 'win_rate': win_rate, | ||||
|                 'avg_holding_days': avg_holding_days, | ||||
|                 'max_capital_usage': max_capital_usage, | ||||
|                 'final_profit': cumulative_profit | ||||
|             } | ||||
| 
 | ||||
|         return pd.DataFrame(results) | ||||
| 
 | ||||
|     def analyze_stock(self, symbol, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): | ||||
|         """分析单个股票""" | ||||
|         df = self.get_stock_data(symbol, start_date, end_date, include_history=True) | ||||
|          | ||||
|         if len(df) < 240:  # 确保总数据量足够 | ||||
|             print(f"股票 {symbol} 总数据不足240天,跳过分析") | ||||
|             return None | ||||
|              | ||||
|         return self.calculate_trade_signals( | ||||
|             df, | ||||
|             take_profit_pct=take_profit_pct, | ||||
|             stop_loss_pct=stop_loss_pct, | ||||
|             pullback_entry_pct=pullback_entry_pct, | ||||
|             start_date=start_date | ||||
|         ) | ||||
| 
 | ||||
|     def analyze_all_stocks(self, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): | ||||
|         """分析所有股票""" | ||||
|         stocks = self.get_stock_list() | ||||
|         all_results = {} | ||||
|         analysis_summary = { | ||||
|             'processed': [], | ||||
|             'skipped': [], | ||||
|             'error': [] | ||||
|         } | ||||
|          | ||||
|         print(f"\n开始分析,共 {len(stocks)} 支股票") | ||||
|          | ||||
|         for symbol in tqdm(stocks, desc="Analyzing stocks"): | ||||
|             try: | ||||
|                 # 获取股票数据 | ||||
|                 df = self.get_stock_data(symbol, start_date, end_date, include_history=True) | ||||
|                  | ||||
|                 if df.empty: | ||||
|                     print(f"\n股票 {symbol} 数据为空,跳过分析") | ||||
|                     analysis_summary['skipped'].append({ | ||||
|                         'symbol': symbol, | ||||
|                         'reason': 'empty_data', | ||||
|                         'take_profit_pct': take_profit_pct | ||||
|                     }) | ||||
|                     continue | ||||
|                  | ||||
|                 # 检查历史数据 | ||||
|                 historical_data = df[df['timestamp'] < start_date] | ||||
|                 if len(historical_data) < 240: | ||||
|                     # print(f"\n股票 {symbol} 历史数据不足240天({len(historical_data)}天),跳过分析") | ||||
|                     analysis_summary['skipped'].append({ | ||||
|                         'symbol': symbol, | ||||
|                         'reason': f'insufficient_history_{len(historical_data)}days', | ||||
|                         'take_profit_pct': take_profit_pct | ||||
|                     }) | ||||
|                     continue | ||||
|                  | ||||
|                 # 检查分析期数据 | ||||
|                 analysis_data = df[df['timestamp'] >= start_date] | ||||
|                 if len(analysis_data) < 10: | ||||
|                     print(f"\n股票 {symbol} 分析期数据不足10天({len(analysis_data)}天),跳过分析") | ||||
|                     analysis_summary['skipped'].append({ | ||||
|                         'symbol': symbol, | ||||
|                         'reason': f'insufficient_analysis_data_{len(analysis_data)}days', | ||||
|                         'take_profit_pct': take_profit_pct | ||||
|                     }) | ||||
|                     continue | ||||
|                  | ||||
|                 # 计算交易信号 | ||||
|                 result = self.calculate_trade_signals( | ||||
|                     df, | ||||
|                     take_profit_pct, | ||||
|                     stop_loss_pct, | ||||
|                     pullback_entry_pct, | ||||
|                     start_date | ||||
|                 ) | ||||
|                  | ||||
|                 if result is not None and not result.empty: | ||||
|                     # 检查是否有交易记录 | ||||
|                     has_trades = any(action in ['sell_profit', 'sell_loss', 'time_limit', 'final_sell']  | ||||
|                                    for action in result['action'] if pd.notna(action)) | ||||
|                      | ||||
|                     if has_trades: | ||||
|                         all_results[symbol] = result | ||||
|                         analysis_summary['processed'].append({ | ||||
|                             'symbol': symbol, | ||||
|                             'take_profit_pct': take_profit_pct, | ||||
|                             'trades': len([x for x in result['action'] if pd.notna(x) and x != 'hold']), | ||||
|                             'profit': result['profit'].iloc[-1] | ||||
|                         }) | ||||
|                     else: | ||||
|                         # print(f"\n股票 {symbol} 没有产生有效的交易信号") | ||||
|                         analysis_summary['skipped'].append({ | ||||
|                             'symbol': symbol, | ||||
|                             'reason': 'no_valid_signals', | ||||
|                             'take_profit_pct': take_profit_pct | ||||
|                         }) | ||||
|                  | ||||
|             except Exception as e: | ||||
|                 print(f"\n处理股票 {symbol} 时出错: {str(e)}") | ||||
|                 analysis_summary['error'].append({ | ||||
|                     'symbol': symbol, | ||||
|                     'error': str(e), | ||||
|                     'take_profit_pct': take_profit_pct | ||||
|                 }) | ||||
|                 continue | ||||
|          | ||||
|         # 打印分析总结 | ||||
|         print(f"\n分析完成:") | ||||
|         print(f"成功分析: {len(analysis_summary['processed'])} 支股票") | ||||
|         print(f"跳过分析: {len(analysis_summary['skipped'])} 支股票") | ||||
|         print(f"出错股票: {len(analysis_summary['error'])} 支股票") | ||||
|          | ||||
|         # 保存分析总结 | ||||
|         if not os.path.exists('results/analysis_summary'): | ||||
|             os.makedirs('results/analysis_summary') | ||||
|          | ||||
|         # 保存处理成功的股票信息 | ||||
|         if analysis_summary['processed']: | ||||
|             pd.DataFrame(analysis_summary['processed']).to_csv( | ||||
|                 f'results/analysis_summary/processed_stocks_{take_profit_pct*100:.1f}.csv', | ||||
|                 index=False | ||||
|             ) | ||||
|          | ||||
|         # 保存跳过的股票信息 | ||||
|         if analysis_summary['skipped']: | ||||
|             pd.DataFrame(analysis_summary['skipped']).to_csv( | ||||
|                 f'results/analysis_summary/skipped_stocks_{take_profit_pct*100:.1f}.csv', | ||||
|                 index=False | ||||
|             ) | ||||
|          | ||||
|         # 保存出错的股票信息 | ||||
|         if analysis_summary['error']: | ||||
|             pd.DataFrame(analysis_summary['error']).to_csv( | ||||
|                 f'results/analysis_summary/error_stocks_{take_profit_pct*100:.1f}.csv', | ||||
|                 index=False | ||||
|             ) | ||||
|          | ||||
|         return all_results | ||||
| 
 | ||||
|     def plot_results(self, results, symbol): | ||||
|         """绘制单个股票的收益走势图""" | ||||
|         plt.figure(figsize=(12, 6)) | ||||
|         plt.plot(results['timestamp'], results['profit'].cumsum(), label='Cumulative Profit') | ||||
|         plt.title(f'Stock {symbol} Trading Results') | ||||
|         plt.xlabel('Date') | ||||
|         plt.ylabel('Profit (CNY)') | ||||
|         plt.legend() | ||||
|         plt.grid(True) | ||||
|         plt.xticks(rotation=45) | ||||
|         plt.tight_layout() | ||||
|         return plt | ||||
| 
 | ||||
|     def plot_all_stocks(self, all_results): | ||||
|         """绘制所有股票的收益走势图""" | ||||
|         plt.figure(figsize=(15, 8)) | ||||
|          | ||||
|         # 记录所有股票的累计收益 | ||||
|         total_profits = {} | ||||
|         max_profit = float('-inf') | ||||
|         min_profit = float('inf') | ||||
|          | ||||
|         # 绘制每支股票的曲线 | ||||
|         for symbol, results in all_results.items(): | ||||
|             # 直接使用已经计算好的累计收益 | ||||
|             profits = results['profit'] | ||||
|             plt.plot(results['timestamp'], profits, label=symbol, alpha=0.5, linewidth=1) | ||||
|              | ||||
|             # 更新最大最小收益 | ||||
|             max_profit = max(max_profit, profits.max()) | ||||
|             min_profit = min(min_profit, profits.min()) | ||||
|              | ||||
|             # 记录总收益(使用最后一个累计收益值) | ||||
|             total_profits[symbol] = profits.iloc[-1] | ||||
|          | ||||
|         # 计算所有股票的中位数收益曲线 | ||||
|         # 首先找到共同的日期范围 | ||||
|         all_dates = set() | ||||
|         for results in all_results.values(): | ||||
|             all_dates.update(results['timestamp']) | ||||
|         all_dates = sorted(list(all_dates)) | ||||
|          | ||||
|         # 创建一个包含所有日期的DataFrame | ||||
|         profits_df = pd.DataFrame(index=all_dates) | ||||
|          | ||||
|         # 对每支股票,填充所有日期的收益 | ||||
|         for symbol, results in all_results.items(): | ||||
|             daily_profits = pd.Series(index=all_dates, data=np.nan) | ||||
|             # 直接使用已经计算好的累计收益 | ||||
|             for date, profit in zip(results['timestamp'], results['profit']): | ||||
|                 daily_profits[date] = profit | ||||
|             # 对于没有交易的日期,使用最后一个累计收益填充 | ||||
|             daily_profits.fillna(method='ffill', inplace=True) | ||||
|             daily_profits.fillna(0, inplace=True)  # 对开始之前的日期填充0 | ||||
|             profits_df[symbol] = daily_profits | ||||
|          | ||||
|         # 计算并绘制中位数收益曲线 | ||||
|         median_line = profits_df.median(axis=1) | ||||
|         plt.plot(all_dates, median_line, 'r-', label='Median', linewidth=2) | ||||
|          | ||||
|         plt.title('All Stocks Trading Results') | ||||
|         plt.xlabel('Date') | ||||
|         plt.ylabel('Profit (CNY)') | ||||
|         plt.grid(True) | ||||
|         plt.xticks(rotation=45) | ||||
|          | ||||
|         # 添加利润区间标注 | ||||
|         plt.text(0.02, 0.98,  | ||||
|                 f'Profit Range:\nMax: {max_profit:,.0f}\nMin: {min_profit:,.0f}\nMedian: {median_line.iloc[-1]:,.0f}',  | ||||
|                 transform=plt.gca().transAxes,  | ||||
|                 bbox=dict(facecolor='white', alpha=0.8)) | ||||
|          | ||||
|         # 计算所有股票的统计数据 | ||||
|         total_stats = { | ||||
|             'total_trades': 0, | ||||
|             'profitable_trades': 0, | ||||
|             'loss_trades': 0, | ||||
|             'total_holding_days': 0, | ||||
|             'max_capital_usage': 0, | ||||
|             'total_profit': 0, | ||||
|             'profitable_stocks': 0, | ||||
|             'loss_stocks': 0, | ||||
|             'total_capital_usage': 0,  # 添加总资金占用 | ||||
|             'total_win_rate': 0,       # 添加总胜率 | ||||
|         } | ||||
| 
 | ||||
|         # 获取前5名和后5名股票 | ||||
|         top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] | ||||
|         bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] | ||||
|          | ||||
|         # 添加排名信息 | ||||
|         rank_text = "Top 5 Stocks:\n" | ||||
|         for symbol, profit in top_5: | ||||
|             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||
|         rank_text += "\nBottom 5 Stocks:\n" | ||||
|         for symbol, profit in bottom_5: | ||||
|             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||
|              | ||||
|         plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,  | ||||
|                 bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') | ||||
| 
 | ||||
|         # 记录每支股票的统计数据用于计算平均值 | ||||
|         stock_stats = [] | ||||
| 
 | ||||
|         for symbol, results in all_results.items(): | ||||
|             if 'stats' in results.iloc[-1]: | ||||
|                 stats = results.iloc[-1]['stats'] | ||||
|                 stock_stats.append(stats)  # 记录每支股票的统计数据 | ||||
|                  | ||||
|                 total_stats['total_trades'] += stats['total_trades'] | ||||
|                 total_stats['profitable_trades'] += stats['profitable_trades'] | ||||
|                 total_stats['loss_trades'] += stats['loss_trades'] | ||||
|                 total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] | ||||
|                 total_stats['max_capital_usage'] = max(total_stats['max_capital_usage'], stats['max_capital_usage']) | ||||
|                 total_stats['total_capital_usage'] += stats['max_capital_usage'] | ||||
|                 total_stats['total_profit'] += stats['final_profit'] | ||||
|                 total_stats['total_win_rate'] += stats['win_rate'] | ||||
|                 if stats['final_profit'] > 0: | ||||
|                     total_stats['profitable_stocks'] += 1 | ||||
|                 else: | ||||
|                     total_stats['loss_stocks'] += 1 | ||||
| 
 | ||||
|         # 计算总体统计 | ||||
|         total_stocks = total_stats['profitable_stocks'] + total_stats['loss_stocks'] | ||||
|         avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 | ||||
|         win_rate = total_stats['profitable_trades'] / total_stats['total_trades'] * 100 if total_stats['total_trades'] > 0 else 0 | ||||
|         stock_win_rate = total_stats['profitable_stocks'] / total_stocks * 100 if total_stocks > 0 else 0 | ||||
|          | ||||
|         # 计算平均值统计 | ||||
|         avg_trades_per_stock = total_stats['total_trades'] / total_stocks if total_stocks > 0 else 0 | ||||
|         avg_profit_per_stock = total_stats['total_profit'] / total_stocks if total_stocks > 0 else 0 | ||||
|         avg_capital_usage = total_stats['total_capital_usage'] / total_stocks if total_stocks > 0 else 0 | ||||
|         avg_win_rate = total_stats['total_win_rate'] / total_stocks if total_stocks > 0 else 0 | ||||
| 
 | ||||
|         # 添加总体统计信息 | ||||
|         stats_text = ( | ||||
|             f"总体统计:\n" | ||||
|             f"总交易次数: {total_stats['total_trades']}\n" | ||||
|             # f"最大资金占用: {total_stats['max_capital_usage']:,.0f}\n" | ||||
|             f"平均持仓天数: {avg_holding_days:.1f}\n" | ||||
|             f"交易胜率: {win_rate:.1f}%\n" | ||||
|             f"股票胜率: {stock_win_rate:.1f}%\n" | ||||
|             f"盈利股票数: {total_stats['profitable_stocks']}\n" | ||||
|             f"亏损股票数: {total_stats['loss_stocks']}\n" | ||||
|             f"\n平均值统计:\n" | ||||
|             f"平均交易次数: {avg_trades_per_stock:.1f}\n" | ||||
|             f"平均收益: {avg_profit_per_stock:,.0f}\n" | ||||
|             f"平均资金占用: {avg_capital_usage:,.0f}\n" | ||||
|             f"平均持仓天数: {avg_holding_days:.1f}\n" | ||||
|             f"平均胜率: {avg_win_rate:.1f}%" | ||||
|         ) | ||||
|          | ||||
|         plt.text(0.02, 0.6, stats_text, | ||||
|                 transform=plt.gca().transAxes, | ||||
|                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), | ||||
|                 verticalalignment='top', | ||||
|                 fontsize=10) | ||||
|          | ||||
|         # 获取前5名和后5名股票 | ||||
|         top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] | ||||
|         bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] | ||||
|          | ||||
|         # 添加排名信息 | ||||
|         rank_text = "Top 5 Stocks:\n" | ||||
|         for symbol, profit in top_5: | ||||
|             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||
|         rank_text += "\nBottom 5 Stocks:\n" | ||||
|         for symbol, profit in bottom_5: | ||||
|             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||
|              | ||||
|         plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,  | ||||
|                 bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') | ||||
|          | ||||
|         plt.tight_layout() | ||||
|         return plt, total_profits | ||||
| 
 | ||||
|     def plot_profit_distribution(self, total_profits): | ||||
|         """绘制所有股票的收益分布小提琴图""" | ||||
|         plt.figure(figsize=(12, 8)) | ||||
|          | ||||
|         # 将收益数据转换为数组,确保使用最终累计收益 | ||||
|         profits = np.array([profit for profit in total_profits.values()]) | ||||
|          | ||||
|         # 计算统计数据 | ||||
|         mean_profit = np.mean(profits) | ||||
|         median_profit = np.median(profits) | ||||
|         max_profit = np.max(profits) | ||||
|         min_profit = np.min(profits) | ||||
|         std_profit = np.std(profits) | ||||
|          | ||||
|         # 绘制小提琴图 | ||||
|         violin_parts = plt.violinplot(profits, positions=[0], showmeans=True, showmedians=True) | ||||
|          | ||||
|         # 设置小提琴图的颜色 | ||||
|         violin_parts['bodies'][0].set_facecolor('lightblue') | ||||
|         violin_parts['bodies'][0].set_alpha(0.7) | ||||
|         violin_parts['cmeans'].set_color('red') | ||||
|         violin_parts['cmedians'].set_color('blue') | ||||
|          | ||||
|         # 添加统计信息 | ||||
|         stats_text = ( | ||||
|             f"统计信息:\n" | ||||
|             f"平均收益: {mean_profit:,.0f}\n" | ||||
|             f"中位数收益: {median_profit:,.0f}\n" | ||||
|             f"最大收益: {max_profit:,.0f}\n" | ||||
|             f"最小收益: {min_profit:,.0f}\n" | ||||
|             f"标准差: {std_profit:,.0f}" | ||||
|         ) | ||||
|          | ||||
|         plt.text(0.65, 0.95, stats_text, | ||||
|                 transform=plt.gca().transAxes, | ||||
|                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), | ||||
|                 verticalalignment='top', | ||||
|                 fontsize=10) | ||||
|          | ||||
|         # 设置图表样式 | ||||
|         plt.title('股票收益分布图', fontsize=14, pad=20) | ||||
|         plt.ylabel('收益 (元)', fontsize=12) | ||||
|         plt.xticks([0], ['所有股票'], fontsize=10) | ||||
|         plt.grid(True, axis='y', alpha=0.3) | ||||
|          | ||||
|         # 添加零线 | ||||
|         plt.axhline(y=0, color='r', linestyle='--', alpha=0.3) | ||||
|          | ||||
|         # 计算盈利和亏损的股票数量 | ||||
|         profit_count = np.sum(profits > 0) | ||||
|         loss_count = np.sum(profits < 0) | ||||
|         total_count = len(profits) | ||||
|          | ||||
|         # 添加盈亏比例信息 | ||||
|         ratio_text = ( | ||||
|             f"盈亏比例:\n" | ||||
|             f"盈利: {profit_count}支 ({profit_count/total_count*100:.1f}%)\n" | ||||
|             f"亏损: {loss_count}支 ({loss_count/total_count*100:.1f}%)\n" | ||||
|             f"总计: {total_count}支" | ||||
|         ) | ||||
|          | ||||
|         plt.text(0.65, 0.6, ratio_text, | ||||
|                 transform=plt.gca().transAxes, | ||||
|                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), | ||||
|                 verticalalignment='top', | ||||
|                 fontsize=10) | ||||
|          | ||||
|         plt.tight_layout() | ||||
|         return plt | ||||
| 
 | ||||
|     def plot_profit_matrix(self, df): | ||||
|         """绘制止盈比例分析矩阵的结果图""" | ||||
|         plt.figure(figsize=(15, 10)) | ||||
|          | ||||
|         # 创建子图 | ||||
|         fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) | ||||
|          | ||||
|         # 1. 总收益曲线 | ||||
|         ax1.plot(df.index, df['total_profit'], 'b-', linewidth=2) | ||||
|         ax1.set_title('总收益 vs 止盈比例') | ||||
|         ax1.set_xlabel('止盈比例 (%)') | ||||
|         ax1.set_ylabel('总收益 (元)') | ||||
|         ax1.grid(True) | ||||
|          | ||||
|         # 标记最优止盈比例点 | ||||
|         best_profit_pct = df['total_profit'].idxmax() | ||||
|         best_profit = df['total_profit'].max() | ||||
|         ax1.plot(best_profit_pct, best_profit, 'ro') | ||||
|         ax1.annotate(f'最优: {best_profit_pct:.1f}%\n{best_profit:,.0f}元', | ||||
|                     (best_profit_pct, best_profit), | ||||
|                     xytext=(10, 10), textcoords='offset points') | ||||
|          | ||||
|         # 2. 平均每笔收益和中位数收益 | ||||
|         ax2.plot(df.index, df['avg_profit_per_trade'], 'g-', linewidth=2, label='平均收益') | ||||
|         ax2.plot(df.index, df['median_profit_per_trade'], 'r--', linewidth=2, label='中位数收益') | ||||
|         ax2.set_title('每笔交易收益 vs 止盈比例') | ||||
|         ax2.set_xlabel('止盈比例 (%)') | ||||
|         ax2.set_ylabel('收益 (元)') | ||||
|         ax2.legend() | ||||
|         ax2.grid(True) | ||||
|          | ||||
|         # 3. 平均持仓天数曲线 | ||||
|         ax3.plot(df.index, df['avg_holding_days'], 'r-', linewidth=2) | ||||
|         ax3.set_title('平均持仓天数 vs 止盈比例') | ||||
|         ax3.set_xlabel('止盈比例 (%)') | ||||
|         ax3.set_ylabel('持仓天数') | ||||
|         ax3.grid(True) | ||||
|          | ||||
|         # 4. 胜率对比 | ||||
|         ax4.plot(df.index, df['trade_win_rate'], 'c-', linewidth=2, label='交易胜率') | ||||
|         ax4.plot(df.index, df['stock_win_rate'], 'm--', linewidth=2, label='股票胜率') | ||||
|         ax4.set_title('胜率 vs 止盈比例') | ||||
|         ax4.set_xlabel('止盈比例 (%)') | ||||
|         ax4.set_ylabel('胜率 (%)') | ||||
|         ax4.legend() | ||||
|         ax4.grid(True) | ||||
|          | ||||
|         # 调整布局 | ||||
|         plt.tight_layout() | ||||
|         return plt | ||||
| 
 | ||||
|     def analyze_profit_matrix(self, start_date, end_date, stop_loss_pct, pullback_entry_pct): | ||||
|         """分析不同止盈比例的表现矩阵""" | ||||
|         # 创建results目录和临时文件目录 | ||||
|         if not os.path.exists('results/temp'): | ||||
|             os.makedirs('results/temp') | ||||
|              | ||||
|         # 从30%到5%,每次减少1% | ||||
|         profit_percentages = list(np.arange(0.30, 0.04, -0.01)) | ||||
|          | ||||
|         for take_profit_pct in profit_percentages: | ||||
|             # 检查是否已经分析过这个止盈点 | ||||
|             temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' | ||||
|             if os.path.exists(temp_file): | ||||
|                 print(f"\n止盈比例 {take_profit_pct*100:.1f}% 已分析,跳过") | ||||
|                 continue | ||||
|                  | ||||
|             print(f"\n分析止盈比例: {take_profit_pct*100:.1f}%") | ||||
|              | ||||
|             try: | ||||
|                 # 运行分析 | ||||
|                 results = self.analyze_all_stocks( | ||||
|                     start_date, | ||||
|                     end_date, | ||||
|                     take_profit_pct, | ||||
|                     stop_loss_pct, | ||||
|                     pullback_entry_pct | ||||
|                 ) | ||||
|                  | ||||
|                 # 过滤掉空结果 | ||||
|                 valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} | ||||
|                  | ||||
|                 if valid_results: | ||||
|                     # 计算统计数据 | ||||
|                     total_stats = { | ||||
|                         'total_trades': 0, | ||||
|                         'profitable_trades': 0, | ||||
|                         'loss_trades': 0, | ||||
|                         'total_holding_days': 0, | ||||
|                         'total_profit': 0, | ||||
|                         'total_capital_usage': 0, | ||||
|                         'total_win_rate': 0, | ||||
|                         'stock_count': len(valid_results), | ||||
|                         'profitable_stocks': 0, | ||||
|                         'all_trade_profits': [] | ||||
|                     } | ||||
|                      | ||||
|                     # 收集每支股票的统计数据 | ||||
|                     for symbol, result in valid_results.items(): | ||||
|                         if 'stats' in result.iloc[-1]: | ||||
|                             stats = result.iloc[-1]['stats'] | ||||
|                             total_stats['total_trades'] += stats['total_trades'] | ||||
|                             total_stats['profitable_trades'] += stats['profitable_trades'] | ||||
|                             total_stats['loss_trades'] += stats['loss_trades'] | ||||
|                             total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] | ||||
|                             total_stats['total_profit'] += stats['final_profit'] | ||||
|                             total_stats['total_capital_usage'] += stats['max_capital_usage'] | ||||
|                             total_stats['total_win_rate'] += stats['win_rate'] | ||||
|                              | ||||
|                             if stats['final_profit'] > 0: | ||||
|                                 total_stats['profitable_stocks'] += 1 | ||||
|                                  | ||||
|                             if hasattr(stats, 'trade_profits'): | ||||
|                                 total_stats['all_trade_profits'].extend(stats['trade_profits']) | ||||
|                      | ||||
|                     # 计算统计数据 | ||||
|                     avg_trades = total_stats['total_trades'] / total_stats['stock_count'] | ||||
|                     avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 | ||||
|                     avg_profit_per_trade = total_stats['total_profit'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 | ||||
|                     median_profit_per_trade = np.median(total_stats['all_trade_profits']) if total_stats['all_trade_profits'] else 0 | ||||
|                     total_profit = total_stats['total_profit'] | ||||
|                     trade_win_rate = (total_stats['profitable_trades'] / total_stats['total_trades'] * 100) if total_stats['total_trades'] > 0 else 0 | ||||
|                     stock_win_rate = (total_stats['profitable_stocks'] / total_stats['stock_count'] * 100) if total_stats['stock_count'] > 0 else 0 | ||||
|                      | ||||
|                     # 创建单个止盈点的结果DataFrame | ||||
|                     result_df = pd.DataFrame([{ | ||||
|                         'take_profit_pct': take_profit_pct * 100, | ||||
|                         'total_profit': total_profit, | ||||
|                         'avg_profit_per_trade': avg_profit_per_trade, | ||||
|                         'median_profit_per_trade': median_profit_per_trade, | ||||
|                         'avg_holding_days': avg_holding_days, | ||||
|                         'trade_win_rate': trade_win_rate, | ||||
|                         'stock_win_rate': stock_win_rate, | ||||
|                         'total_trades': total_stats['total_trades'], | ||||
|                         'stock_count': total_stats['stock_count'] | ||||
|                     }]) | ||||
|                      | ||||
|                     # 保存单个止盈点的结果 | ||||
|                     result_df.to_csv(temp_file, index=False) | ||||
|                     print(f"保存止盈点 {take_profit_pct*100:.1f}% 的分析结果") | ||||
|                      | ||||
|                 # 清理内存 | ||||
|                 del valid_results | ||||
|                 del results | ||||
|                 if 'total_stats' in locals(): | ||||
|                     del total_stats | ||||
|                  | ||||
|             except Exception as e: | ||||
|                 print(f"处理止盈点 {take_profit_pct*100:.1f}% 时出错: {str(e)}") | ||||
|                 continue | ||||
|          | ||||
|         # 合并所有结果 | ||||
|         print("\n合并所有分析结果...") | ||||
|         all_results = [] | ||||
|         for take_profit_pct in profit_percentages: | ||||
|             temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' | ||||
|             if os.path.exists(temp_file): | ||||
|                 try: | ||||
|                     df = pd.read_csv(temp_file) | ||||
|                     all_results.append(df) | ||||
|                 except Exception as e: | ||||
|                     print(f"读取文件 {temp_file} 时出错: {str(e)}") | ||||
|          | ||||
|         if all_results: | ||||
|             # 合并所有结果 | ||||
|             df = pd.concat(all_results, ignore_index=True) | ||||
|             df = df.round({ | ||||
|                 'take_profit_pct': 1, | ||||
|                 'total_profit': 0, | ||||
|                 'avg_profit_per_trade': 0, | ||||
|                 'median_profit_per_trade': 0, | ||||
|                 'avg_holding_days': 1, | ||||
|                 'trade_win_rate': 1, | ||||
|                 'stock_win_rate': 1, | ||||
|                 'total_trades': 0, | ||||
|                 'stock_count': 0 | ||||
|             }) | ||||
|              | ||||
|             # 设置止盈比例为索引并排序 | ||||
|             df.set_index('take_profit_pct', inplace=True) | ||||
|             df.sort_index(ascending=False, inplace=True) | ||||
|              | ||||
|             # 保存最终结果 | ||||
|             df.to_csv('results/profit_matrix_analysis.csv') | ||||
|              | ||||
|             # 打印结果 | ||||
|             print("\n止盈比例分析矩阵:") | ||||
|             print("=" * 120) | ||||
|             print(df.to_string()) | ||||
|             print("=" * 120) | ||||
|              | ||||
|             try: | ||||
|                 # 绘制并保存矩阵分析图表 | ||||
|                 plt = self.plot_profit_matrix(df) | ||||
|                 plt.savefig('results/profit_matrix_analysis.png', bbox_inches='tight', dpi=300) | ||||
|                 plt.close() | ||||
|             except Exception as e: | ||||
|                 print(f"绘制图表时出错: {str(e)}") | ||||
|              | ||||
|             return df | ||||
|         else: | ||||
|             print("没有找到任何分析结果") | ||||
|             return pd.DataFrame() | ||||
| 
 | ||||
| def main(): | ||||
|     # 数据库连接配置 | ||||
|     db_config = { | ||||
|         'host': '192.168.18.199', | ||||
|         'port': 3306, | ||||
|         'user': 'root', | ||||
|         'password': 'Chlry#$.8', | ||||
|         'database': 'db_gp_cj' | ||||
|     } | ||||
|      | ||||
|     try: | ||||
|         connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}" | ||||
|          | ||||
|         # 创建results目录 | ||||
| 
 | ||||
|         if not os.path.exists('results'): | ||||
|             os.makedirs('results') | ||||
|          | ||||
|         # 创建分析器实例 | ||||
|         analyzer = StockAnalyzer(connection_string) | ||||
|          | ||||
|         # 设置分析参数 | ||||
|         start_date = '2022-05-05' | ||||
|         end_date = '2023-08-28' | ||||
|         stop_loss_pct = -0.99   # -99%止损 | ||||
|         pullback_entry_pct = -0.05  # -5%回调建仓 | ||||
|          | ||||
|         # 运行止盈比例分析 | ||||
|         profit_matrix = analyzer.analyze_profit_matrix( | ||||
|             start_date, | ||||
|             end_date, | ||||
|             stop_loss_pct, | ||||
|             pullback_entry_pct | ||||
|         ) | ||||
|          | ||||
|         if not profit_matrix.empty: | ||||
|             # 使用最优止盈比例进行完整分析 | ||||
|             take_profit_pct = profit_matrix['total_profit'].idxmax() / 100 | ||||
|             print(f"\n使用最优止盈比例 {take_profit_pct*100:.1f}% 运行详细分析") | ||||
|              | ||||
|             # 运行分析并只生成所需的图表 | ||||
|             results = analyzer.analyze_all_stocks( | ||||
|                 start_date, | ||||
|                 end_date, | ||||
|                 take_profit_pct, | ||||
|                 stop_loss_pct, | ||||
|                 pullback_entry_pct | ||||
|             ) | ||||
|              | ||||
|             # 过滤掉空结果 | ||||
|             valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} | ||||
|              | ||||
|             if valid_results: | ||||
|                 # 只生成所需的两个图表 | ||||
|                 plt, _ = analyzer.plot_all_stocks(valid_results) | ||||
|                 plt.savefig('results/all_stocks_analysis.png', bbox_inches='tight', dpi=300) | ||||
|                 plt.close() | ||||
|          | ||||
|     except Exception as e: | ||||
|         print(f"程序运行出错: {str(e)}") | ||||
|         raise | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main()  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,81 @@ | |||
| # 用前须知 | ||||
| 
 | ||||
| ## xtdata提供和MiniQmt的交互接口,本质是和MiniQmt建立连接,由MiniQmt处理行情数据请求,再把结果回传返回到python层。使用的行情服务器以及能获取到的行情数据和MiniQmt是一致的,要检查数据或者切换连接时直接操作MiniQmt即可。 | ||||
| 
 | ||||
| ## 对于数据获取接口,使用时需要先确保MiniQmt已有所需要的数据,如果不足可以通过补充数据接口补充,再调用数据获取接口获取。 | ||||
| 
 | ||||
| ## 对于订阅接口,直接设置数据回调,数据到来时会由回调返回。订阅接收到的数据一般会保存下来,同种数据不需要再单独补充。 | ||||
| 
 | ||||
| # 代码讲解 | ||||
| 
 | ||||
| # 从本地python导入xtquant库,如果出现报错则说明安装失败 | ||||
| from xtquant import xtdata | ||||
| import time | ||||
| 
 | ||||
| # 设定一个标的列表 | ||||
| code_list = ["000001.SZ"] | ||||
| # 设定获取数据的周期 | ||||
| period = "1d" | ||||
| 
 | ||||
| # 下载标的行情数据 | ||||
| if 1: | ||||
|     ## 为了方便用户进行数据管理,xtquant的大部分历史数据都是以压缩形式存储在本地的 | ||||
|     ## 比如行情数据,需要通过download_history_data下载,财务数据需要通过 | ||||
|     ## 所以在取历史数据之前,我们需要调用数据下载接口,将数据下载到本地 | ||||
|     for i in code_list: | ||||
|         xtdata.download_history_data(i, period=period, incrementally=True)  # 增量下载行情数据(开高低收,等等)到本地 | ||||
| 
 | ||||
|     xtdata.download_financial_data(code_list)  # 下载财务数据到本地 | ||||
|     xtdata.download_sector_data()  # 下载板块数据到本地 | ||||
|     # 更多数据的下载方式可以通过数据字典查询 | ||||
| 
 | ||||
| # 读取本地历史行情数据 | ||||
| history_data = xtdata.get_market_data_ex([], code_list, period=period, count=-1) | ||||
| print(history_data) | ||||
| print("=" * 20) | ||||
| 
 | ||||
| # 如果需要盘中的实时行情,需要向服务器进行订阅后才能获取 | ||||
| # 订阅后,get_market_data函数于get_market_data_ex函数将会自动拼接本地历史行情与服务器实时行情 | ||||
| 
 | ||||
| # 向服务器订阅数据 | ||||
| for i in code_list: | ||||
|     xtdata.subscribe_quote(i, period=period, count=-1)  # 设置count = -1来取到当天所有实时行情 | ||||
| 
 | ||||
| # 等待订阅完成 | ||||
| time.sleep(1) | ||||
| 
 | ||||
| # 获取订阅后的行情 | ||||
| kline_data = xtdata.get_market_data_ex([], code_list, period=period) | ||||
| print(kline_data) | ||||
| 
 | ||||
| # 获取订阅后的行情,并以固定间隔进行刷新,预期会循环打印10次 | ||||
| for i in range(10): | ||||
|     # 这边做演示,就用for来循环了,实际使用中可以用while True | ||||
|     kline_data = xtdata.get_market_data_ex([], code_list, period=period) | ||||
|     print(kline_data) | ||||
|     time.sleep(3)  # 三秒后再次获取行情 | ||||
| 
 | ||||
| 
 | ||||
| # 如果不想用固定间隔触发,可以以用订阅后的回调来执行 | ||||
| # 这种模式下当订阅的callback回调函数将会异步的执行,每当订阅的标的tick发生变化更新,callback回调函数就会被调用一次 | ||||
| # 本地已有的数据不会触发callback | ||||
| 
 | ||||
| # 定义的回测函数 | ||||
| ## 回调函数中,data是本次触发回调的数据,只有一条 | ||||
| def f(data): | ||||
|     # print(data) | ||||
| 
 | ||||
|     code_list = list(data.keys())  # 获取到本次触发的标的代码 | ||||
| 
 | ||||
|     kline_in_callabck = xtdata.get_market_data_ex([], code_list, period=period)  # 在回调中获取klines数据 | ||||
|     print(kline_in_callabck) | ||||
| 
 | ||||
| 
 | ||||
| for i in code_list: | ||||
|     xtdata.subscribe_quote(i, period=period, count=-1, callback=f)  # 订阅时设定回调函数 | ||||
| 
 | ||||
| # 使用回调时,必须要同时使用xtdata.run()来阻塞程序,否则程序运行到最后一行就直接结束退出了。 | ||||
| xtdata.run() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
		Loading…
	
		Reference in New Issue