commit;
This commit is contained in:
		
							parent
							
								
									ae9b2f2fcf
								
							
						
					
					
						commit
						aadef5f0fa
					
				
							
								
								
									
										129
									
								
								README.md
								
								
								
								
							
							
						
						
									
										129
									
								
								README.md
								
								
								
								
							|  | @ -211,6 +211,127 @@ pip install -r requirements.txt | ||||||
| 
 | 
 | ||||||
| 筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。 | 筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。 | ||||||
| 
 | 
 | ||||||
|  | ## Docker 部署说明 | ||||||
|  | 
 | ||||||
|  | 本系统支持通过 Docker 进行部署,可以单实例部署或多实例部署。 | ||||||
|  | 
 | ||||||
|  | ### 1. 单实例部署 | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | # 构建镜像 | ||||||
|  | docker-compose build | ||||||
|  | 
 | ||||||
|  | # 启动服务 | ||||||
|  | docker-compose up -d | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ### 2. 多实例部署 | ||||||
|  | 
 | ||||||
|  | 提供了三个脚本用于管理多实例部署,每个实例将在不同端口上运行(从5088开始递增): | ||||||
|  | 
 | ||||||
|  | #### 使用 docker run 部署多实例(推荐) | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | # 赋予脚本执行权限 | ||||||
|  | chmod +x deploy-multiple.sh | ||||||
|  | 
 | ||||||
|  | # 部署3个实例,端口将为5088、5089、5090 | ||||||
|  | ./deploy-multiple.sh 3 | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 每个实例将有自己独立的数据目录:`./instances/instance-N/`,包含独立的日志和报告文件。 | ||||||
|  | 
 | ||||||
|  | #### 使用 docker-compose 部署多实例 | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | # 赋予脚本执行权限 | ||||||
|  | chmod +x deploy-compose.sh | ||||||
|  | 
 | ||||||
|  | # 部署3个实例,端口将为5088、5089、5090 | ||||||
|  | ./deploy-compose.sh 3 | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 每个实例将使用独立的项目名称 `stock-app-N`。 | ||||||
|  | 
 | ||||||
|  | #### 管理多实例 | ||||||
|  | 
 | ||||||
|  | 使用 `manage-instances.sh` 脚本可以方便地管理已部署的实例: | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | # 赋予脚本执行权限 | ||||||
|  | chmod +x manage-instances.sh | ||||||
|  | 
 | ||||||
|  | # 查看所有实例状态 | ||||||
|  | ./manage-instances.sh status | ||||||
|  | 
 | ||||||
|  | # 列出正在运行的实例 | ||||||
|  | ./manage-instances.sh list | ||||||
|  | 
 | ||||||
|  | # 启动指定实例 | ||||||
|  | ./manage-instances.sh start 2 | ||||||
|  | 
 | ||||||
|  | # 启动所有实例 | ||||||
|  | ./manage-instances.sh start all | ||||||
|  | 
 | ||||||
|  | # 停止指定实例 | ||||||
|  | ./manage-instances.sh stop 1 | ||||||
|  | 
 | ||||||
|  | # 停止所有实例 | ||||||
|  | ./manage-instances.sh stop all | ||||||
|  | 
 | ||||||
|  | # 重启指定实例 | ||||||
|  | ./manage-instances.sh restart 3 | ||||||
|  | 
 | ||||||
|  | # 查看指定实例的日志 | ||||||
|  | ./manage-instances.sh logs 1 | ||||||
|  | 
 | ||||||
|  | # 删除指定实例 | ||||||
|  | ./manage-instances.sh remove 2 | ||||||
|  | 
 | ||||||
|  | # 删除所有实例 | ||||||
|  | ./manage-instances.sh remove all | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ### 3. 中文字体安装 | ||||||
|  | 
 | ||||||
|  | 系统需要中文字体以正常生成PDF报告。如果部署时遇到 `找不到中文字体文件` 错误,可以使用以下方法解决: | ||||||
|  | 
 | ||||||
|  | #### 自动安装字体(推荐) | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | # 进入容器内部 | ||||||
|  | docker exec -it [容器ID或名称] bash | ||||||
|  | 
 | ||||||
|  | # 在容器内执行字体安装脚本 | ||||||
|  | cd /app | ||||||
|  | python src/fundamentals_llm/setup_fonts.py | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | #### 手动安装字体 | ||||||
|  | 
 | ||||||
|  | 1. 在宿主机上下载中文字体文件(如simhei.ttf、wqy-microhei.ttc等) | ||||||
|  | 2. 创建字体目录并复制字体文件: | ||||||
|  | ```bash | ||||||
|  | mkdir -p src/fundamentals_llm/fonts | ||||||
|  | cp 你下载的字体文件.ttf src/fundamentals_llm/fonts/simhei.ttf | ||||||
|  | ``` | ||||||
|  | 3. 重新构建镜像: | ||||||
|  | ```bash | ||||||
|  | docker-compose build | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ### 4. 访问服务 | ||||||
|  | 
 | ||||||
|  | 部署完成后,可以通过以下URL访问服务: | ||||||
|  | 
 | ||||||
|  | - 单实例:`http://服务器IP:5000` | ||||||
|  | - 多实例:`http://服务器IP:508X`(X为实例序号,从8开始) | ||||||
|  | 
 | ||||||
|  | 例如: | ||||||
|  | - 实例1:`http://服务器IP:5088` | ||||||
|  | - 实例2:`http://服务器IP:5089` | ||||||
|  | - 实例3:`http://服务器IP:5090` | ||||||
|  | 
 | ||||||
| ## 数据库查询示例 | ## 数据库查询示例 | ||||||
| 
 | 
 | ||||||
| 以下SQL查询示例展示了如何从数据库中筛选特定类型的企业: | 以下SQL查询示例展示了如何从数据库中筛选特定类型的企业: | ||||||
|  | @ -251,3 +372,11 @@ AND stock_code IN ( | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| 注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。  | 注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。  | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 重新构建镜像 | ||||||
|  | docker-compose build | ||||||
|  | 
 | ||||||
|  | # 重启所有实例 | ||||||
|  | ./manage-instances.sh restart all | ||||||
|  | @ -13,5 +13,5 @@ volcengine-python-sdk[ark] | ||||||
| openai>=1.0 | openai>=1.0 | ||||||
| reportlab>=4.3.1 | reportlab>=4.3.1 | ||||||
| markdown2>=2.5.3 | markdown2>=2.5.3 | ||||||
| # 1. 按下 Win+R ,输入 regedit 打开注册表编辑器。 | google-genai | ||||||
| # 2. 设置 \HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem 路径下的变量 LongPathsEnabled 为 1 即可。 | redis==5.2.1 | ||||||
							
								
								
									
										570
									
								
								src/app.py
								
								
								
								
							
							
						
						
									
										570
									
								
								src/app.py
								
								
								
								
							|  | @ -1,27 +1,47 @@ | ||||||
| import sys | import sys | ||||||
| import os | import os | ||||||
|  | from datetime import datetime, timedelta | ||||||
|  | import pandas as pd | ||||||
|  | import uuid | ||||||
|  | import json | ||||||
|  | from threading import Thread | ||||||
| 
 | 
 | ||||||
| from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db | from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db | ||||||
| 
 | 
 | ||||||
| # 添加项目根目录到 Python 路径 | # 添加项目根目录到 Python 路径 | ||||||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||||
| 
 | 
 | ||||||
| from flask import Flask, jsonify, request | from flask import Flask, jsonify, request, send_from_directory | ||||||
| from flask_cors import CORS | from flask_cors import CORS | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| # 导入企业筛选器 | # 导入企业筛选器 | ||||||
| from src.fundamentals_llm.enterprise_screener import EnterpriseScreener | from src.fundamentals_llm.enterprise_screener import EnterpriseScreener | ||||||
| 
 | 
 | ||||||
|  | # 导入股票回测器 | ||||||
|  | from src.stock_analysis_v2 import run_backtest, StockBacktester | ||||||
|  | 
 | ||||||
| # 设置日志 | # 设置日志 | ||||||
| logging.basicConfig( | logging.basicConfig( | ||||||
|     level=logging.INFO, |     level=logging.INFO, | ||||||
|     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||||||
|  |     handlers=[ | ||||||
|  |         logging.StreamHandler(),  # 输出到控制台 | ||||||
|  |         logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')  # 输出到文件 | ||||||
|  |     ] | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | # 确保logs和results目录存在 | ||||||
|  | os.makedirs('logs', exist_ok=True) | ||||||
|  | os.makedirs('results', exist_ok=True) | ||||||
|  | os.makedirs('results/tasks', exist_ok=True) | ||||||
|  | os.makedirs('static/results', exist_ok=True) | ||||||
|  | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | logger.info("Flask应用启动") | ||||||
| 
 | 
 | ||||||
| # 创建 Flask 应用 | # 创建 Flask 应用 | ||||||
| app = Flask(__name__) | app = Flask(__name__, static_folder='static') | ||||||
| CORS(app)  # 启用跨域请求支持 | CORS(app)  # 启用跨域请求支持 | ||||||
| 
 | 
 | ||||||
| # 创建企业筛选器实例 | # 创建企业筛选器实例 | ||||||
|  | @ -30,6 +50,329 @@ screener = EnterpriseScreener() | ||||||
| # 获取数据库连接 | # 获取数据库连接 | ||||||
| db = next(get_db()) | db = next(get_db()) | ||||||
| 
 | 
 | ||||||
|  | # 获取项目根目录 | ||||||
|  | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||||
|  | REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports') | ||||||
|  | 
 | ||||||
|  | # 确保reports目录存在 | ||||||
|  | os.makedirs(REPORTS_DIR, exist_ok=True) | ||||||
|  | logger.info(f"报告目录路径: {REPORTS_DIR}") | ||||||
|  | 
 | ||||||
|  | # 存储回测任务状态的字典 | ||||||
|  | backtest_tasks = {} | ||||||
|  | 
 | ||||||
|  | def run_backtest_task(task_id, stocks_buy_dates, end_date): | ||||||
|  |     """ | ||||||
|  |     在后台运行回测任务 | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         logger.info(f"开始执行回测任务 {task_id}") | ||||||
|  |         # 更新任务状态为进行中 | ||||||
|  |         backtest_tasks[task_id]['status'] = 'running' | ||||||
|  |          | ||||||
|  |         # 运行回测 | ||||||
|  |         results, stats_list = run_backtest(stocks_buy_dates, end_date) | ||||||
|  |          | ||||||
|  |         # 如果回测成功 | ||||||
|  |         if results and stats_list: | ||||||
|  |             # 计算总体统计 | ||||||
|  |             stats_df = pd.DataFrame(stats_list) | ||||||
|  |             total_profit = stats_df['final_profit'].sum() | ||||||
|  |             avg_win_rate = stats_df['win_rate'].mean() | ||||||
|  |             avg_holding_days = stats_df['avg_holding_days'].mean() | ||||||
|  |              | ||||||
|  |             # 找出最佳止盈比例(假设为0.15,实际应从回测结果中分析得出) | ||||||
|  |             best_take_profit_pct = 0.15 | ||||||
|  |              | ||||||
|  |             # 获取图表URL | ||||||
|  |             chart_urls = { | ||||||
|  |                 "all_stocks": f"/static/results/{task_id}/all_stocks_analysis.png", | ||||||
|  |                 "profit_matrix": f"/static/results/{task_id}/profit_matrix_analysis.png" | ||||||
|  |             } | ||||||
|  |              | ||||||
|  |             # 保存股票详细统计 | ||||||
|  |             stock_stats = [] | ||||||
|  |             for stat in stats_list: | ||||||
|  |                 stock_stats.append({ | ||||||
|  |                     "symbol": stat['symbol'], | ||||||
|  |                     "total_trades": int(stat['total_trades']), | ||||||
|  |                     "profitable_trades": int(stat['profitable_trades']), | ||||||
|  |                     "loss_trades": int(stat['loss_trades']), | ||||||
|  |                     "win_rate": float(stat['win_rate']), | ||||||
|  |                     "avg_holding_days": float(stat['avg_holding_days']), | ||||||
|  |                     "final_profit": float(stat['final_profit']), | ||||||
|  |                     "entry_count": int(stat['entry_count']) | ||||||
|  |                 }) | ||||||
|  |              | ||||||
|  |             # 构建结果数据 | ||||||
|  |             result_data = { | ||||||
|  |                 "task_id": task_id, | ||||||
|  |                 "status": "completed", | ||||||
|  |                 "results": { | ||||||
|  |                     "total_profit": float(total_profit), | ||||||
|  |                     "win_rate": float(avg_win_rate), | ||||||
|  |                     "avg_holding_days": float(avg_holding_days), | ||||||
|  |                     "best_take_profit_pct": best_take_profit_pct, | ||||||
|  |                     "stock_stats": stock_stats, | ||||||
|  |                     "chart_urls": chart_urls | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |              | ||||||
|  |             # 保存结果到文件 | ||||||
|  |             task_result_path = os.path.join('results', 'tasks', f"{task_id}.json") | ||||||
|  |             with open(task_result_path, 'w', encoding='utf-8') as f: | ||||||
|  |                 json.dump(result_data, f, ensure_ascii=False, indent=2) | ||||||
|  |              | ||||||
|  |             # 更新任务状态为已完成 | ||||||
|  |             backtest_tasks[task_id].update({ | ||||||
|  |                 'status': 'completed', | ||||||
|  |                 'results': result_data['results'] | ||||||
|  |             }) | ||||||
|  |              | ||||||
|  |             logger.info(f"回测任务 {task_id} 已完成") | ||||||
|  |         else: | ||||||
|  |             # 更新任务状态为失败 | ||||||
|  |             backtest_tasks[task_id]['status'] = 'failed' | ||||||
|  |             backtest_tasks[task_id]['error'] = "回测未产生有效结果" | ||||||
|  |             logger.error(f"回测任务 {task_id} 失败:回测未产生有效结果") | ||||||
|  |     except Exception as e: | ||||||
|  |         # 更新任务状态为失败 | ||||||
|  |         backtest_tasks[task_id]['status'] = 'failed' | ||||||
|  |         backtest_tasks[task_id]['error'] = str(e) | ||||||
|  |         logger.error(f"回测任务 {task_id} 失败:{str(e)}") | ||||||
|  | 
 | ||||||
|  | @app.route('/api/backtest/run', methods=['POST']) | ||||||
|  | def start_backtest(): | ||||||
|  |     """启动回测任务 | ||||||
|  |      | ||||||
|  |     请求体格式: | ||||||
|  |     { | ||||||
|  |         "stocks_buy_dates": { | ||||||
|  |             "SH600522": ["2022-05-10", "2022-06-10"],  // 股票代码: [买入日期列表] | ||||||
|  |             "SZ002340": ["2022-06-15"], | ||||||
|  |             "SH601615": ["2022-07-20", "2022-08-01"] | ||||||
|  |         }, | ||||||
|  |         "end_date": "2022-10-20"  // 所有股票共同的结束日期 | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     返回内容: | ||||||
|  |     { | ||||||
|  |         "task_id": "backtask-khkerhy4u237y489237489truiy8432" | ||||||
|  |     } | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # 从请求体获取参数 | ||||||
|  |         data = request.get_json() | ||||||
|  |          | ||||||
|  |         if not data: | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": "请求格式错误: 需要提供JSON数据" | ||||||
|  |             }), 400 | ||||||
|  |              | ||||||
|  |         stocks_buy_dates = data.get('stocks_buy_dates') | ||||||
|  |         end_date = data.get('end_date') | ||||||
|  |          | ||||||
|  |         if not stocks_buy_dates or not isinstance(stocks_buy_dates, dict): | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": "请求格式错误: 需要提供stocks_buy_dates字典" | ||||||
|  |             }), 400 | ||||||
|  |              | ||||||
|  |         if not end_date or not isinstance(end_date, str): | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": "请求格式错误: 需要提供有效的end_date" | ||||||
|  |             }), 400 | ||||||
|  |              | ||||||
|  |         # 验证日期格式 | ||||||
|  |         try: | ||||||
|  |             datetime.strptime(end_date, '%Y-%m-%d') | ||||||
|  |             for stock_code, buy_dates in stocks_buy_dates.items(): | ||||||
|  |                 if not isinstance(buy_dates, list): | ||||||
|  |                     return jsonify({ | ||||||
|  |                         "status": "error",  | ||||||
|  |                         "message": f"请求格式错误: 股票 {stock_code} 的买入日期必须是列表" | ||||||
|  |                     }), 400 | ||||||
|  |                 for buy_date in buy_dates: | ||||||
|  |                     datetime.strptime(buy_date, '%Y-%m-%d') | ||||||
|  |         except ValueError as e: | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": f"日期格式错误: {str(e)}" | ||||||
|  |             }), 400 | ||||||
|  |          | ||||||
|  |         # 生成任务ID | ||||||
|  |         task_id = f"backtask-{uuid.uuid4().hex[:16]}" | ||||||
|  |          | ||||||
|  |         # 创建任务目录 | ||||||
|  |         task_dir = os.path.join('static', 'results', task_id) | ||||||
|  |         os.makedirs(task_dir, exist_ok=True) | ||||||
|  |          | ||||||
|  |         # 记录任务信息 | ||||||
|  |         backtest_tasks[task_id] = { | ||||||
|  |             'status': 'pending',  | ||||||
|  |             'created_at': datetime.now().isoformat(), | ||||||
|  |             'stocks_buy_dates': stocks_buy_dates, | ||||||
|  |             'end_date': end_date | ||||||
|  |         } | ||||||
|  |          | ||||||
|  |         # 创建线程运行回测 | ||||||
|  |         thread = Thread(target=run_backtest_task, args=(task_id, stocks_buy_dates, end_date)) | ||||||
|  |         thread.daemon = True | ||||||
|  |         thread.start() | ||||||
|  |          | ||||||
|  |         logger.info(f"已创建回测任务: {task_id}") | ||||||
|  |          | ||||||
|  |         return jsonify({ | ||||||
|  |             "task_id": task_id | ||||||
|  |         }) | ||||||
|  |          | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"创建回测任务失败: {str(e)}") | ||||||
|  |         return jsonify({ | ||||||
|  |             "status": "error",  | ||||||
|  |             "message": f"创建回测任务失败: {str(e)}" | ||||||
|  |         }), 500 | ||||||
|  | 
 | ||||||
|  | @app.route('/api/backtest/status', methods=['GET']) | ||||||
|  | def check_backtest_status(): | ||||||
|  |     """查询回测任务状态 | ||||||
|  |      | ||||||
|  |     参数: | ||||||
|  |     - task_id: 回测任务ID | ||||||
|  |      | ||||||
|  |     返回内容: | ||||||
|  |     { | ||||||
|  |         "task_id": "backtask-khkerhy4u237y489237489truiy8432", | ||||||
|  |         "status": "running" | "completed" | "failed", | ||||||
|  |         "created_at": "2023-12-01T10:30:45", | ||||||
|  |         "error": "错误信息(如有)" | ||||||
|  |     } | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         task_id = request.args.get('task_id') | ||||||
|  |          | ||||||
|  |         if not task_id: | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": "请求格式错误: 需要提供task_id参数" | ||||||
|  |             }), 400 | ||||||
|  |          | ||||||
|  |         # 检查任务是否存在 | ||||||
|  |         if task_id not in backtest_tasks: | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": f"任务不存在: {task_id}" | ||||||
|  |             }), 404 | ||||||
|  |          | ||||||
|  |         # 获取任务信息 | ||||||
|  |         task_info = backtest_tasks[task_id] | ||||||
|  |          | ||||||
|  |         # 构建响应 | ||||||
|  |         response = { | ||||||
|  |             "task_id": task_id, | ||||||
|  |             "status": task_info['status'], | ||||||
|  |             "created_at": task_info['created_at'] | ||||||
|  |         } | ||||||
|  |          | ||||||
|  |         # 如果任务失败,添加错误信息 | ||||||
|  |         if task_info['status'] == 'failed' and 'error' in task_info: | ||||||
|  |             response['error'] = task_info['error'] | ||||||
|  |          | ||||||
|  |         return jsonify(response) | ||||||
|  |          | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"查询任务状态失败: {str(e)}") | ||||||
|  |         return jsonify({ | ||||||
|  |             "status": "error",  | ||||||
|  |             "message": f"查询任务状态失败: {str(e)}" | ||||||
|  |         }), 500 | ||||||
|  | 
 | ||||||
|  | @app.route('/api/backtest/result', methods=['GET']) | ||||||
|  | def get_backtest_result(): | ||||||
|  |     """获取回测任务结果 | ||||||
|  |      | ||||||
|  |     参数: | ||||||
|  |     - task_id: 回测任务ID | ||||||
|  |      | ||||||
|  |     返回内容: | ||||||
|  |     { | ||||||
|  |         "task_id": "backtask-2023121-001", | ||||||
|  |         "status": "completed", | ||||||
|  |         "results": { | ||||||
|  |             "total_profit": 123456.78, | ||||||
|  |             "win_rate": 75.5, | ||||||
|  |             "avg_holding_days": 12.3, | ||||||
|  |             "best_take_profit_pct": 0.15, | ||||||
|  |             "stock_stats": [...], | ||||||
|  |             "chart_urls": { | ||||||
|  |                 "all_stocks": "/static/results/backtask-2023121-001/all_stocks_analysis.png", | ||||||
|  |                 "profit_matrix": "/static/results/backtask-2023121-001/profit_matrix_analysis.png" | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         task_id = request.args.get('task_id') | ||||||
|  |          | ||||||
|  |         if not task_id: | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": "请求格式错误: 需要提供task_id参数" | ||||||
|  |             }), 400 | ||||||
|  |          | ||||||
|  |         # 检查任务是否存在 | ||||||
|  |         if task_id not in backtest_tasks: | ||||||
|  |             # 尝试从文件加载 | ||||||
|  |             task_result_path = os.path.join('results', 'tasks', f"{task_id}.json") | ||||||
|  |             if os.path.exists(task_result_path): | ||||||
|  |                 with open(task_result_path, 'r', encoding='utf-8') as f: | ||||||
|  |                     result_data = json.load(f) | ||||||
|  |                 return jsonify(result_data) | ||||||
|  |             else: | ||||||
|  |                 return jsonify({ | ||||||
|  |                     "status": "error",  | ||||||
|  |                     "message": f"任务不存在: {task_id}" | ||||||
|  |                 }), 404 | ||||||
|  |          | ||||||
|  |         # 获取任务信息 | ||||||
|  |         task_info = backtest_tasks[task_id] | ||||||
|  |          | ||||||
|  |         # 检查任务是否完成 | ||||||
|  |         if task_info['status'] != 'completed': | ||||||
|  |             return jsonify({ | ||||||
|  |                 "status": "error",  | ||||||
|  |                 "message": f"任务尚未完成或已失败: {task_info['status']}" | ||||||
|  |             }), 400 | ||||||
|  |          | ||||||
|  |         # 构建响应 | ||||||
|  |         response = { | ||||||
|  |             "task_id": task_id, | ||||||
|  |             "status": "completed", | ||||||
|  |             "results": task_info['results'] | ||||||
|  |         } | ||||||
|  |          | ||||||
|  |         return jsonify(response) | ||||||
|  |          | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"获取任务结果失败: {str(e)}") | ||||||
|  |         return jsonify({ | ||||||
|  |             "status": "error",  | ||||||
|  |             "message": f"获取任务结果失败: {str(e)}" | ||||||
|  |         }), 500 | ||||||
|  | 
 | ||||||
|  | @app.route('/api/reports/<path:filename>') | ||||||
|  | def serve_report(filename): | ||||||
|  |     """提供PDF报告的访问""" | ||||||
|  |     try: | ||||||
|  |         logger.info(f"请求文件: {filename}") | ||||||
|  |         logger.info(f"从目录: {REPORTS_DIR} 提供文件") | ||||||
|  |         return send_from_directory(REPORTS_DIR, filename, as_attachment=True) | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"提供报告文件失败: {str(e)}") | ||||||
|  |         return jsonify({"error": "文件不存在"}), 404 | ||||||
|  | 
 | ||||||
| @app.route('/api/health', methods=['GET']) | @app.route('/api/health', methods=['GET']) | ||||||
| def health_check(): | def health_check(): | ||||||
|     """健康检查接口""" |     """健康检查接口""" | ||||||
|  | @ -185,10 +528,10 @@ def generate_reports(): | ||||||
|          |          | ||||||
|         # 导入 PDF 生成器模块 |         # 导入 PDF 生成器模块 | ||||||
|         try: |         try: | ||||||
|             from src.fundamentals_llm.pdf_generator import generate_investment_report |             from src.fundamentals_llm.pdf_generator import PDFGenerator | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             try: |             try: | ||||||
|                 from fundamentals_llm.pdf_generator import generate_investment_report |                 from fundamentals_llm.pdf_generator import PDFGenerator | ||||||
|             except ImportError as e: |             except ImportError as e: | ||||||
|                 logger.error(f"无法导入 PDF 生成器模块: {str(e)}") |                 logger.error(f"无法导入 PDF 生成器模块: {str(e)}") | ||||||
|                 return jsonify({ |                 return jsonify({ | ||||||
|  | @ -200,8 +543,14 @@ def generate_reports(): | ||||||
|         generated_reports = [] |         generated_reports = [] | ||||||
|         for stock_code, stock_name in stocks: |         for stock_code, stock_name in stocks: | ||||||
|             try: |             try: | ||||||
|  |                 # 创建 PDF 生成器实例 | ||||||
|  |                 generator = PDFGenerator() | ||||||
|                 # 调用 PDF 生成器 |                 # 调用 PDF 生成器 | ||||||
|                 report_path = generate_investment_report(stock_code, stock_name) |                 report_path = generator.generate_pdf( | ||||||
|  |                     title=f"{stock_name}({stock_code}) 基本面分析报告", | ||||||
|  |                     content_dict={},  # 这里需要传入实际的内容字典 | ||||||
|  |                     filename=f"{stock_name}_{stock_code}_analysis.pdf" | ||||||
|  |                 ) | ||||||
|                 generated_reports.append({ |                 generated_reports.append({ | ||||||
|                     "code": stock_code, |                     "code": stock_code, | ||||||
|                     "name": stock_name, |                     "name": stock_name, | ||||||
|  | @ -505,9 +854,10 @@ def analyze_and_recommend(): | ||||||
|             "message": f"分析和推荐股票失败: {str(e)}" |             "message": f"分析和推荐股票失败: {str(e)}" | ||||||
|         }), 500 |         }), 500 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @app.route('/api/comprehensive_analysis', methods=['POST']) | @app.route('/api/comprehensive_analysis', methods=['POST']) | ||||||
| def comprehensive_analysis(): | def comprehensive_analysis(): | ||||||
|     """综合分析接口 - 组合多种功能和参数 |     """综合分析接口 - 使用队列方式处理被锁定的股票 | ||||||
|      |      | ||||||
|     请求体格式: |     请求体格式: | ||||||
|     { |     { | ||||||
|  | @ -558,36 +908,9 @@ def comprehensive_analysis(): | ||||||
|         if not isinstance(limit, int) or limit <= 0: |         if not isinstance(limit, int) or limit <= 0: | ||||||
|             limit = 10 |             limit = 10 | ||||||
|          |          | ||||||
|         # 导入必要的聊天机器人模块 |         # 导入必要的模块 | ||||||
|         try: |         try: | ||||||
|             # 首先尝试导入聊天机器人模块 |             # 先导入基本面分析器 | ||||||
|             try: |  | ||||||
|                 from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot |  | ||||||
|                 logger.info("成功从 src.fundamentals_llm.chat_bot 导入 ChatBot") |  | ||||||
|             except ImportError as e1: |  | ||||||
|                 try: |  | ||||||
|                     from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot |  | ||||||
|                     logger.info("成功从 fundamentals_llm.chat_bot 导入 ChatBot") |  | ||||||
|                 except ImportError as e2: |  | ||||||
|                     logger.error(f"无法导入在线聊天机器人模块: {str(e1)}, {str(e2)}") |  | ||||||
|                     return jsonify({ |  | ||||||
|                         "status": "error",  |  | ||||||
|                         "message": f"服务器配置错误: 聊天机器人模块不可用,错误详情: {str(e2)}" |  | ||||||
|                     }), 500 |  | ||||||
|                      |  | ||||||
|             # 然后尝试导入离线聊天机器人模块 |  | ||||||
|             try: |  | ||||||
|                 from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot |  | ||||||
|                 logger.info("成功从 src.fundamentals_llm.chat_bot_with_offline 导入 ChatBot") |  | ||||||
|             except ImportError as e1: |  | ||||||
|                 try: |  | ||||||
|                     from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot |  | ||||||
|                     logger.info("成功从 fundamentals_llm.chat_bot_with_offline 导入 ChatBot") |  | ||||||
|                 except ImportError as e2: |  | ||||||
|                     logger.warning(f"无法导入离线聊天机器人模块: {str(e1)}, {str(e2)}") |  | ||||||
|                     # 这里可以继续执行,因为某些功能可能不需要离线模型 |  | ||||||
|              |  | ||||||
|             # 最后导入基本面分析器 |  | ||||||
|             try: |             try: | ||||||
|                 from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer |                 from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer | ||||||
|                 logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer") |                 logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer") | ||||||
|  | @ -601,6 +924,19 @@ def comprehensive_analysis(): | ||||||
|                         "status": "error",  |                         "status": "error",  | ||||||
|                         "message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}" |                         "message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}" | ||||||
|                     }), 500 |                     }), 500 | ||||||
|  |              | ||||||
|  |             # 再导入其他可能需要的模块 | ||||||
|  |             try: | ||||||
|  |                 from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot | ||||||
|  |                 from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
|  |             except ImportError: | ||||||
|  |                 try: | ||||||
|  |                     from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot | ||||||
|  |                     from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
|  |                 except ImportError: | ||||||
|  |                     # 这些模块不是必须的,所以继续执行 | ||||||
|  |                     logger.warning("无法导入聊天机器人模块,但这不会影响基本功能") | ||||||
|  |                      | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.error(f"导入必要模块时出错: {str(e)}") |             logger.error(f"导入必要模块时出错: {str(e)}") | ||||||
|             return jsonify({ |             return jsonify({ | ||||||
|  | @ -611,64 +947,141 @@ def comprehensive_analysis(): | ||||||
|         # 创建基本面分析器实例 |         # 创建基本面分析器实例 | ||||||
|         analyzer = FundamentalAnalyzer() |         analyzer = FundamentalAnalyzer() | ||||||
|          |          | ||||||
|         # 为每个股票生成投资建议 |         # 准备结果容器 | ||||||
|         investment_advices = [] |         investment_advices = {}  # 使用字典,股票代码作为键 | ||||||
|         for stock_code, stock_name in stocks: |         processing_queue = list(stocks)  # 初始处理队列 | ||||||
|  |         max_attempts = 5  # 最大重试次数 | ||||||
|  |         total_attempts = 0 | ||||||
|  |          | ||||||
|  |         # 导入数据库模块 | ||||||
|  |         from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db | ||||||
|  |          | ||||||
|  |         # 开始处理队列 | ||||||
|  |         while processing_queue and total_attempts < max_attempts: | ||||||
|  |             total_attempts += 1 | ||||||
|  |             logger.info(f"开始第 {total_attempts} 轮处理,队列中有 {len(processing_queue)} 只股票") | ||||||
|  |              | ||||||
|  |             # 暂存下一轮需要处理的股票 | ||||||
|  |             next_round_queue = [] | ||||||
|  |              | ||||||
|  |             # 处理当前队列中的所有股票 | ||||||
|  |             for stock_code, stock_name in processing_queue: | ||||||
|                 try: |                 try: | ||||||
|                 # 生成投资建议 |                     # 检查是否已有分析结果 | ||||||
|  |                     db = next(get_db()) | ||||||
|  |                     existing_result = get_analysis_result(db, stock_code, "investment_advice") | ||||||
|  |                      | ||||||
|  |                     # 如果已有近期结果,直接使用 | ||||||
|  |                     if existing_result and existing_result.update_time > datetime.now() - timedelta(hours=12): | ||||||
|  |                         investment_advices[stock_code] = { | ||||||
|  |                             "code": stock_code, | ||||||
|  |                             "name": stock_name, | ||||||
|  |                             "advice": existing_result.ai_response, | ||||||
|  |                             "reasoning": existing_result.reasoning_process, | ||||||
|  |                             "references": existing_result.references, | ||||||
|  |                             "status": "success", | ||||||
|  |                             "from_cache": True | ||||||
|  |                         } | ||||||
|  |                         logger.info(f"使用缓存的 {stock_name}({stock_code}) 分析结果") | ||||||
|  |                         continue | ||||||
|  |                      | ||||||
|  |                     # 检查是否被锁定 | ||||||
|  |                     if analyzer.is_stock_locked(stock_code, "investment_advice"): | ||||||
|  |                         # 已被锁定,放到下一轮队列 | ||||||
|  |                         next_round_queue.append([stock_code, stock_name]) | ||||||
|  |                         # 记录状态 | ||||||
|  |                         if stock_code not in investment_advices: | ||||||
|  |                             investment_advices[stock_code] = { | ||||||
|  |                                 "code": stock_code, | ||||||
|  |                                 "name": stock_name, | ||||||
|  |                                 "status": "pending", | ||||||
|  |                                 "message": f"股票 {stock_code} 正在被其他请求分析中,已加入等待队列" | ||||||
|  |                             } | ||||||
|  |                         logger.info(f"股票 {stock_name}({stock_code}) 已被锁定,放入下一轮队列") | ||||||
|  |                         continue | ||||||
|  |                      | ||||||
|  |                     # 尝试锁定并分析 | ||||||
|  |                     analyzer.lock_stock(stock_code, "investment_advice") | ||||||
|  |                     try: | ||||||
|  |                         # 执行分析 | ||||||
|                         success, advice, reasoning, references = analyzer.query_analysis( |                         success, advice, reasoning, references = analyzer.query_analysis( | ||||||
|                             stock_code, stock_name, "investment_advice" |                             stock_code, stock_name, "investment_advice" | ||||||
|                         ) |                         ) | ||||||
|                          |                          | ||||||
|  |                         # 记录结果 | ||||||
|                         if success: |                         if success: | ||||||
|                     investment_advices.append({ |                             investment_advices[stock_code] = { | ||||||
|                                 "code": stock_code, |                                 "code": stock_code, | ||||||
|                                 "name": stock_name, |                                 "name": stock_name, | ||||||
|                                 "advice": advice, |                                 "advice": advice, | ||||||
|                                 "reasoning": reasoning, |                                 "reasoning": reasoning, | ||||||
|                                 "references": references, |                                 "references": references, | ||||||
|                                 "status": "success" |                                 "status": "success" | ||||||
|                     }) |                             } | ||||||
|                     logger.info(f"成功生成 {stock_name}({stock_code}) 的投资建议") |                             logger.info(f"成功分析 {stock_name}({stock_code})") | ||||||
|                         else: |                         else: | ||||||
|                     investment_advices.append({ |                             investment_advices[stock_code] = { | ||||||
|                                 "code": stock_code, |                                 "code": stock_code, | ||||||
|                                 "name": stock_name, |                                 "name": stock_name, | ||||||
|                                 "status": "error", |                                 "status": "error", | ||||||
|                         "error": advice |                                 "error": advice or "分析失败,无详细信息" | ||||||
|                     }) |                             } | ||||||
|                     logger.error(f"生成 {stock_name}({stock_code}) 的投资建议失败: {advice}") |                             logger.error(f"分析 {stock_name}({stock_code}) 失败: {advice}") | ||||||
|  |                     finally: | ||||||
|  |                         # 确保释放锁 | ||||||
|  |                         analyzer.unlock_stock(stock_code, "investment_advice") | ||||||
|  |                          | ||||||
|                 except Exception as e: |                 except Exception as e: | ||||||
|                 logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}") |                     # 处理异常 | ||||||
|                 investment_advices.append({ |                     investment_advices[stock_code] = { | ||||||
|                         "code": stock_code, |                         "code": stock_code, | ||||||
|                         "name": stock_name, |                         "name": stock_name, | ||||||
|                         "status": "error", |                         "status": "error", | ||||||
|                         "error": str(e) |                         "error": str(e) | ||||||
|  |                     } | ||||||
|  |                     logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}") | ||||||
|  |                     # 确保释放锁 | ||||||
|  |                     try: | ||||||
|  |                         analyzer.unlock_stock(stock_code, "investment_advice") | ||||||
|  |                     except: | ||||||
|  |                         pass | ||||||
|  |              | ||||||
|  |             # 如果还有下一轮要处理的股票,等待一段时间后继续 | ||||||
|  |             if next_round_queue: | ||||||
|  |                 logger.info(f"本轮结束,还有 {len(next_round_queue)} 只股票等待下一轮处理") | ||||||
|  |                 # 等待30秒再处理下一轮 | ||||||
|  |                 import time | ||||||
|  |                 time.sleep(30) | ||||||
|  |                 processing_queue = next_round_queue | ||||||
|  |             else: | ||||||
|  |                 # 所有股票都已处理完,退出循环 | ||||||
|  |                 logger.info("所有股票处理完毕") | ||||||
|  |                 processing_queue = [] | ||||||
|  |          | ||||||
|  |         # 处理仍在队列中的股票(达到最大重试次数但仍未处理的) | ||||||
|  |         for stock_code, stock_name in processing_queue: | ||||||
|  |             if stock_code in investment_advices and investment_advices[stock_code]["status"] == "pending": | ||||||
|  |                 investment_advices[stock_code].update({ | ||||||
|  |                     "status": "timeout", | ||||||
|  |                     "message": f"等待超时,股票 {stock_code} 可能正在被长时间分析" | ||||||
|                 }) |                 }) | ||||||
|  |                 logger.warning(f"股票 {stock_name}({stock_code}) 分析超时") | ||||||
|  |          | ||||||
|  |         # 将字典转换为列表 | ||||||
|  |         investment_advice_list = list(investment_advices.values()) | ||||||
|          |          | ||||||
|         # 生成PDF报告(如果需要) |         # 生成PDF报告(如果需要) | ||||||
|         pdf_results = [] |         pdf_results = [] | ||||||
|         if generate_pdf: |         if generate_pdf: | ||||||
|             try: |             try: | ||||||
|                 # 导入PDF生成器模块 |                 # 针对已成功分析的股票生成PDF | ||||||
|  |                 for stock_info in investment_advice_list: | ||||||
|  |                     if stock_info["status"] == "success": | ||||||
|  |                         stock_code = stock_info["code"] | ||||||
|  |                         stock_name = stock_info["name"] | ||||||
|                         try: |                         try: | ||||||
|                     from src.fundamentals_llm.pdf_generator import generate_investment_report |                             report_path = analyzer.generate_pdf_report(stock_code, stock_name) | ||||||
|                 except ImportError: |                             if report_path: | ||||||
|                     try: |  | ||||||
|                         from fundamentals_llm.pdf_generator import generate_investment_report |  | ||||||
|                     except ImportError as e: |  | ||||||
|                         logger.error(f"无法导入 PDF 生成器模块: {str(e)}") |  | ||||||
|                         return jsonify({ |  | ||||||
|                             "status": "error",  |  | ||||||
|                             "message": f"服务器配置错误: PDF 生成器模块不可用, 错误详情: {str(e)}" |  | ||||||
|                         }), 500 |  | ||||||
|                  |  | ||||||
|                 # 生成报告 |  | ||||||
|                 for stock_code, stock_name in stocks: |  | ||||||
|                     try: |  | ||||||
|                         # 调用 PDF 生成器 |  | ||||||
|                         report_path = generate_investment_report(stock_code, stock_name) |  | ||||||
|                                 pdf_results.append({ |                                 pdf_results.append({ | ||||||
|                                     "code": stock_code, |                                     "code": stock_code, | ||||||
|                                     "name": stock_name, |                                     "name": stock_name, | ||||||
|  | @ -676,6 +1089,14 @@ def comprehensive_analysis(): | ||||||
|                                     "status": "success" |                                     "status": "success" | ||||||
|                                 }) |                                 }) | ||||||
|                                 logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}") |                                 logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}") | ||||||
|  |                             else: | ||||||
|  |                                 pdf_results.append({ | ||||||
|  |                                     "code": stock_code, | ||||||
|  |                                     "name": stock_name, | ||||||
|  |                                     "status": "error", | ||||||
|  |                                     "error": "生成报告失败" | ||||||
|  |                                 }) | ||||||
|  |                                 logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败") | ||||||
|                         except Exception as e: |                         except Exception as e: | ||||||
|                             logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}") |                             logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}") | ||||||
|                             pdf_results.append({ |                             pdf_results.append({ | ||||||
|  | @ -784,11 +1205,24 @@ def comprehensive_analysis(): | ||||||
|                 logger.error(f"应用企业画像筛选失败: {str(e)}") |                 logger.error(f"应用企业画像筛选失败: {str(e)}") | ||||||
|                 filtered_stocks = [] |                 filtered_stocks = [] | ||||||
|          |          | ||||||
|  |         # 统计各种状态的股票数量 | ||||||
|  |         success_count = sum(1 for item in investment_advice_list if item["status"] == "success") | ||||||
|  |         pending_count = sum(1 for item in investment_advice_list if item["status"] == "pending") | ||||||
|  |         timeout_count = sum(1 for item in investment_advice_list if item["status"] == "timeout") | ||||||
|  |         error_count = sum(1 for item in investment_advice_list if item["status"] == "error") | ||||||
|  |          | ||||||
|         # 返回结果 |         # 返回结果 | ||||||
|         response = { |         response = { | ||||||
|             "status": "success", |             "status": "success" if success_count > 0 else "partial_success" if success_count + pending_count > 0 else "failed", | ||||||
|             "total_input_stocks": len(stocks), |             "total_input_stocks": len(stocks), | ||||||
|             "investment_advices": investment_advices |             "stats": { | ||||||
|  |                 "success": success_count, | ||||||
|  |                 "pending": pending_count, | ||||||
|  |                 "timeout": timeout_count, | ||||||
|  |                 "error": error_count | ||||||
|  |             }, | ||||||
|  |             "rounds_attempted": total_attempts, | ||||||
|  |             "investment_advices": investment_advice_list | ||||||
|         } |         } | ||||||
|          |          | ||||||
|         if profile_filter: |         if profile_filter: | ||||||
|  |  | ||||||
|  | @ -150,11 +150,15 @@ class ChatBot: | ||||||
|             logger.error(f"格式化参考资料时出错: {str(e)}") |             logger.error(f"格式化参考资料时出错: {str(e)}") | ||||||
|             return str(ref) |             return str(ref) | ||||||
|          |          | ||||||
|     def chat(self, user_input: str) -> Dict[str, Any]: |     def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]: | ||||||
|         """与AI进行对话 |         """与AI进行对话 | ||||||
|          |          | ||||||
|         Args: |         Args: | ||||||
|             user_input: 用户输入的问题 |             user_input: 用户输入的问题 | ||||||
|  |             temperature: 控制输出的随机性,范围0-2,默认1.0 | ||||||
|  |             top_p: 控制输出的多样性,范围0-1,默认0.7 | ||||||
|  |             max_tokens: 控制输出的最大长度,默认4096 | ||||||
|  |             frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 | ||||||
|              |              | ||||||
|         Returns: |         Returns: | ||||||
|             Dict[str, Any]: 包含以下字段的字典: |             Dict[str, Any]: 包含以下字段的字典: | ||||||
|  | @ -173,6 +177,10 @@ class ChatBot: | ||||||
|             stream = self.client.chat.completions.create( |             stream = self.client.chat.completions.create( | ||||||
|                 model=self.model, |                 model=self.model, | ||||||
|                 messages=self.conversation_history, |                 messages=self.conversation_history, | ||||||
|  |                 temperature=temperature, | ||||||
|  |                 top_p=top_p, | ||||||
|  |                 max_tokens=max_tokens, | ||||||
|  |                 frequency_penalty=frequency_penalty, | ||||||
|                 stream=True |                 stream=True | ||||||
|             ) |             ) | ||||||
|              |              | ||||||
|  |  | ||||||
|  | @ -0,0 +1,420 @@ | ||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | import json | ||||||
|  | from typing import Dict, List, Optional, Any, Union | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | from google import genai | ||||||
|  | from google.genai.types import Tool, GenerateContentConfig, GoogleSearch | ||||||
|  | 
 | ||||||
|  | # 设置日志记录 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | # 定义基础数据结构 | ||||||
|  | class TextAnalysisResult(BaseModel): | ||||||
|  |     """文本分析结果,包含分析正文、推理过程和引用URL""" | ||||||
|  |     analysis_text: str = Field(description="详细的分析文本") | ||||||
|  |     reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None) | ||||||
|  |     references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None) | ||||||
|  | 
 | ||||||
|  | class NumericalAnalysisResult(BaseModel): | ||||||
|  |     """数值分析结果,包含数值和分析描述""" | ||||||
|  |     value: str = Field(description="评估值") | ||||||
|  |     description: str = Field(description="评估描述") | ||||||
|  | 
 | ||||||
|  | # 分析类型到模型类的映射 | ||||||
|  | ANALYSIS_TYPE_MODEL_MAP = { | ||||||
|  |     # 文本分析类型 | ||||||
|  |     "text_analysis_result": TextAnalysisResult, | ||||||
|  |      | ||||||
|  |     # 数值分析类型 | ||||||
|  |     "numerical_analysis_result": NumericalAnalysisResult | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | class JsonOutputParser: | ||||||
|  |     """解析JSON输出的工具类""" | ||||||
|  |     def __init__(self, pydantic_object): | ||||||
|  |         """初始化解析器 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             pydantic_object: 需要解析成的Pydantic模型类 | ||||||
|  |         """ | ||||||
|  |         self.pydantic_object = pydantic_object | ||||||
|  | 
 | ||||||
|  |     def get_format_instructions(self) -> str: | ||||||
|  |         """获取格式化指令""" | ||||||
|  |         schema = self.pydantic_object.schema() | ||||||
|  |         schema_str = json.dumps(schema, ensure_ascii=False, indent=2) | ||||||
|  |          | ||||||
|  |         return f"""请将你的回答格式化为符合以下JSON结构的格式: | ||||||
|  | {schema_str} | ||||||
|  | 
 | ||||||
|  | 你的回答应该只包含一个有效的JSON对象,而不包含其他内容。请不要包含任何解释或前缀,仅输出JSON本身。 | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  |     def parse(self, text: str) -> Any: | ||||||
|  |         """解析文本为对象""" | ||||||
|  |         try: | ||||||
|  |             # 尝试解析JSON | ||||||
|  |             text = text.strip() | ||||||
|  |             # 如果文本以```json开头且以```结尾,则提取中间部分 | ||||||
|  |             if text.startswith("```json") and text.endswith("```"): | ||||||
|  |                 text = text[7:-3].strip() | ||||||
|  |             # 如果文本以```开头且以```结尾,则提取中间部分 | ||||||
|  |             elif text.startswith("```") and text.endswith("```"): | ||||||
|  |                 text = text[3:-3].strip() | ||||||
|  |                  | ||||||
|  |             json_obj = json.loads(text) | ||||||
|  |             return self.pydantic_object.parse_obj(json_obj) | ||||||
|  |         except Exception as e: | ||||||
|  |             raise ValueError(f"无法解析为JSON或无法匹配模式: {e}") | ||||||
|  | 
 | ||||||
|  | # 获取项目根目录的绝对路径 | ||||||
|  | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||||
|  | 
 | ||||||
|  | # 尝试导入配置 | ||||||
|  | try: | ||||||
|  |     # 添加项目根目录到 Python 路径 | ||||||
|  |     sys.path.append(os.path.dirname(ROOT_DIR)) | ||||||
|  |     sys.path.append(ROOT_DIR) | ||||||
|  | 
 | ||||||
|  |     # 导入配置 | ||||||
|  |     try: | ||||||
|  |         from scripts.config import get_model_config | ||||||
|  | 
 | ||||||
|  |         logger.info("成功从scripts.config导入配置") | ||||||
|  |     except ImportError: | ||||||
|  |         try: | ||||||
|  |             from src.scripts.config import get_model_config | ||||||
|  | 
 | ||||||
|  |             logger.info("成功从src.scripts.config导入配置") | ||||||
|  |         except ImportError: | ||||||
|  |             logger.warning("无法导入配置模块,使用默认配置") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |             # 使用默认配置的实现 | ||||||
|  |             def get_model_config(platform: str, model_type: str): | ||||||
|  |                 return { | ||||||
|  |                     "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", | ||||||
|  |                     "model": "gemini-2.0-flash" | ||||||
|  |                 } | ||||||
|  | except Exception as e: | ||||||
|  |     logger.error(f"导入配置时出错: {str(e)},使用默认配置") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     # 使用默认配置的实现 | ||||||
|  |     def get_model_config(platform: str, model_type: str): | ||||||
|  |         return { | ||||||
|  |             "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", | ||||||
|  |             "model": "gemini-2.0-flash" | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  | class ChatBot: | ||||||
|  |     def __init__(self, platform: str = "Gemini", model_type: str = "gemini-2.0-flash"): | ||||||
|  |         """初始化聊天机器人 | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             platform: 平台名称,默认为Gemini | ||||||
|  |             model_type: 要使用的模型类型,默认为gemini-2.0-flash | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 从配置获取API配置 | ||||||
|  |             config = get_model_config(platform, model_type) | ||||||
|  |             self.api_key = config["api_key"] | ||||||
|  |             self.model = config["model"]  # 直接使用配置中的模型名称 | ||||||
|  | 
 | ||||||
|  |             logger.info(f"初始化ChatBot,使用平台: {platform}, 模型: {self.model}") | ||||||
|  | 
 | ||||||
|  |             google_search_tool = Tool( | ||||||
|  |                 google_search=GoogleSearch() | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             self.client = genai.Client(api_key=self.api_key) | ||||||
|  |              | ||||||
|  |             # 系统提示语 | ||||||
|  |             self.system_instruction = """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下: | ||||||
|  | 
 | ||||||
|  |                             1. 分析风格: | ||||||
|  |                             - 专业、客观、理性 | ||||||
|  |                             - 注重数据支撑 | ||||||
|  |                             - 关注风险控制 | ||||||
|  |                             - 重视投资性价比 | ||||||
|  | 
 | ||||||
|  |                             2. 分析框架: | ||||||
|  |                             - 公司基本面分析 | ||||||
|  |                             - 行业竞争格局 | ||||||
|  |                             - 估值水平评估 | ||||||
|  |                             - 风险因素识别 | ||||||
|  |                             - 投资机会判断 | ||||||
|  | 
 | ||||||
|  |                             3. 输出要求: | ||||||
|  |                             - 简明扼要,重点突出 | ||||||
|  |                             - 逻辑清晰,层次分明 | ||||||
|  |                             - 数据准确,论据充分 | ||||||
|  |                             - 结论明确,建议具体 | ||||||
|  | 
 | ||||||
|  |                             请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。""" | ||||||
|  |              | ||||||
|  |             # 对话历史 | ||||||
|  |             self.conversation_history = [] | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"初始化ChatBot时出错: {str(e)}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def chat(self, user_input: str, analysis_type: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096) -> dict: | ||||||
|  |         """处理用户输入并返回AI回复 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             user_input: 用户输入的问题 | ||||||
|  |             analysis_type: 分析类型,如company_profile或business_scale等 | ||||||
|  |             temperature: 控制输出的随机性,范围0-2,默认0.7 | ||||||
|  |             top_p: 控制输出的多样性,范围0-1,默认0.7 | ||||||
|  |             max_tokens: 控制输出的最大长度,默认4096 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             dict: 包含AI回复的字典 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             from google.genai import types | ||||||
|  |             google_search_tool = Tool( | ||||||
|  |                 google_search=GoogleSearch() | ||||||
|  |             ) | ||||||
|  |              | ||||||
|  |             # 确定分析类型和模型 | ||||||
|  |             model_class = None | ||||||
|  |             if analysis_type and analysis_type in ANALYSIS_TYPE_MODEL_MAP: | ||||||
|  |                 model_class = ANALYSIS_TYPE_MODEL_MAP[analysis_type] | ||||||
|  |                 parser = JsonOutputParser(pydantic_object=model_class) | ||||||
|  |                 format_instructions = parser.get_format_instructions() | ||||||
|  |                  | ||||||
|  |                 # 在原有基础上添加格式指令 | ||||||
|  |                 combined_instruction = f"{self.system_instruction}\n\n{format_instructions}" | ||||||
|  |             else: | ||||||
|  |                 combined_instruction = self.system_instruction | ||||||
|  |                 parser = None | ||||||
|  |              | ||||||
|  |             response = self.client.models.generate_content_stream( | ||||||
|  |                 model=self.model, | ||||||
|  |                 config=GenerateContentConfig( | ||||||
|  |                     system_instruction=combined_instruction, | ||||||
|  |                     tools=[google_search_tool], | ||||||
|  |                     response_modalities=["TEXT"], | ||||||
|  |                     temperature=temperature, | ||||||
|  |                     top_p=top_p, | ||||||
|  |                     max_output_tokens=max_tokens | ||||||
|  |                 ), | ||||||
|  |                 contents=user_input | ||||||
|  |             ) | ||||||
|  |              | ||||||
|  |             full_response = "" | ||||||
|  |             for chunk in response: | ||||||
|  |                 full_response += chunk.text | ||||||
|  |                 print(chunk.text, end="") | ||||||
|  |              | ||||||
|  |             # 如果指定了分析类型,尝试解析输出 | ||||||
|  |             if parser: | ||||||
|  |                 try: | ||||||
|  |                     # 检查分析类型是文本分析还是数值分析 | ||||||
|  |                     is_text_analysis = model_class == TextAnalysisResult | ||||||
|  |                     is_numerical_analysis = model_class == NumericalAnalysisResult | ||||||
|  |                      | ||||||
|  |                     # 使用解析方法尝试处理JSON响应 | ||||||
|  |                     json_result = self.parse_json_response(full_response, model_class) | ||||||
|  |                     if json_result: | ||||||
|  |                         return json_result | ||||||
|  |                      | ||||||
|  |                     # 如果使用解析方法失败,尝试使用parser解析 | ||||||
|  |                     parsed_result = parser.parse(full_response) | ||||||
|  |                     structured_result = parsed_result.dict() | ||||||
|  |                      | ||||||
|  |                     if isinstance(parsed_result, TextAnalysisResult): | ||||||
|  |                         # 文本分析结果 | ||||||
|  |                         return { | ||||||
|  |                             "response": structured_result["analysis_text"], | ||||||
|  |                             "reasoning_process": structured_result.get("reasoning_process"), | ||||||
|  |                             "references": structured_result.get("references") | ||||||
|  |                         } | ||||||
|  |                     elif isinstance(parsed_result, NumericalAnalysisResult): | ||||||
|  |                         # 数值分析结果 | ||||||
|  |                         return { | ||||||
|  |                             "response": structured_result["value"], | ||||||
|  |                             "description": structured_result.get("description", "") | ||||||
|  |                         } | ||||||
|  |                 except Exception as e: | ||||||
|  |                     logger.warning(f"解析响应失败: {str(e)},返回原始响应") | ||||||
|  |                      | ||||||
|  |                     # 最后尝试根据分析类型直接构建结果 | ||||||
|  |                     if is_text_analysis: | ||||||
|  |                         return { | ||||||
|  |                             "response": full_response, | ||||||
|  |                             "reasoning_process": None, | ||||||
|  |                             "references": None | ||||||
|  |                         } | ||||||
|  |                     elif is_numerical_analysis: | ||||||
|  |                         try: | ||||||
|  |                             # 尝试从文本中提取数字 | ||||||
|  |                             import re | ||||||
|  |                             number_match = re.search(r'\b(\d+)\b', full_response) | ||||||
|  |                             value = int(number_match.group(1)) if number_match else 0 | ||||||
|  |                             return { | ||||||
|  |                                 "response": value, | ||||||
|  |                                 "description": full_response | ||||||
|  |                             } | ||||||
|  |                         except: | ||||||
|  |                             return { | ||||||
|  |                                 "response": 0, | ||||||
|  |                                 "description": full_response | ||||||
|  |                             } | ||||||
|  |                     else: | ||||||
|  |                         return { | ||||||
|  |                             "response": full_response, | ||||||
|  |                             "reasoning_process": None, | ||||||
|  |                             "references": None | ||||||
|  |                         } | ||||||
|  |             else: | ||||||
|  |                 # 普通响应 | ||||||
|  |                 return { | ||||||
|  |                     "response": full_response, | ||||||
|  |                     "reasoning_process": None, | ||||||
|  |                     "references": None | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"聊天失败: {str(e)}") | ||||||
|  |             error_msg = f"抱歉,发生错误: {str(e)}" | ||||||
|  |             print(f"\n{error_msg}") | ||||||
|  |             return {"response": error_msg, "reasoning_process": None, "references": None} | ||||||
|  | 
 | ||||||
|  |     def clear_history(self): | ||||||
|  |         """清除对话历史""" | ||||||
|  |         self.conversation_history = [] | ||||||
|  |         # 重置会话 | ||||||
|  |         if hasattr(self, 'model_instance'): | ||||||
|  |             self.chat_session = self.model_instance.start_chat() | ||||||
|  |         print("对话历史已清除") | ||||||
|  | 
 | ||||||
|  |     def run(self): | ||||||
|  |         """运行聊天机器人""" | ||||||
|  |         print("欢迎使用Gemini AI助手!输入 'quit' 退出,输入 'clear' 清除对话历史。") | ||||||
|  |         print("-" * 50) | ||||||
|  | 
 | ||||||
|  |         while True: | ||||||
|  |             try: | ||||||
|  |                 # 获取用户输入 | ||||||
|  |                 user_input = input("\n你: ").strip() | ||||||
|  | 
 | ||||||
|  |                 # 检查退出命令 | ||||||
|  |                 if user_input.lower() == 'quit': | ||||||
|  |                     print("感谢使用,再见!") | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|  |                 # 检查清除历史命令 | ||||||
|  |                 if user_input.lower() == 'clear': | ||||||
|  |                     self.clear_history() | ||||||
|  |                     continue | ||||||
|  | 
 | ||||||
|  |                 # 如果输入为空,继续下一轮 | ||||||
|  |                 if not user_input: | ||||||
|  |                     continue | ||||||
|  | 
 | ||||||
|  |                 # 获取AI回复 | ||||||
|  |                 self.chat(user_input) | ||||||
|  |                 print("-" * 50) | ||||||
|  | 
 | ||||||
|  |             except KeyboardInterrupt: | ||||||
|  |                 print("\n感谢使用,再见!") | ||||||
|  |                 break | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"运行错误: {str(e)}") | ||||||
|  |                 print(f"发生错误: {str(e)}") | ||||||
|  | 
 | ||||||
|  |     def parse_json_response(self, full_response: str, model_class) -> dict: | ||||||
|  |         """解析JSON响应,处理不同的JSON结构 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             full_response: 完整的响应文本 | ||||||
|  |             model_class: 模型类(TextAnalysisResult或NumericalAnalysisResult) | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             dict: 包含解析结果的字典 | ||||||
|  |         """ | ||||||
|  |         # 检查分析类型 | ||||||
|  |         is_text_analysis = model_class == TextAnalysisResult | ||||||
|  |         is_numerical_analysis = model_class == NumericalAnalysisResult | ||||||
|  |          | ||||||
|  |         try: | ||||||
|  |             # 尝试提取JSON内容 | ||||||
|  |             json_text = full_response | ||||||
|  |             # 如果文本以```json开头且以```结尾,则提取中间部分 | ||||||
|  |             if "```json" in json_text and "```" in json_text.split("```json", 1)[1]: | ||||||
|  |                 json_text = json_text.split("```json", 1)[1].split("```", 1)[0].strip() | ||||||
|  |             # 如果文本以```开头且以```结尾,则提取中间部分 | ||||||
|  |             elif "```" in json_text and "```" in json_text.split("```", 1)[1]: | ||||||
|  |                 json_text = json_text.split("```", 1)[1].split("```", 1)[0].strip() | ||||||
|  |              | ||||||
|  |             # 尝试解析JSON | ||||||
|  |             json_obj = json.loads(json_text) | ||||||
|  |              | ||||||
|  |             # 情况1: 直接包含字段的格式 {"analysis_text": "...", "reasoning_process": null, "references": [...]} | ||||||
|  |             if is_text_analysis and "analysis_text" in json_obj: | ||||||
|  |                 return { | ||||||
|  |                     "response": json_obj["analysis_text"], | ||||||
|  |                     "reasoning_process": json_obj.get("reasoning_process"), | ||||||
|  |                     "references": json_obj.get("references", []) | ||||||
|  |                 } | ||||||
|  |             elif is_numerical_analysis and "value" in json_obj: | ||||||
|  |                 value = json_obj["value"] | ||||||
|  |                 if isinstance(value, str) and value.isdigit(): | ||||||
|  |                     value = int(value) | ||||||
|  |                 elif isinstance(value, (int, float)): | ||||||
|  |                     value = int(value) | ||||||
|  |                 else: | ||||||
|  |                     value = 0 | ||||||
|  |                      | ||||||
|  |                 return { | ||||||
|  |                     "response": value, | ||||||
|  |                     "description": json_obj.get("description", "") | ||||||
|  |                 } | ||||||
|  |              | ||||||
|  |             # 情况2: properties中包含字段的格式 | ||||||
|  |             if "properties" in json_obj and isinstance(json_obj["properties"], dict): | ||||||
|  |                 properties = json_obj["properties"] | ||||||
|  |                  | ||||||
|  |                 if is_text_analysis and "analysis_text" in properties: | ||||||
|  |                     return { | ||||||
|  |                         "response": properties["analysis_text"], | ||||||
|  |                         "reasoning_process": properties.get("reasoning_process"), | ||||||
|  |                         "references": properties.get("references", []) | ||||||
|  |                     } | ||||||
|  |                 elif is_numerical_analysis and "value" in properties: | ||||||
|  |                     value = properties["value"] | ||||||
|  |                     if isinstance(value, str) and value.isdigit(): | ||||||
|  |                         value = int(value) | ||||||
|  |                     elif isinstance(value, (int, float)): | ||||||
|  |                         value = int(value) | ||||||
|  |                     else: | ||||||
|  |                         value = 0 | ||||||
|  |                          | ||||||
|  |                     return { | ||||||
|  |                         "response": value, | ||||||
|  |                         "description": properties.get("description", "") | ||||||
|  |                     } | ||||||
|  |              | ||||||
|  |             # 如果无法识别结构,使用parser解析 | ||||||
|  |             raise ValueError("无法识别的JSON结构,尝试使用parser解析") | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             # 记录错误但不抛出,让调用方法继续尝试其他解析方式 | ||||||
|  |             logger.warning(f"JSON解析失败: {str(e)}") | ||||||
|  |             return None | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 设置日志级别 | ||||||
|  |     logging.basicConfig( | ||||||
|  |         level=logging.INFO, | ||||||
|  |         format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # 创建并运行聊天机器人 | ||||||
|  |     bot = ChatBot() | ||||||
|  |     bot.run()  | ||||||
|  | @ -3,7 +3,7 @@ from openai import OpenAI | ||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
| import time | import time | ||||||
| import random | 
 | ||||||
| 
 | 
 | ||||||
| # 设置日志记录 | # 设置日志记录 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | @ -102,8 +102,19 @@ class ChatBot: | ||||||
|             logger.error(f"初始化ChatBot时出错: {str(e)}") |             logger.error(f"初始化ChatBot时出错: {str(e)}") | ||||||
|             raise |             raise | ||||||
| 
 | 
 | ||||||
|     def chat(self, user_input: str) -> str: |     def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> str: | ||||||
|         """处理用户输入并返回AI回复""" |         """处理用户输入并返回AI回复 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             user_input: 用户输入的问题 | ||||||
|  |             temperature: 控制输出的随机性,范围0-2,默认1.0 | ||||||
|  |             top_p: 控制输出的多样性,范围0-1,默认0.7 | ||||||
|  |             max_tokens: 控制输出的最大长度,默认4096 | ||||||
|  |             frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             str: AI的回答内容 | ||||||
|  |         """ | ||||||
|         try: |         try: | ||||||
|             # # 添加用户消息到对话历史-多轮 |             # # 添加用户消息到对话历史-多轮 | ||||||
|             self.conversation_history.append({ |             self.conversation_history.append({ | ||||||
|  | @ -116,7 +127,10 @@ class ChatBot: | ||||||
|             stream = self.client.chat.completions.create( |             stream = self.client.chat.completions.create( | ||||||
|                 model=self.model, |                 model=self.model, | ||||||
|                 messages=self.conversation_history, |                 messages=self.conversation_history, | ||||||
|                 temperature=0, |                 temperature=temperature, | ||||||
|  |                 top_p=top_p, | ||||||
|  |                 max_tokens=max_tokens, | ||||||
|  |                 frequency_penalty=frequency_penalty, | ||||||
|                 stream=True, |                 stream=True, | ||||||
|                 timeout=600 |                 timeout=600 | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  | @ -130,9 +130,15 @@ class EnterpriseScreener: | ||||||
|             { |             { | ||||||
|                 'dimension': 'investment_advice', |                 'dimension': 'investment_advice', | ||||||
|                 'field': 'investment_advice_type', |                 'field': 'investment_advice_type', | ||||||
|                 'operator': '!=', |                 'operator': 'in', | ||||||
|                 'value': '不建议' |                 'value': ["'短期'","'中期'","'长期'"] | ||||||
|             } |             }, | ||||||
|  |             { | ||||||
|  |                 'dimension': 'financial_report', | ||||||
|  |                 'field': 'financial_report_level', | ||||||
|  |                 'operator': '>=', | ||||||
|  |                 'value': -1 | ||||||
|  |             }, | ||||||
|         ] |         ] | ||||||
|         return self._screen_stocks_by_conditions(conditions, limit) |         return self._screen_stocks_by_conditions(conditions, limit) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,7 +1,11 @@ | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| from datetime import datetime | import sys | ||||||
| from typing import Dict, List, Optional, Tuple, Callable | from datetime import datetime, timedelta | ||||||
|  | import time | ||||||
|  | import redis | ||||||
|  | from typing import Dict, List, Optional, Tuple, Callable, Any | ||||||
|  | 
 | ||||||
| # 修改导入路径,使用相对导入 | # 修改导入路径,使用相对导入 | ||||||
| try: | try: | ||||||
|     # 尝试相对导入 |     # 尝试相对导入 | ||||||
|  | @ -9,6 +13,7 @@ try: | ||||||
|     from .chat_bot_with_offline import ChatBot as OfflineChatBot |     from .chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
|     from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result |     from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result | ||||||
|     from .pdf_generator import PDFGenerator |     from .pdf_generator import PDFGenerator | ||||||
|  |     from .text_processor import TextProcessor | ||||||
| except ImportError: | except ImportError: | ||||||
|     # 如果相对导入失败,尝试尝试绝对导入 |     # 如果相对导入失败,尝试尝试绝对导入 | ||||||
|     try: |     try: | ||||||
|  | @ -16,18 +21,17 @@ except ImportError: | ||||||
|         from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot |         from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
|         from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result |         from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result | ||||||
|         from src.fundamentals_llm.pdf_generator import PDFGenerator |         from src.fundamentals_llm.pdf_generator import PDFGenerator | ||||||
|  |         from src.fundamentals_llm.text_processor import TextProcessor | ||||||
|     except ImportError: |     except ImportError: | ||||||
|         # 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况) |         # 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况) | ||||||
|         from chat_bot import ChatBot |         from chat_bot import ChatBot | ||||||
|         from chat_bot_with_offline import ChatBot as OfflineChatBot |         from chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
|         from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result |         from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result | ||||||
|         from pdf_generator import PDFGenerator |         from pdf_generator import PDFGenerator | ||||||
|  |         from text_processor import TextProcessor | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
| 
 | 
 | ||||||
| # 设置日志记录 |  | ||||||
| logger = logging.getLogger(__name__) |  | ||||||
| 
 |  | ||||||
| # 获取项目根目录的绝对路径 | # 获取项目根目录的绝对路径 | ||||||
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||||
| 
 | 
 | ||||||
|  | @ -58,6 +62,34 @@ logging.basicConfig( | ||||||
|     datefmt=date_format |     datefmt=date_format | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | logger.info("测试日志输出 - 程序启动") | ||||||
|  | 
 | ||||||
|  | from typing import Dict, List, Optional, Any, Union | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | 
 | ||||||
|  | # 定义基础数据结构 | ||||||
|  | class TextAnalysisResult(BaseModel): | ||||||
|  |     """文本分析结果,包含分析正文、推理过程和引用URL""" | ||||||
|  |     analysis_text: str = Field(description="详细的分析文本") | ||||||
|  |     reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None) | ||||||
|  |     references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None) | ||||||
|  | 
 | ||||||
|  | class NumericalAnalysisResult(BaseModel): | ||||||
|  |     """数值分析结果,包含数值和分析描述""" | ||||||
|  |     value: str = Field(description="评估值") | ||||||
|  |     description: str = Field(description="评估描述") | ||||||
|  | 
 | ||||||
|  | # 添加Redis客户端 | ||||||
|  | redis_client = redis.Redis( | ||||||
|  |     host='192.168.18.208',  # Redis服务器地址,根据实际情况调整 | ||||||
|  |     port=6379, | ||||||
|  |     password='wlkj2018', | ||||||
|  |     db=14, | ||||||
|  |     socket_timeout=5, | ||||||
|  |     decode_responses=True | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| class FundamentalAnalyzer: | class FundamentalAnalyzer: | ||||||
|     """基本面分析器""" |     """基本面分析器""" | ||||||
|      |      | ||||||
|  | @ -66,9 +98,10 @@ class FundamentalAnalyzer: | ||||||
|         # 使用联网模型进行基本面分析 |         # 使用联网模型进行基本面分析 | ||||||
|         self.chat_bot = ChatBot(model_type="online_bot") |         self.chat_bot = ChatBot(model_type="online_bot") | ||||||
|         # 使用离线模型进行其他分析 |         # 使用离线模型进行其他分析 | ||||||
|         self.offline_bot = OfflineChatBot(platform="tl_private", model_type="ds-v1") |         self.offline_bot = OfflineChatBot(platform="volc", model_type="offline_model") | ||||||
|         # 千问打杂 |         # 千问打杂 | ||||||
|         self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") |         # self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq") | ||||||
|  |         self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="GLM") | ||||||
|         self.db = next(get_db()) |         self.db = next(get_db()) | ||||||
|          |          | ||||||
|         # 定义维度映射 |         # 定义维度映射 | ||||||
|  | @ -161,6 +194,8 @@ class FundamentalAnalyzer: | ||||||
|             - 关键战略决策和转型 |             - 关键战略决策和转型 | ||||||
| 
 | 
 | ||||||
|             请提供专业、客观的分析,突出关键信息,避免冗长描述。""" |             请提供专业、客观的分析,突出关键信息,避免冗长描述。""" | ||||||
|  |             #开头清上下文缓存 | ||||||
|  |             self.chat_bot.clear_history() | ||||||
|             # 获取AI分析结果 |             # 获取AI分析结果 | ||||||
|             result = self.chat_bot.chat(prompt) |             result = self.chat_bot.chat(prompt) | ||||||
|              |              | ||||||
|  | @ -554,8 +589,8 @@ class FundamentalAnalyzer: | ||||||
|                     请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。""" |                     请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。""" | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             # 使用离线模型进行分析 |             # 使用离线模型进行分析 | ||||||
|             space_value_str = self.offline_bot_tl_qw.chat(prompt) |             space_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|             space_value_str = self._clean_model_output(space_value_str) |             space_value_str = TextProcessor.clean_thought_process(space_value_str) | ||||||
|             # 提取数值 |             # 提取数值 | ||||||
|             space_value = 0  # 默认值 |             space_value = 0  # 默认值 | ||||||
|              |              | ||||||
|  | @ -658,9 +693,9 @@ class FundamentalAnalyzer: | ||||||
|                     请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" |                     请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             # 使用离线模型进行分析 |             # 使用离线模型进行分析 | ||||||
|             events_value_str = self.offline_bot_tl_qw.chat(prompt) |             events_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|             # 数据清洗 |             # 数据清洗 | ||||||
|             events_value_str = self._clean_model_output(events_value_str) |             events_value_str = TextProcessor.clean_thought_process(events_value_str) | ||||||
|             # 提取数值 |             # 提取数值 | ||||||
|             events_value = 0  # 默认值 |             events_value = 0  # 默认值 | ||||||
|              |              | ||||||
|  | @ -700,14 +735,14 @@ class FundamentalAnalyzer: | ||||||
|     def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool: |     def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool: | ||||||
|         """分析股吧讨论内容""" |         """分析股吧讨论内容""" | ||||||
|         try: |         try: | ||||||
|             prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出: |             prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在400字以内(主要讨论话题200字,重要信息汇总200字),请严格按照以下格式输出: | ||||||
| 
 | 
 | ||||||
|                     1. 主要讨论话题(150字左右): |                     1. 主要讨论话题: | ||||||
|                     - 近期热点事件 |                     - 近期热点事件 | ||||||
|                     - 投资者关注焦点 |                     - 投资者关注焦点 | ||||||
|                     - 市场情绪倾向 |                     - 市场情绪倾向 | ||||||
|                      |                      | ||||||
|                     2. 重要信息汇总(150字左右): |                     2. 重要信息汇总: | ||||||
|                     - 公司相关动态 |                     - 公司相关动态 | ||||||
|                     - 行业政策变化 |                     - 行业政策变化 | ||||||
|                     - 市场预期变化 |                     - 市场预期变化 | ||||||
|  | @ -762,10 +797,10 @@ class FundamentalAnalyzer: | ||||||
|                     请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" |                     请仅返回一个数值:1、0或-1,不要包含任何解释或说明。""" | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             # 使用离线模型进行分析 |             # 使用离线模型进行分析 | ||||||
|             emotion_value_str = self.offline_bot_tl_qw.chat(prompt) |             emotion_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
| 
 | 
 | ||||||
|             # 数据清洗 |             # 数据清洗 | ||||||
|             emotion_value_str = self._clean_model_output(emotion_value_str) |             emotion_value_str = TextProcessor.clean_thought_process(emotion_value_str) | ||||||
|             # 提取数值 |             # 提取数值 | ||||||
|             emotion_value = 0  # 默认值 |             emotion_value = 0  # 默认值 | ||||||
|              |              | ||||||
|  | @ -1015,10 +1050,10 @@ class FundamentalAnalyzer: | ||||||
|                      |                      | ||||||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。""" |                     只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。""" | ||||||
|              |              | ||||||
|             response = self.chat_bot.chat(prompt) |             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|              |              | ||||||
|             # 提取数值 |             # 提取数值 | ||||||
|             rating_str = self._extract_numeric_value_from_response(response) |             rating_str = TextProcessor.extract_numeric_value_from_response(response) | ||||||
|              |              | ||||||
|             # 尝试将响应转换为整数 |             # 尝试将响应转换为整数 | ||||||
|             try: |             try: | ||||||
|  | @ -1057,10 +1092,10 @@ class FundamentalAnalyzer: | ||||||
|                          |                          | ||||||
|                         只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" |                         只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。""" | ||||||
|              |              | ||||||
|             response = self.chat_bot.chat(prompt) |             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|              |              | ||||||
|             # 提取数值 |             # 提取数值 | ||||||
|             odds_str = self._extract_numeric_value_from_response(response) |             odds_str = TextProcessor.extract_numeric_value_from_response(response) | ||||||
|              |              | ||||||
|             # 尝试将响应转换为整数 |             # 尝试将响应转换为整数 | ||||||
|             try: |             try: | ||||||
|  | @ -1200,8 +1235,8 @@ class FundamentalAnalyzer: | ||||||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" |                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||||
|              |              | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             response = self.offline_bot_tl_qw.chat(prompt) |             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|             pe_hist_str = self._clean_model_output(response) |             pe_hist_str =  TextProcessor.clean_thought_process(response) | ||||||
|              |              | ||||||
|             try: |             try: | ||||||
|                 pe_hist = int(pe_hist_str) |                 pe_hist = int(pe_hist_str) | ||||||
|  | @ -1240,8 +1275,8 @@ class FundamentalAnalyzer: | ||||||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" |                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||||
|              |              | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             response = self.offline_bot_tl_qw.chat(prompt) |             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|             pb_hist_str = self._clean_model_output(response) |             pb_hist_str = TextProcessor.clean_thought_process(response) | ||||||
|              |              | ||||||
|             try: |             try: | ||||||
|                 pb_hist = int(pb_hist_str) |                 pb_hist = int(pb_hist_str) | ||||||
|  | @ -1280,8 +1315,8 @@ class FundamentalAnalyzer: | ||||||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" |                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||||
|              |              | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             response = self.offline_bot_tl_qw.chat(prompt) |             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|             pe_ind_str = self._clean_model_output(response) |             pe_ind_str = TextProcessor.clean_thought_process(response) | ||||||
|              |              | ||||||
|             try: |             try: | ||||||
|                 pe_ind = int(pe_ind_str) |                 pe_ind = int(pe_ind_str) | ||||||
|  | @ -1320,8 +1355,8 @@ class FundamentalAnalyzer: | ||||||
|                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" |                     只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。""" | ||||||
|              |              | ||||||
|             self.offline_bot_tl_qw.clear_history() |             self.offline_bot_tl_qw.clear_history() | ||||||
|             response = self.offline_bot_tl_qw.chat(prompt) |             response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|             pb_ind_str = self._clean_model_output(response) |             pb_ind_str = TextProcessor.clean_thought_process(response) | ||||||
|              |              | ||||||
|             try: |             try: | ||||||
|                 pb_ind = int(pb_ind_str) |                 pb_ind = int(pb_ind_str) | ||||||
|  | @ -1373,9 +1408,9 @@ class FundamentalAnalyzer: | ||||||
|             {json.dumps(all_results, ensure_ascii=False, indent=2)}""" |             {json.dumps(all_results, ensure_ascii=False, indent=2)}""" | ||||||
|             self.offline_bot.clear_history() |             self.offline_bot.clear_history() | ||||||
|             # 使用离线模型生成建议 |             # 使用离线模型生成建议 | ||||||
|             result = self.offline_bot.chat(prompt) |             result = self.offline_bot.chat(prompt,max_tokens=20000) | ||||||
|             # 清理模型输出 |             # 清理模型输出 | ||||||
|             result = self._clean_model_output(result) |             result = TextProcessor.clean_thought_process(result) | ||||||
|             # 保存到数据库 |             # 保存到数据库 | ||||||
|             success = save_analysis_result( |             success = save_analysis_result( | ||||||
|                 self.db, |                 self.db, | ||||||
|  | @ -1441,93 +1476,6 @@ class FundamentalAnalyzer: | ||||||
|             logger.error(f"提取投资建议类型失败: {str(e)}") |             logger.error(f"提取投资建议类型失败: {str(e)}") | ||||||
|             return None |             return None | ||||||
|          |          | ||||||
|     def _clean_model_output(self, output: str) -> str: |  | ||||||
|         """清理模型输出,移除推理过程,只保留最终结果 |  | ||||||
|          |  | ||||||
|         Args: |  | ||||||
|             output: 模型原始输出文本 |  | ||||||
|              |  | ||||||
|         Returns: |  | ||||||
|             str: 清理后的输出文本 |  | ||||||
|         """ |  | ||||||
|         try: |  | ||||||
|             # 找到</think>标签的位置 |  | ||||||
|             think_end = output.find('</think>') |  | ||||||
|             if think_end != -1: |  | ||||||
|                 # 移除</think>标签及其之前的所有内容 |  | ||||||
|                 output = output[think_end + len('</think>'):] |  | ||||||
|              |  | ||||||
|             # 处理可能存在的空行 |  | ||||||
|             lines = output.split('\n') |  | ||||||
|             cleaned_lines = [] |  | ||||||
|             for line in lines: |  | ||||||
|                 line = line.strip() |  | ||||||
|                 if line:  # 只保留非空行 |  | ||||||
|                     cleaned_lines.append(line) |  | ||||||
|              |  | ||||||
|             # 重新组合文本 |  | ||||||
|             output = '\n'.join(cleaned_lines) |  | ||||||
|              |  | ||||||
|             return output.strip() |  | ||||||
|              |  | ||||||
|         except Exception as e: |  | ||||||
|             logger.error(f"清理模型输出失败: {str(e)}") |  | ||||||
|             return output.strip() |  | ||||||
| 
 |  | ||||||
|     def _extract_numeric_value_from_response(self, response: str) -> str: |  | ||||||
|         """从模型响应中提取数值,移除参考资料和推理过程 |  | ||||||
|          |  | ||||||
|         Args: |  | ||||||
|             response: 模型原始响应文本或响应对象 |  | ||||||
|              |  | ||||||
|         Returns: |  | ||||||
|             str: 提取的数值字符串 |  | ||||||
|         """ |  | ||||||
|         try: |  | ||||||
|             # 处理响应对象(包含response字段的字典) |  | ||||||
|             if isinstance(response, dict) and "response" in response: |  | ||||||
|                 response = response["response"] |  | ||||||
|                  |  | ||||||
|             # 确保响应是字符串 |  | ||||||
|             if not isinstance(response, str): |  | ||||||
|                 logger.warning(f"响应不是字符串类型: {type(response)}") |  | ||||||
|                 return "0" |  | ||||||
|                  |  | ||||||
|             # 移除推理过程部分 |  | ||||||
|             reasoning_start = response.find("推理过程:") |  | ||||||
|             if reasoning_start != -1: |  | ||||||
|                 response = response[:reasoning_start].strip() |  | ||||||
|                  |  | ||||||
|             # 移除参考资料部分(通常以 [数字] 开头的行) |  | ||||||
|             lines = response.split("\n") |  | ||||||
|             cleaned_lines = [] |  | ||||||
|              |  | ||||||
|             for line in lines: |  | ||||||
|                 # 跳过参考资料行(通常以 [数字] 开头) |  | ||||||
|                 if re.match(r'\[\d+\]', line.strip()): |  | ||||||
|                     continue |  | ||||||
|                 cleaned_lines.append(line) |  | ||||||
|                  |  | ||||||
|             response = "\n".join(cleaned_lines).strip() |  | ||||||
|              |  | ||||||
|             # 提取数值 |  | ||||||
|             # 先尝试直接将整个响应转换为数值 |  | ||||||
|             if response.strip() in ["-2", "-1", "0", "1", "2"]: |  | ||||||
|                 return response.strip() |  | ||||||
|                  |  | ||||||
|             # 如果整个响应不是数值,尝试匹配第一个数值 |  | ||||||
|             match = re.search(r'([-]?[0-9])', response) |  | ||||||
|             if match: |  | ||||||
|                 return match.group(1) |  | ||||||
|                  |  | ||||||
|             # 如果没有找到数值,返回默认值 |  | ||||||
|             logger.warning(f"未能从响应中提取数值: {response}") |  | ||||||
|             return "0" |  | ||||||
|              |  | ||||||
|         except Exception as e: |  | ||||||
|             logger.error(f"从响应中提取数值失败: {str(e)}") |  | ||||||
|             return "0" |  | ||||||
| 
 |  | ||||||
|     def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]: |     def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]: | ||||||
|         """尝试多次从投资建议中提取建议类型 |         """尝试多次从投资建议中提取建议类型 | ||||||
|          |          | ||||||
|  | @ -1549,7 +1497,7 @@ class FundamentalAnalyzer: | ||||||
|                  |                  | ||||||
|                 # 使用千问离线模型提取建议类型 |                 # 使用千问离线模型提取建议类型 | ||||||
| 
 | 
 | ||||||
|                 result = self.offline_bot_tl_qw.chat(prompt) |                 result = self.offline_bot_tl_qw.chat(prompt,temperature=0.0) | ||||||
|                  |                  | ||||||
|                 # 检查是否是错误响应 |                 # 检查是否是错误响应 | ||||||
|                 if isinstance(result, str) and "抱歉,发生错误" in result: |                 if isinstance(result, str) and "抱歉,发生错误" in result: | ||||||
|  | @ -1557,7 +1505,7 @@ class FundamentalAnalyzer: | ||||||
|                     continue |                     continue | ||||||
|                  |                  | ||||||
|                 # 清理模型输出 |                 # 清理模型输出 | ||||||
|                 cleaned_result = self._clean_model_output(result) |                 cleaned_result = TextProcessor.clean_thought_process(result) | ||||||
|                  |                  | ||||||
|                 # 检查结果是否为有效类型 |                 # 检查结果是否为有效类型 | ||||||
|                 if cleaned_result in valid_types: |                 if cleaned_result in valid_types: | ||||||
|  | @ -1644,6 +1592,24 @@ class FundamentalAnalyzer: | ||||||
|             Optional[str]: 生成的PDF文件路径,如果失败则返回None |             Optional[str]: 生成的PDF文件路径,如果失败则返回None | ||||||
|         """ |         """ | ||||||
|         try: |         try: | ||||||
|  |             # 检查是否已存在PDF文件 | ||||||
|  |             reports_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'reports') | ||||||
|  |             os.makedirs(reports_dir, exist_ok=True) | ||||||
|  |              | ||||||
|  |             # 构建可能的文件名格式 | ||||||
|  |             possible_filenames = [ | ||||||
|  |                 f"{stock_name}_{stock_code}_analysis.pdf", | ||||||
|  |                 f"{stock_name}_{stock_code}.SZ_analysis.pdf", | ||||||
|  |                 f"{stock_name}_{stock_code}.SH_analysis.pdf" | ||||||
|  |             ] | ||||||
|  |              | ||||||
|  |             # 检查是否存在已生成的PDF文件 | ||||||
|  |             for filename in possible_filenames: | ||||||
|  |                 filepath = os.path.join(reports_dir, filename) | ||||||
|  |                 if os.path.exists(filepath): | ||||||
|  |                     logger.info(f"找到已存在的PDF报告: {filepath}") | ||||||
|  |                     return filepath | ||||||
|  |              | ||||||
|             # 维度名称映射 |             # 维度名称映射 | ||||||
|             dimension_names = { |             dimension_names = { | ||||||
|                 "company_profile": "公司简介", |                 "company_profile": "公司简介", | ||||||
|  | @ -1669,27 +1635,156 @@ class FundamentalAnalyzer: | ||||||
|                 logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果") |                 logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果") | ||||||
|                 return None |                 return None | ||||||
|              |              | ||||||
|  |             # 确保字体目录存在 | ||||||
|  |             fonts_dir = os.path.join(os.path.dirname(__file__), "fonts") | ||||||
|  |             os.makedirs(fonts_dir, exist_ok=True) | ||||||
|  |              | ||||||
|  |             # 检查是否存在字体文件,如果不存在则创建简单的默认字体标记文件 | ||||||
|  |             font_path = os.path.join(fonts_dir, "simhei.ttf") | ||||||
|  |             if not os.path.exists(font_path): | ||||||
|  |                 # 尝试从系统字体目录复制 | ||||||
|  |                 try: | ||||||
|  |                     import shutil | ||||||
|  |                     # 尝试常见的系统字体位置 | ||||||
|  |                     system_fonts = [ | ||||||
|  |                         "C:/Windows/Fonts/simhei.ttf",  # Windows | ||||||
|  |                         "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",  # Linux | ||||||
|  |                         "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc",  # 其他Linux | ||||||
|  |                         "/System/Library/Fonts/PingFang.ttc"  # macOS | ||||||
|  |                     ] | ||||||
|  |                      | ||||||
|  |                     for system_font in system_fonts: | ||||||
|  |                         if os.path.exists(system_font): | ||||||
|  |                             shutil.copy2(system_font, font_path) | ||||||
|  |                             logger.info(f"已复制字体文件: {system_font} -> {font_path}") | ||||||
|  |                             break | ||||||
|  |                 except Exception as font_error: | ||||||
|  |                     logger.warning(f"复制字体文件失败: {str(font_error)}") | ||||||
|  |              | ||||||
|             # 创建PDF生成器实例 |             # 创建PDF生成器实例 | ||||||
|             generator = PDFGenerator() |             generator = PDFGenerator() | ||||||
|              |              | ||||||
|             # 生成PDF报告 |             # 生成PDF报告 | ||||||
|  |             try: | ||||||
|  |                 # 第一次尝试生成 | ||||||
|                 filepath = generator.generate_pdf( |                 filepath = generator.generate_pdf( | ||||||
|                     title=f"{stock_name}({stock_code}) 基本面分析报告", |                     title=f"{stock_name}({stock_code}) 基本面分析报告", | ||||||
|                     content_dict=content_dict, |                     content_dict=content_dict, | ||||||
|  |                     output_dir=reports_dir, | ||||||
|                     filename=f"{stock_name}_{stock_code}_analysis.pdf" |                     filename=f"{stock_name}_{stock_code}_analysis.pdf" | ||||||
|                 ) |                 ) | ||||||
|                  |                  | ||||||
|                 if filepath: |                 if filepath: | ||||||
|                     logger.info(f"PDF报告已生成: {filepath}") |                     logger.info(f"PDF报告已生成: {filepath}") | ||||||
|             else: |  | ||||||
|                 logger.error("PDF报告生成失败") |  | ||||||
|              |  | ||||||
|                     return filepath |                     return filepath | ||||||
|  |             except Exception as pdf_error: | ||||||
|  |                 logger.error(f"生成PDF报告第一次尝试失败: {str(pdf_error)}") | ||||||
|  |                  | ||||||
|  |                 # 如果是字体问题,可能需要使用备选方案 | ||||||
|  |                 if "找不到中文字体文件" in str(pdf_error): | ||||||
|  |                     # 导出为文本文件作为备选 | ||||||
|  |                     try: | ||||||
|  |                         txt_filename = f"{stock_name}_{stock_code}_analysis.txt" | ||||||
|  |                         txt_filepath = os.path.join(reports_dir, txt_filename) | ||||||
|  |                          | ||||||
|  |                         with open(txt_filepath, 'w', encoding='utf-8') as f: | ||||||
|  |                             f.write(f"{stock_name}({stock_code}) 基本面分析报告\n") | ||||||
|  |                             f.write(f"生成时间:{datetime.now().strftime('%Y年%m月%d日 %H:%M:%S')}\n\n") | ||||||
|  |                              | ||||||
|  |                             for section_title, content in content_dict.items(): | ||||||
|  |                                 if content: | ||||||
|  |                                     f.write(f"## {section_title}\n\n") | ||||||
|  |                                     f.write(f"{content}\n\n") | ||||||
|  |                          | ||||||
|  |                         logger.info(f"由于PDF生成失败,已生成文本报告: {txt_filepath}") | ||||||
|  |                         return txt_filepath | ||||||
|  |                     except Exception as txt_error: | ||||||
|  |                         logger.error(f"生成文本报告失败: {str(txt_error)}") | ||||||
|  |              | ||||||
|  |             logger.error("PDF报告生成失败") | ||||||
|  |             return None | ||||||
|              |              | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.error(f"生成PDF报告失败: {str(e)}") |             logger.error(f"生成PDF报告失败: {str(e)}") | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|  |     def is_stock_locked(self, stock_code: str, dimension: str) -> bool: | ||||||
|  |         """检查股票是否已被锁定 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             stock_code: 股票代码 | ||||||
|  |             dimension: 分析维度 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             bool: 是否被锁定 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" | ||||||
|  |              | ||||||
|  |             # 检查是否存在锁 | ||||||
|  |             existing_lock = redis_client.get(lock_key) | ||||||
|  |             if existing_lock: | ||||||
|  |                 lock_time = int(existing_lock) | ||||||
|  |                 current_time = int(time.time()) | ||||||
|  |                  | ||||||
|  |                 # 锁超过30分钟(1800秒)视为过期 | ||||||
|  |                 if current_time - lock_time > 1800: | ||||||
|  |                     # 锁已过期,可以释放 | ||||||
|  |                     redis_client.delete(lock_key) | ||||||
|  |                     logger.info(f"股票 {stock_code} 维度 {dimension} 的过期锁已释放") | ||||||
|  |                     return False | ||||||
|  |                 else: | ||||||
|  |                     # 锁未过期,被锁定 | ||||||
|  |                     return True | ||||||
|  |              | ||||||
|  |             # 不存在锁 | ||||||
|  |             return False | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"检查股票 {stock_code} 锁状态时出错: {str(e)}") | ||||||
|  |             # 出错时保守处理,返回未锁定 | ||||||
|  |             return False | ||||||
|  |      | ||||||
|  |     def lock_stock(self, stock_code: str, dimension: str) -> bool: | ||||||
|  |         """锁定股票 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             stock_code: 股票代码 | ||||||
|  |             dimension: 分析维度 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             bool: 是否成功锁定 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" | ||||||
|  |             current_time = int(time.time()) | ||||||
|  |              | ||||||
|  |             # 设置锁,过期时间1小时 | ||||||
|  |             redis_client.set(lock_key, current_time, ex=3600) | ||||||
|  |             logger.info(f"股票 {stock_code} 维度 {dimension} 已锁定") | ||||||
|  |             return True | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"锁定股票 {stock_code} 时出错: {str(e)}") | ||||||
|  |             return False | ||||||
|  |      | ||||||
|  |     def unlock_stock(self, stock_code: str, dimension: str) -> bool: | ||||||
|  |         """解锁股票 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             stock_code: 股票代码 | ||||||
|  |             dimension: 分析维度 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             bool: 是否成功解锁 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             lock_key = f"stock_analysis_lock:{stock_code}:{dimension}" | ||||||
|  |             redis_client.delete(lock_key) | ||||||
|  |             logger.info(f"股票 {stock_code} 维度 {dimension} 已解锁") | ||||||
|  |             return True | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"解锁股票 {stock_code} 时出错: {str(e)}") | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
| def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool: | def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool: | ||||||
|     """测试单个分析方法""" |     """测试单个分析方法""" | ||||||
|     try: |     try: | ||||||
|  | @ -1713,11 +1808,6 @@ def test_single_stock(analyzer: FundamentalAnalyzer, stock_code: str, stock_name | ||||||
| 
 | 
 | ||||||
| def main(): | def main(): | ||||||
|     """主函数""" |     """主函数""" | ||||||
|     # 设置日志级别 |  | ||||||
|     logging.basicConfig( |  | ||||||
|         level=logging.INFO, |  | ||||||
|         format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     # 测试股票列表 |     # 测试股票列表 | ||||||
|     test_stocks = [ |     test_stocks = [ | ||||||
|  |  | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -11,25 +11,25 @@ import markdown2 | ||||||
| from bs4 import BeautifulSoup | from bs4 import BeautifulSoup | ||||||
| import os | import os | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from fpdf import FPDF | import shutil | ||||||
| import matplotlib.pyplot as plt | 
 | ||||||
| import matplotlib | import matplotlib | ||||||
| matplotlib.use('Agg') | matplotlib.use('Agg') | ||||||
| 
 | 
 | ||||||
| # 修改导入路径,使用相对导入 | # 修改导入路径,使用相对导入 | ||||||
| try: | try: | ||||||
|     # 尝试相对导入 |     # 尝试相对导入 | ||||||
|     from .chat_bot_with_offline import ChatBot |     from .chat_bot import ChatBot | ||||||
|     from .fundamental_analysis_database import get_db, get_analysis_result |     from .chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
| except ImportError: | except ImportError: | ||||||
|     # 如果相对导入失败,尝试绝对导入 |     # 如果相对导入失败,尝试绝对导入 | ||||||
|     try: |     try: | ||||||
|         from src.fundamentals_llm.chat_bot_with_offline import ChatBot |         from src.fundamentals_llm.chat_bot import ChatBot | ||||||
|         from src.fundamentals_llm.fundamental_analysis_database import get_db, get_analysis_result |         from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
|     except ImportError: |     except ImportError: | ||||||
|         # 最后尝试直接导入 |         # 最后尝试直接导入 | ||||||
|         from chat_bot_with_offline import ChatBot |         from chat_bot import ChatBot | ||||||
|         from fundamental_analysis_database import get_db, get_analysis_result |         from chat_bot_with_offline import ChatBot as OfflineChatBot | ||||||
| 
 | 
 | ||||||
| # 设置日志记录 | # 设置日志记录 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | @ -39,30 +39,9 @@ class PDFGenerator: | ||||||
|      |      | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         """初始化PDF生成器""" |         """初始化PDF生成器""" | ||||||
|         # 注册中文字体 |         # 尝试注册中文字体 | ||||||
|         try: |         self.font_name = self._register_chinese_font() | ||||||
|             # 尝试使用系统自带的中文字体 |         self.chat_bot = OfflineChatBot(platform="volc", model_type="offline_model")  # 使用离线模式 | ||||||
|             if os.name == 'nt':  # Windows |  | ||||||
|                 font_path = "C:/Windows/Fonts/simhei.ttf"  # 黑体 |  | ||||||
|             else:  # Linux/Mac |  | ||||||
|                 font_path = "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf" |  | ||||||
|              |  | ||||||
|             if os.path.exists(font_path): |  | ||||||
|                 pdfmetrics.registerFont(TTFont('SimHei', font_path)) |  | ||||||
|                 self.font_name = 'SimHei' |  | ||||||
|             else: |  | ||||||
|                 # 如果找不到系统字体,尝试使用当前目录下的字体 |  | ||||||
|                 font_path = os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf") |  | ||||||
|                 if os.path.exists(font_path): |  | ||||||
|                     pdfmetrics.registerFont(TTFont('SimHei', font_path)) |  | ||||||
|                     self.font_name = 'SimHei' |  | ||||||
|                 else: |  | ||||||
|                     raise FileNotFoundError("找不到中文字体文件") |  | ||||||
|         except Exception as e: |  | ||||||
|             logger.error(f"注册中文字体失败: {str(e)}") |  | ||||||
|             raise |  | ||||||
|          |  | ||||||
|         self.chat_bot = ChatBot() |  | ||||||
|         self.styles = getSampleStyleSheet() |         self.styles = getSampleStyleSheet() | ||||||
|          |          | ||||||
|         # 创建自定义样式 |         # 创建自定义样式 | ||||||
|  | @ -139,6 +118,64 @@ class PDFGenerator: | ||||||
|             textColor=colors.HexColor('#333333') |             textColor=colors.HexColor('#333333') | ||||||
|         )) |         )) | ||||||
|      |      | ||||||
|  |     def _register_chinese_font(self): | ||||||
|  |         """查找并注册中文字体 | ||||||
|  |          | ||||||
|  |         Returns: | ||||||
|  |             str: 注册的字体名称,如果失败则使用默认字体 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 可能的字体文件位置列表 | ||||||
|  |             font_locations = [ | ||||||
|  |                 # 当前目录/子目录字体位置 | ||||||
|  |                 os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf"), | ||||||
|  |                 os.path.join(os.path.dirname(__file__), "fonts", "wqy-microhei.ttc"), | ||||||
|  |                 os.path.join(os.path.dirname(__file__), "fonts", "wqy-zenhei.ttc"), | ||||||
|  |                 # Docker中可能的位置 | ||||||
|  |                 "/app/src/fundamentals_llm/fonts/simhei.ttf", | ||||||
|  |                 # Windows系统字体位置 | ||||||
|  |                 "C:/Windows/Fonts/simhei.ttf", | ||||||
|  |                 "C:/Windows/Fonts/simfang.ttf", | ||||||
|  |                 "C:/Windows/Fonts/simsun.ttc", | ||||||
|  |                 # Linux/Mac系统字体位置 | ||||||
|  |                 "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", | ||||||
|  |                 "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", | ||||||
|  |                 "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", | ||||||
|  |                 "/usr/share/fonts/wqy-zenhei/wqy-zenhei.ttc", | ||||||
|  |                 "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf", | ||||||
|  |                 "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", | ||||||
|  |                 "/System/Library/Fonts/PingFang.ttc"  # macOS | ||||||
|  |             ] | ||||||
|  |              | ||||||
|  |             # 尝试各个位置 | ||||||
|  |             for font_path in font_locations: | ||||||
|  |                 if os.path.exists(font_path): | ||||||
|  |                     logger.info(f"使用字体文件: {font_path}") | ||||||
|  |                     # 为防止字体文件名称不同,统一拷贝到字体目录并重命名 | ||||||
|  |                     font_dir = os.path.join(os.path.dirname(__file__), "fonts") | ||||||
|  |                     os.makedirs(font_dir, exist_ok=True) | ||||||
|  |                     target_path = os.path.join(font_dir, "simhei.ttf") | ||||||
|  |                      | ||||||
|  |                     # 如果字体文件不在目标位置,拷贝过去 | ||||||
|  |                     if font_path != target_path and not os.path.exists(target_path): | ||||||
|  |                         try: | ||||||
|  |                             shutil.copy2(font_path, target_path) | ||||||
|  |                             logger.info(f"已将字体文件复制到: {target_path}") | ||||||
|  |                         except Exception as e: | ||||||
|  |                             logger.warning(f"复制字体文件失败: {str(e)}") | ||||||
|  |                      | ||||||
|  |                     # 注册字体 | ||||||
|  |                     pdfmetrics.registerFont(TTFont('SimHei', target_path)) | ||||||
|  |                     return 'SimHei' | ||||||
|  |              | ||||||
|  |             # 如果所有位置都找不到,使用默认字体 | ||||||
|  |             logger.warning("找不到中文字体文件,将使用默认字体") | ||||||
|  |             return 'Helvetica' | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"注册中文字体失败: {str(e)}") | ||||||
|  |             return 'Helvetica'  # 使用默认字体 | ||||||
|  |      | ||||||
|     def _convert_markdown_to_flowables(self, markdown_text: str) -> List: |     def _convert_markdown_to_flowables(self, markdown_text: str) -> List: | ||||||
|         """将Markdown文本转换为PDF流对象 |         """将Markdown文本转换为PDF流对象 | ||||||
|          |          | ||||||
|  |  | ||||||
|  | @ -0,0 +1,190 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | """ | ||||||
|  | 中文字体安装工具 | ||||||
|  | 
 | ||||||
|  | 此脚本用于下载和安装中文字体,以支持PDF生成功能。 | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | import logging | ||||||
|  | import shutil | ||||||
|  | import tempfile | ||||||
|  | import platform | ||||||
|  | 
 | ||||||
|  | # 配置日志 | ||||||
|  | logging.basicConfig( | ||||||
|  |     level=logging.INFO, | ||||||
|  |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||||
|  | ) | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | def download_wqy_font(): | ||||||
|  |     """ | ||||||
|  |     下载文泉驿微米黑字体 | ||||||
|  |      | ||||||
|  |     Returns: | ||||||
|  |         str: 下载的字体文件路径,如果失败则返回None | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         import requests | ||||||
|  |          | ||||||
|  |         # 创建临时目录 | ||||||
|  |         temp_dir = tempfile.mkdtemp() | ||||||
|  |         font_file = os.path.join(temp_dir, "wqy-microhei.ttc") | ||||||
|  |          | ||||||
|  |         # 文泉驿微米黑字体的下载链接 | ||||||
|  |         url = "https://mirrors.tuna.tsinghua.edu.cn/osdn/wqy/wqy-microhei-0.2.0-beta.tar.gz" | ||||||
|  |          | ||||||
|  |         logger.info(f"开始下载文泉驿微米黑字体: {url}") | ||||||
|  |         response = requests.get(url, stream=True) | ||||||
|  |         response.raise_for_status() | ||||||
|  |          | ||||||
|  |         # 下载压缩包 | ||||||
|  |         tar_file = os.path.join(temp_dir, "wqy-microhei.tar.gz") | ||||||
|  |         with open(tar_file, 'wb') as f: | ||||||
|  |             for chunk in response.iter_content(chunk_size=8192): | ||||||
|  |                 f.write(chunk) | ||||||
|  |          | ||||||
|  |         # 解压字体文件 | ||||||
|  |         import tarfile | ||||||
|  |         with tarfile.open(tar_file) as tar: | ||||||
|  |             # 提取字体文件 | ||||||
|  |             for member in tar.getmembers(): | ||||||
|  |                 if member.name.endswith(('.ttc', '.ttf')): | ||||||
|  |                     member.name = os.path.basename(member.name) | ||||||
|  |                     tar.extract(member, temp_dir) | ||||||
|  |                     extracted_file = os.path.join(temp_dir, member.name) | ||||||
|  |                     if os.path.exists(extracted_file): | ||||||
|  |                         return extracted_file | ||||||
|  |          | ||||||
|  |         logger.error("在压缩包中未找到字体文件") | ||||||
|  |         return None | ||||||
|  |      | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"下载字体失败: {str(e)}") | ||||||
|  |         return None | ||||||
|  |      | ||||||
|  | def find_system_fonts(): | ||||||
|  |     """ | ||||||
|  |     在系统中查找可用的中文字体 | ||||||
|  |      | ||||||
|  |     Returns: | ||||||
|  |         str: 找到的字体文件路径,如果找不到则返回None | ||||||
|  |     """ | ||||||
|  |     # 常见的中文字体位置 | ||||||
|  |     font_locations = [] | ||||||
|  |      | ||||||
|  |     system = platform.system() | ||||||
|  |     if system == "Windows": | ||||||
|  |         font_locations = [ | ||||||
|  |             "C:/Windows/Fonts/simhei.ttf", | ||||||
|  |             "C:/Windows/Fonts/simsun.ttc", | ||||||
|  |             "C:/Windows/Fonts/msyh.ttf", | ||||||
|  |             "C:/Windows/Fonts/simfang.ttf", | ||||||
|  |         ] | ||||||
|  |     elif system == "Darwin":  # macOS | ||||||
|  |         font_locations = [ | ||||||
|  |             "/System/Library/Fonts/PingFang.ttc", | ||||||
|  |             "/Library/Fonts/Microsoft/SimHei.ttf", | ||||||
|  |             "/Library/Fonts/Microsoft/SimSun.ttf", | ||||||
|  |         ] | ||||||
|  |     else:  # Linux | ||||||
|  |         font_locations = [ | ||||||
|  |             "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", | ||||||
|  |             "/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", | ||||||
|  |             "/usr/share/fonts/wenquanyi/wqy-microhei/wqy-microhei.ttc", | ||||||
|  |             "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", | ||||||
|  |             "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf", | ||||||
|  |         ] | ||||||
|  |      | ||||||
|  |     # 查找第一个存在的字体文件 | ||||||
|  |     for font_path in font_locations: | ||||||
|  |         if os.path.exists(font_path): | ||||||
|  |             logger.info(f"在系统中找到中文字体: {font_path}") | ||||||
|  |             return font_path | ||||||
|  |      | ||||||
|  |     logger.warning("在系统中未找到任何中文字体") | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  | def install_font(target_dir=None): | ||||||
|  |     """ | ||||||
|  |     安装中文字体 | ||||||
|  |      | ||||||
|  |     Args: | ||||||
|  |         target_dir: 目标目录,如果不提供则使用默认目录 | ||||||
|  |          | ||||||
|  |     Returns: | ||||||
|  |         str: 安装后的字体文件路径,如果失败则返回None | ||||||
|  |     """ | ||||||
|  |     # 如果没有指定目标目录,使用默认位置 | ||||||
|  |     if target_dir is None: | ||||||
|  |         # 获取当前脚本所在目录 | ||||||
|  |         current_dir = os.path.dirname(os.path.abspath(__file__)) | ||||||
|  |         target_dir = os.path.join(current_dir, "fonts") | ||||||
|  |      | ||||||
|  |     # 确保目标目录存在 | ||||||
|  |     os.makedirs(target_dir, exist_ok=True) | ||||||
|  |      | ||||||
|  |     # 目标字体文件路径 | ||||||
|  |     target_font = os.path.join(target_dir, "simhei.ttf") | ||||||
|  |      | ||||||
|  |     # 如果目标字体已存在,直接返回 | ||||||
|  |     if os.path.exists(target_font): | ||||||
|  |         logger.info(f"中文字体已存在: {target_font}") | ||||||
|  |         return target_font | ||||||
|  |      | ||||||
|  |     # 尝试从系统中复制字体 | ||||||
|  |     system_font = find_system_fonts() | ||||||
|  |     if system_font: | ||||||
|  |         try: | ||||||
|  |             shutil.copy2(system_font, target_font) | ||||||
|  |             logger.info(f"已从系统复制字体: {system_font} -> {target_font}") | ||||||
|  |             return target_font | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"复制系统字体失败: {str(e)}") | ||||||
|  |      | ||||||
|  |     # 如果系统中找不到字体,尝试下载 | ||||||
|  |     logger.info("尝试下载中文字体...") | ||||||
|  |     downloaded_font = download_wqy_font() | ||||||
|  |     if downloaded_font: | ||||||
|  |         try: | ||||||
|  |             shutil.copy2(downloaded_font, target_font) | ||||||
|  |             logger.info(f"已安装下载的字体: {downloaded_font} -> {target_font}") | ||||||
|  |             return target_font | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"复制下载的字体失败: {str(e)}") | ||||||
|  |      | ||||||
|  |     logger.error("安装中文字体失败") | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     """ | ||||||
|  |     主函数 | ||||||
|  |     """ | ||||||
|  |     print("中文字体安装工具") | ||||||
|  |     print("----------------") | ||||||
|  |      | ||||||
|  |     target_dir = None | ||||||
|  |     if len(sys.argv) > 1: | ||||||
|  |         target_dir = sys.argv[1] | ||||||
|  |         print(f"使用指定的目标目录: {target_dir}") | ||||||
|  |      | ||||||
|  |     result = install_font(target_dir) | ||||||
|  |      | ||||||
|  |     if result: | ||||||
|  |         print(f"\n✓ 成功安装中文字体: {result}") | ||||||
|  |         print("\n现在可以正常生成PDF报告了。") | ||||||
|  |     else: | ||||||
|  |         print("\n✗ 安装中文字体失败") | ||||||
|  |         print("\n请手动将中文字体(.ttf或.ttc文件)复制到以下目录:") | ||||||
|  |         if target_dir: | ||||||
|  |             print(f"  {target_dir}") | ||||||
|  |         else: | ||||||
|  |             print(f"  {os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fonts')}") | ||||||
|  |         print("\n并将其重命名为'simhei.ttf'") | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main()  | ||||||
|  | @ -0,0 +1,172 @@ | ||||||
|  | import re | ||||||
|  | import logging | ||||||
|  | from typing import Optional, List, Dict, Any | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | class TextProcessor: | ||||||
|  |     """大模型文本处理工具类""" | ||||||
|  |      | ||||||
|  |     @staticmethod | ||||||
|  |     def clean_model_output(output: str) -> str: | ||||||
|  |         """清理模型输出文本 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             output: 模型输出的原始文本 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             str: 清理后的文本 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 移除引用标记 | ||||||
|  |             output = re.sub(r'\[[^\]]*\]', '', output) | ||||||
|  |              | ||||||
|  |             # 移除多余的空白字符 | ||||||
|  |             output = re.sub(r'\s+', ' ', output) | ||||||
|  |              | ||||||
|  |             # 移除首尾的空白字符 | ||||||
|  |             output = output.strip() | ||||||
|  |              | ||||||
|  |             # 移除可能存在的HTML标签 | ||||||
|  |             output = re.sub(r'<[^>]+>', '', output) | ||||||
|  |              | ||||||
|  |             # 移除可能存在的Markdown格式 | ||||||
|  |             output = re.sub(r'[*_`~]', '', output) | ||||||
|  |              | ||||||
|  |             return output | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"清理模型输出失败: {str(e)}") | ||||||
|  |             return output | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def clean_thought_process(output: str) -> str: | ||||||
|  |         """清理模型输出,移除推理过程,只保留最终结果 | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             output: 模型原始输出文本 | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             str: 清理后的输出文本 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 找到</think>标签的位置 | ||||||
|  |             think_end = output.find('</think>') | ||||||
|  |             if think_end != -1: | ||||||
|  |                 # 移除</think>标签及其之前的所有内容 | ||||||
|  |                 output = output[think_end + len('</think>'):] | ||||||
|  | 
 | ||||||
|  |             # 处理可能存在的空行 | ||||||
|  |             lines = output.split('\n') | ||||||
|  |             cleaned_lines = [] | ||||||
|  |             for line in lines: | ||||||
|  |                 line = line.strip() | ||||||
|  |                 if line:  # 只保留非空行 | ||||||
|  |                     cleaned_lines.append(line) | ||||||
|  | 
 | ||||||
|  |             # 重新组合文本 | ||||||
|  |             output = '\n'.join(cleaned_lines) | ||||||
|  | 
 | ||||||
|  |             return output.strip() | ||||||
|  | 
 | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"清理模型输出失败: {str(e)}") | ||||||
|  |             return output.strip() | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def extract_numeric_value_from_response(response: str) -> str: | ||||||
|  |         """从模型响应中提取数值 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             response: 模型响应的文本 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             str: 提取出的数值 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 清理响应文本 | ||||||
|  |             cleaned_response = TextProcessor.clean_model_output(response) | ||||||
|  |              | ||||||
|  |             # 尝试提取数值 | ||||||
|  |             numeric_patterns = [ | ||||||
|  |                 r'[-+]?\d*\.\d+',  # 浮点数 | ||||||
|  |                 r'[-+]?\d+',       # 整数 | ||||||
|  |                 r'[-+]?\d+\.\d*'   # 可能带小数点的数 | ||||||
|  |             ] | ||||||
|  |              | ||||||
|  |             for pattern in numeric_patterns: | ||||||
|  |                 matches = re.findall(pattern, cleaned_response) | ||||||
|  |                 if matches: | ||||||
|  |                     # 返回第一个匹配的数值 | ||||||
|  |                     return matches[0] | ||||||
|  |              | ||||||
|  |             # 如果没有找到数值,返回原始响应 | ||||||
|  |             return cleaned_response | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"提取数值失败: {str(e)}") | ||||||
|  |             return response | ||||||
|  |      | ||||||
|  |     @staticmethod | ||||||
|  |     def extract_list_from_response(response: str) -> List[str]: | ||||||
|  |         """从模型响应中提取列表项 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             response: 模型响应的文本 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             List[str]: 提取出的列表项 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 清理响应文本 | ||||||
|  |             cleaned_response = TextProcessor.clean_model_output(response) | ||||||
|  |              | ||||||
|  |             # 尝试匹配列表项 | ||||||
|  |             # 匹配数字编号的列表项 | ||||||
|  |             numbered_items = re.findall(r'\d+\.\s*([^\n]+)', cleaned_response) | ||||||
|  |             if numbered_items: | ||||||
|  |                 return [item.strip() for item in numbered_items] | ||||||
|  |              | ||||||
|  |             # 匹配带符号的列表项 | ||||||
|  |             bullet_items = re.findall(r'[-*•]\s*([^\n]+)', cleaned_response) | ||||||
|  |             if bullet_items: | ||||||
|  |                 return [item.strip() for item in bullet_items] | ||||||
|  |              | ||||||
|  |             # 如果没有找到列表项,返回空列表 | ||||||
|  |             return [] | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"提取列表项失败: {str(e)}") | ||||||
|  |             return [] | ||||||
|  |      | ||||||
|  |     @staticmethod | ||||||
|  |     def extract_key_value_pairs(response: str) -> Dict[str, str]: | ||||||
|  |         """从模型响应中提取键值对 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             response: 模型响应的文本 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             Dict[str, str]: 提取出的键值对 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 清理响应文本 | ||||||
|  |             cleaned_response = TextProcessor.clean_model_output(response) | ||||||
|  |              | ||||||
|  |             # 尝试匹配键值对 | ||||||
|  |             # 匹配冒号分隔的键值对 | ||||||
|  |             pairs = re.findall(r'([^:]+):\s*([^\n]+)', cleaned_response) | ||||||
|  |             if pairs: | ||||||
|  |                 return {key.strip(): value.strip() for key, value in pairs} | ||||||
|  |              | ||||||
|  |             # 匹配等号分隔的键值对 | ||||||
|  |             pairs = re.findall(r'([^=]+)=\s*([^\n]+)', cleaned_response) | ||||||
|  |             if pairs: | ||||||
|  |                 return {key.strip(): value.strip() for key, value in pairs} | ||||||
|  |              | ||||||
|  |             # 如果没有找到键值对,返回空字典 | ||||||
|  |             return {} | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"提取键值对失败: {str(e)}") | ||||||
|  |             return {}  | ||||||
|  | @ -66,12 +66,22 @@ MODEL_CONFIGS = { | ||||||
|             "doubao": "doubao-1-5-pro-32k-250115" |             "doubao": "doubao-1-5-pro-32k-250115" | ||||||
|         } |         } | ||||||
|     }, |     }, | ||||||
|  |     # 谷歌Gemini | ||||||
|  |     "Gemini": { | ||||||
|  |         "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", | ||||||
|  |         "api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU", | ||||||
|  |         "models": { | ||||||
|  |             "offline_model": "gemini-2.0-flash" | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|     # 天链苹果 |     # 天链苹果 | ||||||
|     "tl_private": { |     "tl_private": { | ||||||
|         "base_url": "http://192.168.32.118:1234/v1/", |         "base_url": "http://192.168.16.174:1234/v1/", | ||||||
|         "api_key": "none", |         "api_key": "none", | ||||||
|         "models": { |         "models": { | ||||||
|             "ds-v1": "mlx-community/DeepSeek-R1-4bit" |             "glm-z1": "glm-z1-rumination-32b-0414", | ||||||
|  |             "glm-4": "glm-4-32b-0414-abliterated", | ||||||
|  |             "ds_v1": "mlx-community/DeepSeek-R1-4bit", | ||||||
|         } |         } | ||||||
|     }, |     }, | ||||||
|     # 天链-千问 |     # 天链-千问 | ||||||
|  | @ -79,7 +89,8 @@ MODEL_CONFIGS = { | ||||||
|         "base_url": "http://192.168.16.178:11434/v1", |         "base_url": "http://192.168.16.178:11434/v1", | ||||||
|         "api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019", |         "api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019", | ||||||
|         "models": { |         "models": { | ||||||
|             "qwq": "qwq:32b" |             "qwq": "qwq:32b", | ||||||
|  |             "GLM": "hf-mirror.com/Cobra4687/GLM-4-32B-0414-abliterated-Q4_K_M-GGUF:Q4_K_M" | ||||||
|         } |         } | ||||||
|     }, |     }, | ||||||
|     # Deepseek配置 |     # Deepseek配置 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,986 @@ | ||||||
|  | import pandas as pd | ||||||
|  | import numpy as np | ||||||
|  | from sqlalchemy import create_engine | ||||||
|  | from datetime import datetime, timedelta | ||||||
|  | from tqdm import tqdm | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | # v2版本只做表格 | ||||||
|  | 
 | ||||||
|  | # 添加中文字体支持 | ||||||
|  | plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签 | ||||||
|  | plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号 | ||||||
|  | 
 | ||||||
|  | class StockAnalyzer: | ||||||
|  |     def __init__(self, db_connection_string): | ||||||
|  |         """ | ||||||
|  |         Initialize the stock analyzer | ||||||
|  |         :param db_connection_string: Database connection string | ||||||
|  |         """ | ||||||
|  |         self.engine = create_engine(db_connection_string) | ||||||
|  |         self.initial_capital = 300000  # 30万本金 | ||||||
|  |         self.total_position = 1000000  # 100万建仓规模 | ||||||
|  |         self.borrowed_capital = 700000  # 70万借入资金 | ||||||
|  |         self.annual_interest_rate = 0.05  # 年化5%利息 | ||||||
|  |         self.daily_interest_rate = self.annual_interest_rate / 365 | ||||||
|  |         self.commission_rate = 0.05  # 5%手续费 | ||||||
|  |         self.min_holding_days = 7  # 最少持仓天数 | ||||||
|  | 
 | ||||||
|  |     def get_stock_list(self): | ||||||
|  |         """获取所有股票列表""" | ||||||
|  |         # query = "SELECT DISTINCT gp_code as symbol FROM gp_code_all where mark1 = '1'" | ||||||
|  |         # return pd.read_sql(query, self.engine)['symbol'].tolist() | ||||||
|  |         return ['SH600522', | ||||||
|  |                 'SZ002340', | ||||||
|  |                 'SZ000733', | ||||||
|  |                 'SH601615', | ||||||
|  |                 'SH600157', | ||||||
|  |                 'SH688005', | ||||||
|  |                 'SH600903', | ||||||
|  |                 'SH600956', | ||||||
|  |                 'SH601187', | ||||||
|  |                 'SH603983']  # 你可以在这里添加更多股票代码 | ||||||
|  | 
 | ||||||
|  |     def get_stock_data(self, symbol, start_date, end_date, include_history=False): | ||||||
|  |         """获取指定股票在时间范围内的数据""" | ||||||
|  |         try: | ||||||
|  |             if include_history: | ||||||
|  |                 # 如果需要历史数据,获取start_date之前240天的数据 | ||||||
|  |                 query = f""" | ||||||
|  |                 WITH history_data AS ( | ||||||
|  |                     SELECT  | ||||||
|  |                         symbol, | ||||||
|  |                         timestamp, | ||||||
|  |                         CAST(close as DECIMAL(10,2)) as close | ||||||
|  |                     FROM gp_day_data  | ||||||
|  |                     WHERE symbol = '{symbol}'  | ||||||
|  |                     AND timestamp < '{start_date}' | ||||||
|  |                     ORDER BY timestamp DESC | ||||||
|  |                     LIMIT 240 | ||||||
|  |                 ) | ||||||
|  |                 SELECT  | ||||||
|  |                     symbol, | ||||||
|  |                     timestamp, | ||||||
|  |                     CAST(open as DECIMAL(10,2)) as open, | ||||||
|  |                     CAST(high as DECIMAL(10,2)) as high, | ||||||
|  |                     CAST(low as DECIMAL(10,2)) as low, | ||||||
|  |                     CAST(close as DECIMAL(10,2)) as close, | ||||||
|  |                     CAST(volume as DECIMAL(20,0)) as volume | ||||||
|  |                 FROM gp_day_data  | ||||||
|  |                 WHERE symbol = '{symbol}'  | ||||||
|  |                 AND timestamp BETWEEN '{start_date}' AND '{end_date}' | ||||||
|  |                 UNION ALL | ||||||
|  |                 SELECT  | ||||||
|  |                     symbol, | ||||||
|  |                     timestamp, | ||||||
|  |                     NULL as open, | ||||||
|  |                     NULL as high, | ||||||
|  |                     NULL as low, | ||||||
|  |                     close, | ||||||
|  |                     NULL as volume | ||||||
|  |                 FROM history_data | ||||||
|  |                 ORDER BY timestamp | ||||||
|  |                 """ | ||||||
|  |             else: | ||||||
|  |                 query = f""" | ||||||
|  |                 SELECT  | ||||||
|  |                     symbol, | ||||||
|  |                     timestamp, | ||||||
|  |                     CAST(open as DECIMAL(10,2)) as open, | ||||||
|  |                     CAST(high as DECIMAL(10,2)) as high, | ||||||
|  |                     CAST(low as DECIMAL(10,2)) as low, | ||||||
|  |                     CAST(close as DECIMAL(10,2)) as close, | ||||||
|  |                     CAST(volume as DECIMAL(20,0)) as volume | ||||||
|  |                 FROM gp_day_data  | ||||||
|  |                 WHERE symbol = '{symbol}'  | ||||||
|  |                 AND timestamp BETWEEN '{start_date}' AND '{end_date}' | ||||||
|  |                 ORDER BY timestamp | ||||||
|  |                 """ | ||||||
|  |              | ||||||
|  |             # 使用chunksize分批读取数据 | ||||||
|  |             chunks = [] | ||||||
|  |             for chunk in pd.read_sql(query, self.engine, chunksize=1000): | ||||||
|  |                 chunks.append(chunk) | ||||||
|  |             df = pd.concat(chunks, ignore_index=True) | ||||||
|  |              | ||||||
|  |             # 确保数值列的类型正确 | ||||||
|  |             numeric_columns = ['open', 'high', 'low', 'close', 'volume'] | ||||||
|  |             for col in numeric_columns: | ||||||
|  |                 if col in df.columns: | ||||||
|  |                     df[col] = pd.to_numeric(df[col], errors='coerce') | ||||||
|  |              | ||||||
|  |             # 删除任何包含 NaN 的行,但只考虑分析期间需要的列 | ||||||
|  |             if include_history: | ||||||
|  |                 historical_mask = (df['timestamp'] < start_date) & df['close'].notna() | ||||||
|  |                 analysis_mask = (df['timestamp'] >= start_date) & df[['open', 'high', 'low', 'close', 'volume']].notna().all(axis=1) | ||||||
|  |                 df = df[historical_mask | analysis_mask] | ||||||
|  |             else: | ||||||
|  |                 df = df.dropna() | ||||||
|  |              | ||||||
|  |             return df | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"获取股票 {symbol} 数据时出错: {str(e)}") | ||||||
|  |             return pd.DataFrame() | ||||||
|  | 
 | ||||||
|  |     def calculate_trade_signals(self, df, take_profit_pct, stop_loss_pct, pullback_entry_pct, start_date): | ||||||
|  |         """计算交易信号和收益""" | ||||||
|  |         """计算交易信号和收益""" | ||||||
|  |         # 检查历史数据是否足够 | ||||||
|  |         historical_data = df[df['timestamp'] < start_date] | ||||||
|  |         if len(historical_data) < 240: | ||||||
|  |             # print(f"历史数据不足240天,跳过分析") | ||||||
|  |             return pd.DataFrame() | ||||||
|  | 
 | ||||||
|  |         # 计算240日均线基准价格(使用历史数据) | ||||||
|  |         ma240_base_price = historical_data['close'].tail(240).mean() | ||||||
|  |         # print(f"240日均线基准价格: {ma240_base_price:.2f}") | ||||||
|  | 
 | ||||||
|  |         # 只使用分析期间的数据进行交易信号计算 | ||||||
|  |         analysis_df = df[df['timestamp'] >= start_date].copy() | ||||||
|  |         if len(analysis_df) < 10:  # 确保分析期间有足够的数据 | ||||||
|  |             # print(f"分析期间数据不足10天,跳过分析") | ||||||
|  |             return pd.DataFrame() | ||||||
|  | 
 | ||||||
|  |         results = [] | ||||||
|  |         positions = []  # 记录所有持仓,每个元素是一个字典,包含入场价格和日期 | ||||||
|  |         max_positions = 5  # 最大建仓次数 | ||||||
|  |         cumulative_profit = 0  # 总累计收益 | ||||||
|  |         trades_count = 0 | ||||||
|  |         max_holding_period = timedelta(days=180)  # 最长持仓6个月 | ||||||
|  |         max_loss_amount = -300000  # 最大亏损金额为30万 | ||||||
|  |          | ||||||
|  |         # 添加统计变量 | ||||||
|  |         total_holding_days = 0  # 总持仓天数 | ||||||
|  |         max_capital_usage = 0  # 最大资金占用 | ||||||
|  |         total_trades = 0  # 总交易次数 | ||||||
|  |         profitable_trades = 0  # 盈利交易次数 | ||||||
|  |         loss_trades = 0  # 亏损交易次数 | ||||||
|  | 
 | ||||||
|  |         # print(f"\n{'='*50}") | ||||||
|  |         # print(f"开始分析股票: {df.iloc[0]['symbol']}") | ||||||
|  |         # print(f"{'='*50}") | ||||||
|  | 
 | ||||||
|  |         # 记录前一天的收盘价 | ||||||
|  |         prev_close = None | ||||||
|  | 
 | ||||||
|  |         for index, row in analysis_df.iterrows(): | ||||||
|  |             current_price = (float(row['high']) + float(row['low'])) / 2 | ||||||
|  |             day_result = { | ||||||
|  |                 'timestamp': row['timestamp'], | ||||||
|  |                 'profit': cumulative_profit,  # 记录当前的累计收益 | ||||||
|  |                 'position': False, | ||||||
|  |                 'action': 'hold' | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             # 更新所有持仓的收益 | ||||||
|  |             positions_to_remove = [] | ||||||
|  | 
 | ||||||
|  |             for pos in positions: | ||||||
|  |                 pos_days_held = (row['timestamp'] - pos['entry_date']).days | ||||||
|  |                 pos_size = self.total_position / pos['entry_price'] | ||||||
|  |                  | ||||||
|  |                 # 检查是否需要平仓 | ||||||
|  |                 should_close = False | ||||||
|  |                 close_type = None | ||||||
|  |                 close_price = current_price | ||||||
|  | 
 | ||||||
|  |                 if pos_days_held >= self.min_holding_days: | ||||||
|  |                     # 检查止盈 | ||||||
|  |                     if float(row['high']) >= pos['entry_price'] * (1 + take_profit_pct): | ||||||
|  |                         should_close = True | ||||||
|  |                         close_type = 'sell_profit' | ||||||
|  |                         close_price = pos['entry_price'] * (1 + take_profit_pct) | ||||||
|  |                     # 检查最大亏损 | ||||||
|  |                     elif (close_price - pos['entry_price']) * pos_size <= max_loss_amount: | ||||||
|  |                         should_close = True | ||||||
|  |                         close_type = 'sell_loss' | ||||||
|  |                     # 检查持仓时间 | ||||||
|  |                     elif (row['timestamp'] - pos['entry_date']) >= max_holding_period: | ||||||
|  |                         should_close = True | ||||||
|  |                         close_type = 'time_limit' | ||||||
|  | 
 | ||||||
|  |                 if should_close: | ||||||
|  |                     profit = (close_price - pos['entry_price']) * pos_size | ||||||
|  |                     commission = profit * self.commission_rate if profit > 0 else 0 | ||||||
|  |                     interest_cost = self.borrowed_capital * self.daily_interest_rate * pos_days_held | ||||||
|  |                     net_profit = profit - commission - interest_cost | ||||||
|  |                      | ||||||
|  |                     # 更新统计数据 | ||||||
|  |                     total_holding_days += pos_days_held | ||||||
|  |                     total_trades += 1 | ||||||
|  |                     if net_profit > 0: | ||||||
|  |                         profitable_trades += 1 | ||||||
|  |                     else: | ||||||
|  |                         loss_trades += 1 | ||||||
|  |                      | ||||||
|  |                     cumulative_profit += net_profit | ||||||
|  |                     positions_to_remove.append(pos) | ||||||
|  |                     day_result['action'] = close_type | ||||||
|  |                     day_result['profit'] = cumulative_profit | ||||||
|  |                      | ||||||
|  |                     # print(f"\n>>> {'止盈平仓' if close_type == 'sell_profit' else '止损平仓' if close_type == 'sell_loss' else '到期平仓'}") | ||||||
|  |                     # print(f"持仓天数: {pos_days_held}") | ||||||
|  |                     # print(f"持仓数量: {pos_size:.0f}股") | ||||||
|  |                     # print(f"平仓价格: {close_price:.2f}") | ||||||
|  |                     # print(f"{'盈利' if profit > 0 else '亏损'}: {profit:,.2f}") | ||||||
|  |                     # print(f"手续费: {commission:,.2f}") | ||||||
|  |                     # print(f"利息成本: {interest_cost:,.2f}") | ||||||
|  |                     # print(f"净{'盈利' if net_profit > 0 else '亏损'}: {net_profit:,.2f}") | ||||||
|  | 
 | ||||||
|  |             # 移除已平仓的持仓 | ||||||
|  |             for pos in positions_to_remove: | ||||||
|  |                 positions.remove(pos) | ||||||
|  | 
 | ||||||
|  |             # 检查是否可以建仓 | ||||||
|  |             if len(positions) < max_positions and prev_close is not None: | ||||||
|  |                 # 计算当日跌幅 | ||||||
|  |                 daily_drop = (prev_close - float(row['low'])) / prev_close | ||||||
|  |                  | ||||||
|  |                 # 检查价格是否在均线区间内 | ||||||
|  |                 price_in_range = ( | ||||||
|  |                     current_price >= ma240_base_price * 0.7 and  | ||||||
|  |                     current_price <= ma240_base_price * 1.3 | ||||||
|  |                 ) | ||||||
|  |                  | ||||||
|  |                 # 如果跌幅超过5%且价格在均线区间内 | ||||||
|  |                 if daily_drop >= 0.05 and price_in_range: | ||||||
|  |                     positions.append({ | ||||||
|  |                         'entry_price': current_price, | ||||||
|  |                         'entry_date': row['timestamp'] | ||||||
|  |                     }) | ||||||
|  |                     trades_count += 1 | ||||||
|  |                     # print(f"\n>>> 建仓信号 #{trades_count}") | ||||||
|  |                     # print(f"日期: {row['timestamp'].strftime('%Y-%m-%d')}") | ||||||
|  |                     # print(f"建仓价格: {current_price:.2f}") | ||||||
|  |                     # print(f"建仓数量: {self.total_position/current_price:.0f}股") | ||||||
|  |                     # print(f"当日跌幅: {daily_drop*100:.2f}%") | ||||||
|  |                     # print(f"距离均线: {((current_price/ma240_base_price)-1)*100:.2f}%") | ||||||
|  | 
 | ||||||
|  |             # 更新前一天收盘价 | ||||||
|  |             prev_close = float(row['close']) | ||||||
|  | 
 | ||||||
|  |             # 更新日结果 | ||||||
|  |             day_result['position'] = len(positions) > 0 | ||||||
|  |             results.append(day_result) | ||||||
|  | 
 | ||||||
|  |             # 计算当前资金占用 | ||||||
|  |             current_capital_usage = sum(self.total_position for pos in positions) | ||||||
|  |             max_capital_usage = max(max_capital_usage, current_capital_usage) | ||||||
|  | 
 | ||||||
|  |         # 在最后一天强制平仓所有持仓 | ||||||
|  |         if positions: | ||||||
|  |             final_price = (float(analysis_df.iloc[-1]['high']) + float(analysis_df.iloc[-1]['low'])) / 2 | ||||||
|  |             final_total_profit = 0 | ||||||
|  |              | ||||||
|  |             for pos in positions: | ||||||
|  |                 position_size = self.total_position / pos['entry_price'] | ||||||
|  |                 days_held = (analysis_df.iloc[-1]['timestamp'] - pos['entry_date']).days | ||||||
|  |                 final_profit = (final_price - pos['entry_price']) * position_size | ||||||
|  |                 commission = final_profit * self.commission_rate if final_profit > 0 else 0 | ||||||
|  |                 interest_cost = self.borrowed_capital * self.daily_interest_rate * days_held | ||||||
|  |                 net_profit = final_profit - commission - interest_cost | ||||||
|  |                  | ||||||
|  |                 # 更新统计数据 | ||||||
|  |                 total_holding_days += days_held | ||||||
|  |                 total_trades += 1 | ||||||
|  |                 if net_profit > 0: | ||||||
|  |                     profitable_trades += 1 | ||||||
|  |                 else: | ||||||
|  |                     loss_trades += 1 | ||||||
|  |                  | ||||||
|  |                 final_total_profit += net_profit | ||||||
|  |                  | ||||||
|  |                 # print(f"\n>>> 到期强制平仓") | ||||||
|  |                 # print(f"持仓天数: {days_held}") | ||||||
|  |                 # print(f"持仓数量: {position_size:.0f}股") | ||||||
|  |                 # print(f"平仓价格: {final_price:.2f}") | ||||||
|  |                 # print(f"毛利润: {final_profit:,.2f}") | ||||||
|  |                 # print(f"手续费: {commission:,.2f}") | ||||||
|  |                 # print(f"利息成本: {interest_cost:,.2f}") | ||||||
|  |                 # print(f"净利润: {net_profit:,.2f}") | ||||||
|  |              | ||||||
|  |             results[-1]['action'] = 'final_sell' | ||||||
|  |             cumulative_profit += final_total_profit  # 更新最终的累计收益 | ||||||
|  |             results[-1]['profit'] = cumulative_profit  # 更新最后一天的累计收益 | ||||||
|  | 
 | ||||||
|  |         # 计算统计数据 | ||||||
|  |         avg_holding_days = total_holding_days / total_trades if total_trades > 0 else 0 | ||||||
|  |         win_rate = profitable_trades / total_trades * 100 if total_trades > 0 else 0 | ||||||
|  | 
 | ||||||
|  |         # print(f"\n{'='*50}") | ||||||
|  |         # print(f"交易统计") | ||||||
|  |         # print(f"总交易次数: {trades_count}") | ||||||
|  |         # print(f"累计收益: {cumulative_profit:,.2f}") | ||||||
|  |         # print(f"最大资金占用: {max_capital_usage:,.2f}") | ||||||
|  |         # print(f"平均持仓天数: {avg_holding_days:.1f}") | ||||||
|  |         # print(f"胜率: {win_rate:.1f}%") | ||||||
|  |         # print(f"盈利交易: {profitable_trades}次") | ||||||
|  |         # print(f"亏损交易: {loss_trades}次") | ||||||
|  |         # print(f"{'='*50}\n") | ||||||
|  | 
 | ||||||
|  |         # 将统计数据添加到结果中 | ||||||
|  |         if len(results) > 0: | ||||||
|  |             results[-1]['stats'] = { | ||||||
|  |                 'total_trades': total_trades, | ||||||
|  |                 'profitable_trades': profitable_trades, | ||||||
|  |                 'loss_trades': loss_trades, | ||||||
|  |                 'win_rate': win_rate, | ||||||
|  |                 'avg_holding_days': avg_holding_days, | ||||||
|  |                 'max_capital_usage': max_capital_usage, | ||||||
|  |                 'final_profit': cumulative_profit | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         return pd.DataFrame(results) | ||||||
|  | 
 | ||||||
|  |     def analyze_stock(self, symbol, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): | ||||||
|  |         """分析单个股票""" | ||||||
|  |         df = self.get_stock_data(symbol, start_date, end_date, include_history=True) | ||||||
|  |          | ||||||
|  |         if len(df) < 240:  # 确保总数据量足够 | ||||||
|  |             print(f"股票 {symbol} 总数据不足240天,跳过分析") | ||||||
|  |             return None | ||||||
|  |              | ||||||
|  |         return self.calculate_trade_signals( | ||||||
|  |             df, | ||||||
|  |             take_profit_pct=take_profit_pct, | ||||||
|  |             stop_loss_pct=stop_loss_pct, | ||||||
|  |             pullback_entry_pct=pullback_entry_pct, | ||||||
|  |             start_date=start_date | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def analyze_all_stocks(self, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct): | ||||||
|  |         """分析所有股票""" | ||||||
|  |         stocks = self.get_stock_list() | ||||||
|  |         all_results = {} | ||||||
|  |         analysis_summary = { | ||||||
|  |             'processed': [], | ||||||
|  |             'skipped': [], | ||||||
|  |             'error': [] | ||||||
|  |         } | ||||||
|  |          | ||||||
|  |         print(f"\n开始分析,共 {len(stocks)} 支股票") | ||||||
|  |          | ||||||
|  |         for symbol in tqdm(stocks, desc="Analyzing stocks"): | ||||||
|  |             try: | ||||||
|  |                 # 获取股票数据 | ||||||
|  |                 df = self.get_stock_data(symbol, start_date, end_date, include_history=True) | ||||||
|  |                  | ||||||
|  |                 if df.empty: | ||||||
|  |                     print(f"\n股票 {symbol} 数据为空,跳过分析") | ||||||
|  |                     analysis_summary['skipped'].append({ | ||||||
|  |                         'symbol': symbol, | ||||||
|  |                         'reason': 'empty_data', | ||||||
|  |                         'take_profit_pct': take_profit_pct | ||||||
|  |                     }) | ||||||
|  |                     continue | ||||||
|  |                  | ||||||
|  |                 # 检查历史数据 | ||||||
|  |                 historical_data = df[df['timestamp'] < start_date] | ||||||
|  |                 if len(historical_data) < 240: | ||||||
|  |                     # print(f"\n股票 {symbol} 历史数据不足240天({len(historical_data)}天),跳过分析") | ||||||
|  |                     analysis_summary['skipped'].append({ | ||||||
|  |                         'symbol': symbol, | ||||||
|  |                         'reason': f'insufficient_history_{len(historical_data)}days', | ||||||
|  |                         'take_profit_pct': take_profit_pct | ||||||
|  |                     }) | ||||||
|  |                     continue | ||||||
|  |                  | ||||||
|  |                 # 检查分析期数据 | ||||||
|  |                 analysis_data = df[df['timestamp'] >= start_date] | ||||||
|  |                 if len(analysis_data) < 10: | ||||||
|  |                     print(f"\n股票 {symbol} 分析期数据不足10天({len(analysis_data)}天),跳过分析") | ||||||
|  |                     analysis_summary['skipped'].append({ | ||||||
|  |                         'symbol': symbol, | ||||||
|  |                         'reason': f'insufficient_analysis_data_{len(analysis_data)}days', | ||||||
|  |                         'take_profit_pct': take_profit_pct | ||||||
|  |                     }) | ||||||
|  |                     continue | ||||||
|  |                  | ||||||
|  |                 # 计算交易信号 | ||||||
|  |                 result = self.calculate_trade_signals( | ||||||
|  |                     df, | ||||||
|  |                     take_profit_pct, | ||||||
|  |                     stop_loss_pct, | ||||||
|  |                     pullback_entry_pct, | ||||||
|  |                     start_date | ||||||
|  |                 ) | ||||||
|  |                  | ||||||
|  |                 if result is not None and not result.empty: | ||||||
|  |                     # 检查是否有交易记录 | ||||||
|  |                     has_trades = any(action in ['sell_profit', 'sell_loss', 'time_limit', 'final_sell']  | ||||||
|  |                                    for action in result['action'] if pd.notna(action)) | ||||||
|  |                      | ||||||
|  |                     if has_trades: | ||||||
|  |                         all_results[symbol] = result | ||||||
|  |                         analysis_summary['processed'].append({ | ||||||
|  |                             'symbol': symbol, | ||||||
|  |                             'take_profit_pct': take_profit_pct, | ||||||
|  |                             'trades': len([x for x in result['action'] if pd.notna(x) and x != 'hold']), | ||||||
|  |                             'profit': result['profit'].iloc[-1] | ||||||
|  |                         }) | ||||||
|  |                     else: | ||||||
|  |                         # print(f"\n股票 {symbol} 没有产生有效的交易信号") | ||||||
|  |                         analysis_summary['skipped'].append({ | ||||||
|  |                             'symbol': symbol, | ||||||
|  |                             'reason': 'no_valid_signals', | ||||||
|  |                             'take_profit_pct': take_profit_pct | ||||||
|  |                         }) | ||||||
|  |                  | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"\n处理股票 {symbol} 时出错: {str(e)}") | ||||||
|  |                 analysis_summary['error'].append({ | ||||||
|  |                     'symbol': symbol, | ||||||
|  |                     'error': str(e), | ||||||
|  |                     'take_profit_pct': take_profit_pct | ||||||
|  |                 }) | ||||||
|  |                 continue | ||||||
|  |          | ||||||
|  |         # 打印分析总结 | ||||||
|  |         print(f"\n分析完成:") | ||||||
|  |         print(f"成功分析: {len(analysis_summary['processed'])} 支股票") | ||||||
|  |         print(f"跳过分析: {len(analysis_summary['skipped'])} 支股票") | ||||||
|  |         print(f"出错股票: {len(analysis_summary['error'])} 支股票") | ||||||
|  |          | ||||||
|  |         # 保存分析总结 | ||||||
|  |         if not os.path.exists('results/analysis_summary'): | ||||||
|  |             os.makedirs('results/analysis_summary') | ||||||
|  |          | ||||||
|  |         # 保存处理成功的股票信息 | ||||||
|  |         if analysis_summary['processed']: | ||||||
|  |             pd.DataFrame(analysis_summary['processed']).to_csv( | ||||||
|  |                 f'results/analysis_summary/processed_stocks_{take_profit_pct*100:.1f}.csv', | ||||||
|  |                 index=False | ||||||
|  |             ) | ||||||
|  |          | ||||||
|  |         # 保存跳过的股票信息 | ||||||
|  |         if analysis_summary['skipped']: | ||||||
|  |             pd.DataFrame(analysis_summary['skipped']).to_csv( | ||||||
|  |                 f'results/analysis_summary/skipped_stocks_{take_profit_pct*100:.1f}.csv', | ||||||
|  |                 index=False | ||||||
|  |             ) | ||||||
|  |          | ||||||
|  |         # 保存出错的股票信息 | ||||||
|  |         if analysis_summary['error']: | ||||||
|  |             pd.DataFrame(analysis_summary['error']).to_csv( | ||||||
|  |                 f'results/analysis_summary/error_stocks_{take_profit_pct*100:.1f}.csv', | ||||||
|  |                 index=False | ||||||
|  |             ) | ||||||
|  |          | ||||||
|  |         return all_results | ||||||
|  | 
 | ||||||
|  |     def plot_results(self, results, symbol): | ||||||
|  |         """绘制单个股票的收益走势图""" | ||||||
|  |         plt.figure(figsize=(12, 6)) | ||||||
|  |         plt.plot(results['timestamp'], results['profit'].cumsum(), label='Cumulative Profit') | ||||||
|  |         plt.title(f'Stock {symbol} Trading Results') | ||||||
|  |         plt.xlabel('Date') | ||||||
|  |         plt.ylabel('Profit (CNY)') | ||||||
|  |         plt.legend() | ||||||
|  |         plt.grid(True) | ||||||
|  |         plt.xticks(rotation=45) | ||||||
|  |         plt.tight_layout() | ||||||
|  |         return plt | ||||||
|  | 
 | ||||||
|  |     def plot_all_stocks(self, all_results): | ||||||
|  |         """绘制所有股票的收益走势图""" | ||||||
|  |         plt.figure(figsize=(15, 8)) | ||||||
|  |          | ||||||
|  |         # 记录所有股票的累计收益 | ||||||
|  |         total_profits = {} | ||||||
|  |         max_profit = float('-inf') | ||||||
|  |         min_profit = float('inf') | ||||||
|  |          | ||||||
|  |         # 绘制每支股票的曲线 | ||||||
|  |         for symbol, results in all_results.items(): | ||||||
|  |             # 直接使用已经计算好的累计收益 | ||||||
|  |             profits = results['profit'] | ||||||
|  |             plt.plot(results['timestamp'], profits, label=symbol, alpha=0.5, linewidth=1) | ||||||
|  |              | ||||||
|  |             # 更新最大最小收益 | ||||||
|  |             max_profit = max(max_profit, profits.max()) | ||||||
|  |             min_profit = min(min_profit, profits.min()) | ||||||
|  |              | ||||||
|  |             # 记录总收益(使用最后一个累计收益值) | ||||||
|  |             total_profits[symbol] = profits.iloc[-1] | ||||||
|  |          | ||||||
|  |         # 计算所有股票的中位数收益曲线 | ||||||
|  |         # 首先找到共同的日期范围 | ||||||
|  |         all_dates = set() | ||||||
|  |         for results in all_results.values(): | ||||||
|  |             all_dates.update(results['timestamp']) | ||||||
|  |         all_dates = sorted(list(all_dates)) | ||||||
|  |          | ||||||
|  |         # 创建一个包含所有日期的DataFrame | ||||||
|  |         profits_df = pd.DataFrame(index=all_dates) | ||||||
|  |          | ||||||
|  |         # 对每支股票,填充所有日期的收益 | ||||||
|  |         for symbol, results in all_results.items(): | ||||||
|  |             daily_profits = pd.Series(index=all_dates, data=np.nan) | ||||||
|  |             # 直接使用已经计算好的累计收益 | ||||||
|  |             for date, profit in zip(results['timestamp'], results['profit']): | ||||||
|  |                 daily_profits[date] = profit | ||||||
|  |             # 对于没有交易的日期,使用最后一个累计收益填充 | ||||||
|  |             daily_profits.fillna(method='ffill', inplace=True) | ||||||
|  |             daily_profits.fillna(0, inplace=True)  # 对开始之前的日期填充0 | ||||||
|  |             profits_df[symbol] = daily_profits | ||||||
|  |          | ||||||
|  |         # 计算并绘制中位数收益曲线 | ||||||
|  |         median_line = profits_df.median(axis=1) | ||||||
|  |         plt.plot(all_dates, median_line, 'r-', label='Median', linewidth=2) | ||||||
|  |          | ||||||
|  |         plt.title('All Stocks Trading Results') | ||||||
|  |         plt.xlabel('Date') | ||||||
|  |         plt.ylabel('Profit (CNY)') | ||||||
|  |         plt.grid(True) | ||||||
|  |         plt.xticks(rotation=45) | ||||||
|  |          | ||||||
|  |         # 添加利润区间标注 | ||||||
|  |         plt.text(0.02, 0.98,  | ||||||
|  |                 f'Profit Range:\nMax: {max_profit:,.0f}\nMin: {min_profit:,.0f}\nMedian: {median_line.iloc[-1]:,.0f}',  | ||||||
|  |                 transform=plt.gca().transAxes,  | ||||||
|  |                 bbox=dict(facecolor='white', alpha=0.8)) | ||||||
|  |          | ||||||
|  |         # 计算所有股票的统计数据 | ||||||
|  |         total_stats = { | ||||||
|  |             'total_trades': 0, | ||||||
|  |             'profitable_trades': 0, | ||||||
|  |             'loss_trades': 0, | ||||||
|  |             'total_holding_days': 0, | ||||||
|  |             'max_capital_usage': 0, | ||||||
|  |             'total_profit': 0, | ||||||
|  |             'profitable_stocks': 0, | ||||||
|  |             'loss_stocks': 0, | ||||||
|  |             'total_capital_usage': 0,  # 添加总资金占用 | ||||||
|  |             'total_win_rate': 0,       # 添加总胜率 | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         # 获取前5名和后5名股票 | ||||||
|  |         top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] | ||||||
|  |         bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] | ||||||
|  |          | ||||||
|  |         # 添加排名信息 | ||||||
|  |         rank_text = "Top 5 Stocks:\n" | ||||||
|  |         for symbol, profit in top_5: | ||||||
|  |             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||||
|  |         rank_text += "\nBottom 5 Stocks:\n" | ||||||
|  |         for symbol, profit in bottom_5: | ||||||
|  |             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||||
|  |              | ||||||
|  |         plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,  | ||||||
|  |                 bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') | ||||||
|  | 
 | ||||||
|  |         # 记录每支股票的统计数据用于计算平均值 | ||||||
|  |         stock_stats = [] | ||||||
|  | 
 | ||||||
|  |         for symbol, results in all_results.items(): | ||||||
|  |             if 'stats' in results.iloc[-1]: | ||||||
|  |                 stats = results.iloc[-1]['stats'] | ||||||
|  |                 stock_stats.append(stats)  # 记录每支股票的统计数据 | ||||||
|  |                  | ||||||
|  |                 total_stats['total_trades'] += stats['total_trades'] | ||||||
|  |                 total_stats['profitable_trades'] += stats['profitable_trades'] | ||||||
|  |                 total_stats['loss_trades'] += stats['loss_trades'] | ||||||
|  |                 total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] | ||||||
|  |                 total_stats['max_capital_usage'] = max(total_stats['max_capital_usage'], stats['max_capital_usage']) | ||||||
|  |                 total_stats['total_capital_usage'] += stats['max_capital_usage'] | ||||||
|  |                 total_stats['total_profit'] += stats['final_profit'] | ||||||
|  |                 total_stats['total_win_rate'] += stats['win_rate'] | ||||||
|  |                 if stats['final_profit'] > 0: | ||||||
|  |                     total_stats['profitable_stocks'] += 1 | ||||||
|  |                 else: | ||||||
|  |                     total_stats['loss_stocks'] += 1 | ||||||
|  | 
 | ||||||
|  |         # 计算总体统计 | ||||||
|  |         total_stocks = total_stats['profitable_stocks'] + total_stats['loss_stocks'] | ||||||
|  |         avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 | ||||||
|  |         win_rate = total_stats['profitable_trades'] / total_stats['total_trades'] * 100 if total_stats['total_trades'] > 0 else 0 | ||||||
|  |         stock_win_rate = total_stats['profitable_stocks'] / total_stocks * 100 if total_stocks > 0 else 0 | ||||||
|  |          | ||||||
|  |         # 计算平均值统计 | ||||||
|  |         avg_trades_per_stock = total_stats['total_trades'] / total_stocks if total_stocks > 0 else 0 | ||||||
|  |         avg_profit_per_stock = total_stats['total_profit'] / total_stocks if total_stocks > 0 else 0 | ||||||
|  |         avg_capital_usage = total_stats['total_capital_usage'] / total_stocks if total_stocks > 0 else 0 | ||||||
|  |         avg_win_rate = total_stats['total_win_rate'] / total_stocks if total_stocks > 0 else 0 | ||||||
|  | 
 | ||||||
|  |         # 添加总体统计信息 | ||||||
|  |         stats_text = ( | ||||||
|  |             f"总体统计:\n" | ||||||
|  |             f"总交易次数: {total_stats['total_trades']}\n" | ||||||
|  |             # f"最大资金占用: {total_stats['max_capital_usage']:,.0f}\n" | ||||||
|  |             f"平均持仓天数: {avg_holding_days:.1f}\n" | ||||||
|  |             f"交易胜率: {win_rate:.1f}%\n" | ||||||
|  |             f"股票胜率: {stock_win_rate:.1f}%\n" | ||||||
|  |             f"盈利股票数: {total_stats['profitable_stocks']}\n" | ||||||
|  |             f"亏损股票数: {total_stats['loss_stocks']}\n" | ||||||
|  |             f"\n平均值统计:\n" | ||||||
|  |             f"平均交易次数: {avg_trades_per_stock:.1f}\n" | ||||||
|  |             f"平均收益: {avg_profit_per_stock:,.0f}\n" | ||||||
|  |             f"平均资金占用: {avg_capital_usage:,.0f}\n" | ||||||
|  |             f"平均持仓天数: {avg_holding_days:.1f}\n" | ||||||
|  |             f"平均胜率: {avg_win_rate:.1f}%" | ||||||
|  |         ) | ||||||
|  |          | ||||||
|  |         plt.text(0.02, 0.6, stats_text, | ||||||
|  |                 transform=plt.gca().transAxes, | ||||||
|  |                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), | ||||||
|  |                 verticalalignment='top', | ||||||
|  |                 fontsize=10) | ||||||
|  |          | ||||||
|  |         # 获取前5名和后5名股票 | ||||||
|  |         top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5] | ||||||
|  |         bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5] | ||||||
|  |          | ||||||
|  |         # 添加排名信息 | ||||||
|  |         rank_text = "Top 5 Stocks:\n" | ||||||
|  |         for symbol, profit in top_5: | ||||||
|  |             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||||
|  |         rank_text += "\nBottom 5 Stocks:\n" | ||||||
|  |         for symbol, profit in bottom_5: | ||||||
|  |             rank_text += f"{symbol}: {profit:,.0f}\n" | ||||||
|  |              | ||||||
|  |         plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,  | ||||||
|  |                 bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top') | ||||||
|  |          | ||||||
|  |         plt.tight_layout() | ||||||
|  |         return plt, total_profits | ||||||
|  | 
 | ||||||
|  |     def plot_profit_distribution(self, total_profits): | ||||||
|  |         """绘制所有股票的收益分布小提琴图""" | ||||||
|  |         plt.figure(figsize=(12, 8)) | ||||||
|  |          | ||||||
|  |         # 将收益数据转换为数组,确保使用最终累计收益 | ||||||
|  |         profits = np.array([profit for profit in total_profits.values()]) | ||||||
|  |          | ||||||
|  |         # 计算统计数据 | ||||||
|  |         mean_profit = np.mean(profits) | ||||||
|  |         median_profit = np.median(profits) | ||||||
|  |         max_profit = np.max(profits) | ||||||
|  |         min_profit = np.min(profits) | ||||||
|  |         std_profit = np.std(profits) | ||||||
|  |          | ||||||
|  |         # 绘制小提琴图 | ||||||
|  |         violin_parts = plt.violinplot(profits, positions=[0], showmeans=True, showmedians=True) | ||||||
|  |          | ||||||
|  |         # 设置小提琴图的颜色 | ||||||
|  |         violin_parts['bodies'][0].set_facecolor('lightblue') | ||||||
|  |         violin_parts['bodies'][0].set_alpha(0.7) | ||||||
|  |         violin_parts['cmeans'].set_color('red') | ||||||
|  |         violin_parts['cmedians'].set_color('blue') | ||||||
|  |          | ||||||
|  |         # 添加统计信息 | ||||||
|  |         stats_text = ( | ||||||
|  |             f"统计信息:\n" | ||||||
|  |             f"平均收益: {mean_profit:,.0f}\n" | ||||||
|  |             f"中位数收益: {median_profit:,.0f}\n" | ||||||
|  |             f"最大收益: {max_profit:,.0f}\n" | ||||||
|  |             f"最小收益: {min_profit:,.0f}\n" | ||||||
|  |             f"标准差: {std_profit:,.0f}" | ||||||
|  |         ) | ||||||
|  |          | ||||||
|  |         plt.text(0.65, 0.95, stats_text, | ||||||
|  |                 transform=plt.gca().transAxes, | ||||||
|  |                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), | ||||||
|  |                 verticalalignment='top', | ||||||
|  |                 fontsize=10) | ||||||
|  |          | ||||||
|  |         # 设置图表样式 | ||||||
|  |         plt.title('股票收益分布图', fontsize=14, pad=20) | ||||||
|  |         plt.ylabel('收益 (元)', fontsize=12) | ||||||
|  |         plt.xticks([0], ['所有股票'], fontsize=10) | ||||||
|  |         plt.grid(True, axis='y', alpha=0.3) | ||||||
|  |          | ||||||
|  |         # 添加零线 | ||||||
|  |         plt.axhline(y=0, color='r', linestyle='--', alpha=0.3) | ||||||
|  |          | ||||||
|  |         # 计算盈利和亏损的股票数量 | ||||||
|  |         profit_count = np.sum(profits > 0) | ||||||
|  |         loss_count = np.sum(profits < 0) | ||||||
|  |         total_count = len(profits) | ||||||
|  |          | ||||||
|  |         # 添加盈亏比例信息 | ||||||
|  |         ratio_text = ( | ||||||
|  |             f"盈亏比例:\n" | ||||||
|  |             f"盈利: {profit_count}支 ({profit_count/total_count*100:.1f}%)\n" | ||||||
|  |             f"亏损: {loss_count}支 ({loss_count/total_count*100:.1f}%)\n" | ||||||
|  |             f"总计: {total_count}支" | ||||||
|  |         ) | ||||||
|  |          | ||||||
|  |         plt.text(0.65, 0.6, ratio_text, | ||||||
|  |                 transform=plt.gca().transAxes, | ||||||
|  |                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'), | ||||||
|  |                 verticalalignment='top', | ||||||
|  |                 fontsize=10) | ||||||
|  |          | ||||||
|  |         plt.tight_layout() | ||||||
|  |         return plt | ||||||
|  | 
 | ||||||
|  |     def plot_profit_matrix(self, df): | ||||||
|  |         """绘制止盈比例分析矩阵的结果图""" | ||||||
|  |         plt.figure(figsize=(15, 10)) | ||||||
|  |          | ||||||
|  |         # 创建子图 | ||||||
|  |         fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) | ||||||
|  |          | ||||||
|  |         # 1. 总收益曲线 | ||||||
|  |         ax1.plot(df.index, df['total_profit'], 'b-', linewidth=2) | ||||||
|  |         ax1.set_title('总收益 vs 止盈比例') | ||||||
|  |         ax1.set_xlabel('止盈比例 (%)') | ||||||
|  |         ax1.set_ylabel('总收益 (元)') | ||||||
|  |         ax1.grid(True) | ||||||
|  |          | ||||||
|  |         # 标记最优止盈比例点 | ||||||
|  |         best_profit_pct = df['total_profit'].idxmax() | ||||||
|  |         best_profit = df['total_profit'].max() | ||||||
|  |         ax1.plot(best_profit_pct, best_profit, 'ro') | ||||||
|  |         ax1.annotate(f'最优: {best_profit_pct:.1f}%\n{best_profit:,.0f}元', | ||||||
|  |                     (best_profit_pct, best_profit), | ||||||
|  |                     xytext=(10, 10), textcoords='offset points') | ||||||
|  |          | ||||||
|  |         # 2. 平均每笔收益和中位数收益 | ||||||
|  |         ax2.plot(df.index, df['avg_profit_per_trade'], 'g-', linewidth=2, label='平均收益') | ||||||
|  |         ax2.plot(df.index, df['median_profit_per_trade'], 'r--', linewidth=2, label='中位数收益') | ||||||
|  |         ax2.set_title('每笔交易收益 vs 止盈比例') | ||||||
|  |         ax2.set_xlabel('止盈比例 (%)') | ||||||
|  |         ax2.set_ylabel('收益 (元)') | ||||||
|  |         ax2.legend() | ||||||
|  |         ax2.grid(True) | ||||||
|  |          | ||||||
|  |         # 3. 平均持仓天数曲线 | ||||||
|  |         ax3.plot(df.index, df['avg_holding_days'], 'r-', linewidth=2) | ||||||
|  |         ax3.set_title('平均持仓天数 vs 止盈比例') | ||||||
|  |         ax3.set_xlabel('止盈比例 (%)') | ||||||
|  |         ax3.set_ylabel('持仓天数') | ||||||
|  |         ax3.grid(True) | ||||||
|  |          | ||||||
|  |         # 4. 胜率对比 | ||||||
|  |         ax4.plot(df.index, df['trade_win_rate'], 'c-', linewidth=2, label='交易胜率') | ||||||
|  |         ax4.plot(df.index, df['stock_win_rate'], 'm--', linewidth=2, label='股票胜率') | ||||||
|  |         ax4.set_title('胜率 vs 止盈比例') | ||||||
|  |         ax4.set_xlabel('止盈比例 (%)') | ||||||
|  |         ax4.set_ylabel('胜率 (%)') | ||||||
|  |         ax4.legend() | ||||||
|  |         ax4.grid(True) | ||||||
|  |          | ||||||
|  |         # 调整布局 | ||||||
|  |         plt.tight_layout() | ||||||
|  |         return plt | ||||||
|  | 
 | ||||||
|  |     def analyze_profit_matrix(self, start_date, end_date, stop_loss_pct, pullback_entry_pct): | ||||||
|  |         """分析不同止盈比例的表现矩阵""" | ||||||
|  |         # 创建results目录和临时文件目录 | ||||||
|  |         if not os.path.exists('results/temp'): | ||||||
|  |             os.makedirs('results/temp') | ||||||
|  |              | ||||||
|  |         # 从30%到5%,每次减少1% | ||||||
|  |         profit_percentages = list(np.arange(0.30, 0.04, -0.01)) | ||||||
|  |          | ||||||
|  |         for take_profit_pct in profit_percentages: | ||||||
|  |             # 检查是否已经分析过这个止盈点 | ||||||
|  |             temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' | ||||||
|  |             if os.path.exists(temp_file): | ||||||
|  |                 print(f"\n止盈比例 {take_profit_pct*100:.1f}% 已分析,跳过") | ||||||
|  |                 continue | ||||||
|  |                  | ||||||
|  |             print(f"\n分析止盈比例: {take_profit_pct*100:.1f}%") | ||||||
|  |              | ||||||
|  |             try: | ||||||
|  |                 # 运行分析 | ||||||
|  |                 results = self.analyze_all_stocks( | ||||||
|  |                     start_date, | ||||||
|  |                     end_date, | ||||||
|  |                     take_profit_pct, | ||||||
|  |                     stop_loss_pct, | ||||||
|  |                     pullback_entry_pct | ||||||
|  |                 ) | ||||||
|  |                  | ||||||
|  |                 # 过滤掉空结果 | ||||||
|  |                 valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} | ||||||
|  |                  | ||||||
|  |                 if valid_results: | ||||||
|  |                     # 计算统计数据 | ||||||
|  |                     total_stats = { | ||||||
|  |                         'total_trades': 0, | ||||||
|  |                         'profitable_trades': 0, | ||||||
|  |                         'loss_trades': 0, | ||||||
|  |                         'total_holding_days': 0, | ||||||
|  |                         'total_profit': 0, | ||||||
|  |                         'total_capital_usage': 0, | ||||||
|  |                         'total_win_rate': 0, | ||||||
|  |                         'stock_count': len(valid_results), | ||||||
|  |                         'profitable_stocks': 0, | ||||||
|  |                         'all_trade_profits': [] | ||||||
|  |                     } | ||||||
|  |                      | ||||||
|  |                     # 收集每支股票的统计数据 | ||||||
|  |                     for symbol, result in valid_results.items(): | ||||||
|  |                         if 'stats' in result.iloc[-1]: | ||||||
|  |                             stats = result.iloc[-1]['stats'] | ||||||
|  |                             total_stats['total_trades'] += stats['total_trades'] | ||||||
|  |                             total_stats['profitable_trades'] += stats['profitable_trades'] | ||||||
|  |                             total_stats['loss_trades'] += stats['loss_trades'] | ||||||
|  |                             total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades'] | ||||||
|  |                             total_stats['total_profit'] += stats['final_profit'] | ||||||
|  |                             total_stats['total_capital_usage'] += stats['max_capital_usage'] | ||||||
|  |                             total_stats['total_win_rate'] += stats['win_rate'] | ||||||
|  |                              | ||||||
|  |                             if stats['final_profit'] > 0: | ||||||
|  |                                 total_stats['profitable_stocks'] += 1 | ||||||
|  |                                  | ||||||
|  |                             if hasattr(stats, 'trade_profits'): | ||||||
|  |                                 total_stats['all_trade_profits'].extend(stats['trade_profits']) | ||||||
|  |                      | ||||||
|  |                     # 计算统计数据 | ||||||
|  |                     avg_trades = total_stats['total_trades'] / total_stats['stock_count'] | ||||||
|  |                     avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 | ||||||
|  |                     avg_profit_per_trade = total_stats['total_profit'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0 | ||||||
|  |                     median_profit_per_trade = np.median(total_stats['all_trade_profits']) if total_stats['all_trade_profits'] else 0 | ||||||
|  |                     total_profit = total_stats['total_profit'] | ||||||
|  |                     trade_win_rate = (total_stats['profitable_trades'] / total_stats['total_trades'] * 100) if total_stats['total_trades'] > 0 else 0 | ||||||
|  |                     stock_win_rate = (total_stats['profitable_stocks'] / total_stats['stock_count'] * 100) if total_stats['stock_count'] > 0 else 0 | ||||||
|  |                      | ||||||
|  |                     # 创建单个止盈点的结果DataFrame | ||||||
|  |                     result_df = pd.DataFrame([{ | ||||||
|  |                         'take_profit_pct': take_profit_pct * 100, | ||||||
|  |                         'total_profit': total_profit, | ||||||
|  |                         'avg_profit_per_trade': avg_profit_per_trade, | ||||||
|  |                         'median_profit_per_trade': median_profit_per_trade, | ||||||
|  |                         'avg_holding_days': avg_holding_days, | ||||||
|  |                         'trade_win_rate': trade_win_rate, | ||||||
|  |                         'stock_win_rate': stock_win_rate, | ||||||
|  |                         'total_trades': total_stats['total_trades'], | ||||||
|  |                         'stock_count': total_stats['stock_count'] | ||||||
|  |                     }]) | ||||||
|  |                      | ||||||
|  |                     # 保存单个止盈点的结果 | ||||||
|  |                     result_df.to_csv(temp_file, index=False) | ||||||
|  |                     print(f"保存止盈点 {take_profit_pct*100:.1f}% 的分析结果") | ||||||
|  |                      | ||||||
|  |                 # 清理内存 | ||||||
|  |                 del valid_results | ||||||
|  |                 del results | ||||||
|  |                 if 'total_stats' in locals(): | ||||||
|  |                     del total_stats | ||||||
|  |                  | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"处理止盈点 {take_profit_pct*100:.1f}% 时出错: {str(e)}") | ||||||
|  |                 continue | ||||||
|  |          | ||||||
|  |         # 合并所有结果 | ||||||
|  |         print("\n合并所有分析结果...") | ||||||
|  |         all_results = [] | ||||||
|  |         for take_profit_pct in profit_percentages: | ||||||
|  |             temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv' | ||||||
|  |             if os.path.exists(temp_file): | ||||||
|  |                 try: | ||||||
|  |                     df = pd.read_csv(temp_file) | ||||||
|  |                     all_results.append(df) | ||||||
|  |                 except Exception as e: | ||||||
|  |                     print(f"读取文件 {temp_file} 时出错: {str(e)}") | ||||||
|  |          | ||||||
|  |         if all_results: | ||||||
|  |             # 合并所有结果 | ||||||
|  |             df = pd.concat(all_results, ignore_index=True) | ||||||
|  |             df = df.round({ | ||||||
|  |                 'take_profit_pct': 1, | ||||||
|  |                 'total_profit': 0, | ||||||
|  |                 'avg_profit_per_trade': 0, | ||||||
|  |                 'median_profit_per_trade': 0, | ||||||
|  |                 'avg_holding_days': 1, | ||||||
|  |                 'trade_win_rate': 1, | ||||||
|  |                 'stock_win_rate': 1, | ||||||
|  |                 'total_trades': 0, | ||||||
|  |                 'stock_count': 0 | ||||||
|  |             }) | ||||||
|  |              | ||||||
|  |             # 设置止盈比例为索引并排序 | ||||||
|  |             df.set_index('take_profit_pct', inplace=True) | ||||||
|  |             df.sort_index(ascending=False, inplace=True) | ||||||
|  |              | ||||||
|  |             # 保存最终结果 | ||||||
|  |             df.to_csv('results/profit_matrix_analysis.csv') | ||||||
|  |              | ||||||
|  |             # 打印结果 | ||||||
|  |             print("\n止盈比例分析矩阵:") | ||||||
|  |             print("=" * 120) | ||||||
|  |             print(df.to_string()) | ||||||
|  |             print("=" * 120) | ||||||
|  |              | ||||||
|  |             try: | ||||||
|  |                 # 绘制并保存矩阵分析图表 | ||||||
|  |                 plt = self.plot_profit_matrix(df) | ||||||
|  |                 plt.savefig('results/profit_matrix_analysis.png', bbox_inches='tight', dpi=300) | ||||||
|  |                 plt.close() | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"绘制图表时出错: {str(e)}") | ||||||
|  |              | ||||||
|  |             return df | ||||||
|  |         else: | ||||||
|  |             print("没有找到任何分析结果") | ||||||
|  |             return pd.DataFrame() | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     # 数据库连接配置 | ||||||
|  |     db_config = { | ||||||
|  |         'host': '192.168.18.199', | ||||||
|  |         'port': 3306, | ||||||
|  |         'user': 'root', | ||||||
|  |         'password': 'Chlry#$.8', | ||||||
|  |         'database': 'db_gp_cj' | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     try: | ||||||
|  |         connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}" | ||||||
|  |          | ||||||
|  |         # 创建results目录 | ||||||
|  | 
 | ||||||
|  |         if not os.path.exists('results'): | ||||||
|  |             os.makedirs('results') | ||||||
|  |          | ||||||
|  |         # 创建分析器实例 | ||||||
|  |         analyzer = StockAnalyzer(connection_string) | ||||||
|  |          | ||||||
|  |         # 设置分析参数 | ||||||
|  |         start_date = '2022-05-05' | ||||||
|  |         end_date = '2023-08-28' | ||||||
|  |         stop_loss_pct = -0.99   # -99%止损 | ||||||
|  |         pullback_entry_pct = -0.05  # -5%回调建仓 | ||||||
|  |          | ||||||
|  |         # 运行止盈比例分析 | ||||||
|  |         profit_matrix = analyzer.analyze_profit_matrix( | ||||||
|  |             start_date, | ||||||
|  |             end_date, | ||||||
|  |             stop_loss_pct, | ||||||
|  |             pullback_entry_pct | ||||||
|  |         ) | ||||||
|  |          | ||||||
|  |         if not profit_matrix.empty: | ||||||
|  |             # 使用最优止盈比例进行完整分析 | ||||||
|  |             take_profit_pct = profit_matrix['total_profit'].idxmax() / 100 | ||||||
|  |             print(f"\n使用最优止盈比例 {take_profit_pct*100:.1f}% 运行详细分析") | ||||||
|  |              | ||||||
|  |             # 运行分析并只生成所需的图表 | ||||||
|  |             results = analyzer.analyze_all_stocks( | ||||||
|  |                 start_date, | ||||||
|  |                 end_date, | ||||||
|  |                 take_profit_pct, | ||||||
|  |                 stop_loss_pct, | ||||||
|  |                 pullback_entry_pct | ||||||
|  |             ) | ||||||
|  |              | ||||||
|  |             # 过滤掉空结果 | ||||||
|  |             valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty} | ||||||
|  |              | ||||||
|  |             if valid_results: | ||||||
|  |                 # 只生成所需的两个图表 | ||||||
|  |                 plt, _ = analyzer.plot_all_stocks(valid_results) | ||||||
|  |                 plt.savefig('results/all_stocks_analysis.png', bbox_inches='tight', dpi=300) | ||||||
|  |                 plt.close() | ||||||
|  |          | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"程序运行出错: {str(e)}") | ||||||
|  |         raise | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main()  | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,81 @@ | ||||||
|  | # 用前须知 | ||||||
|  | 
 | ||||||
|  | ## xtdata提供和MiniQmt的交互接口,本质是和MiniQmt建立连接,由MiniQmt处理行情数据请求,再把结果回传返回到python层。使用的行情服务器以及能获取到的行情数据和MiniQmt是一致的,要检查数据或者切换连接时直接操作MiniQmt即可。 | ||||||
|  | 
 | ||||||
|  | ## 对于数据获取接口,使用时需要先确保MiniQmt已有所需要的数据,如果不足可以通过补充数据接口补充,再调用数据获取接口获取。 | ||||||
|  | 
 | ||||||
|  | ## 对于订阅接口,直接设置数据回调,数据到来时会由回调返回。订阅接收到的数据一般会保存下来,同种数据不需要再单独补充。 | ||||||
|  | 
 | ||||||
|  | # 代码讲解 | ||||||
|  | 
 | ||||||
|  | # 从本地python导入xtquant库,如果出现报错则说明安装失败 | ||||||
|  | from xtquant import xtdata | ||||||
|  | import time | ||||||
|  | 
 | ||||||
|  | # 设定一个标的列表 | ||||||
|  | code_list = ["000001.SZ"] | ||||||
|  | # 设定获取数据的周期 | ||||||
|  | period = "1d" | ||||||
|  | 
 | ||||||
|  | # 下载标的行情数据 | ||||||
|  | if 1: | ||||||
|  |     ## 为了方便用户进行数据管理,xtquant的大部分历史数据都是以压缩形式存储在本地的 | ||||||
|  |     ## 比如行情数据,需要通过download_history_data下载,财务数据需要通过 | ||||||
|  |     ## 所以在取历史数据之前,我们需要调用数据下载接口,将数据下载到本地 | ||||||
|  |     for i in code_list: | ||||||
|  |         xtdata.download_history_data(i, period=period, incrementally=True)  # 增量下载行情数据(开高低收,等等)到本地 | ||||||
|  | 
 | ||||||
|  |     xtdata.download_financial_data(code_list)  # 下载财务数据到本地 | ||||||
|  |     xtdata.download_sector_data()  # 下载板块数据到本地 | ||||||
|  |     # 更多数据的下载方式可以通过数据字典查询 | ||||||
|  | 
 | ||||||
|  | # 读取本地历史行情数据 | ||||||
|  | history_data = xtdata.get_market_data_ex([], code_list, period=period, count=-1) | ||||||
|  | print(history_data) | ||||||
|  | print("=" * 20) | ||||||
|  | 
 | ||||||
|  | # 如果需要盘中的实时行情,需要向服务器进行订阅后才能获取 | ||||||
|  | # 订阅后,get_market_data函数于get_market_data_ex函数将会自动拼接本地历史行情与服务器实时行情 | ||||||
|  | 
 | ||||||
|  | # 向服务器订阅数据 | ||||||
|  | for i in code_list: | ||||||
|  |     xtdata.subscribe_quote(i, period=period, count=-1)  # 设置count = -1来取到当天所有实时行情 | ||||||
|  | 
 | ||||||
|  | # 等待订阅完成 | ||||||
|  | time.sleep(1) | ||||||
|  | 
 | ||||||
|  | # 获取订阅后的行情 | ||||||
|  | kline_data = xtdata.get_market_data_ex([], code_list, period=period) | ||||||
|  | print(kline_data) | ||||||
|  | 
 | ||||||
|  | # 获取订阅后的行情,并以固定间隔进行刷新,预期会循环打印10次 | ||||||
|  | for i in range(10): | ||||||
|  |     # 这边做演示,就用for来循环了,实际使用中可以用while True | ||||||
|  |     kline_data = xtdata.get_market_data_ex([], code_list, period=period) | ||||||
|  |     print(kline_data) | ||||||
|  |     time.sleep(3)  # 三秒后再次获取行情 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 如果不想用固定间隔触发,可以以用订阅后的回调来执行 | ||||||
|  | # 这种模式下当订阅的callback回调函数将会异步的执行,每当订阅的标的tick发生变化更新,callback回调函数就会被调用一次 | ||||||
|  | # 本地已有的数据不会触发callback | ||||||
|  | 
 | ||||||
|  | # 定义的回测函数 | ||||||
|  | ## 回调函数中,data是本次触发回调的数据,只有一条 | ||||||
|  | def f(data): | ||||||
|  |     # print(data) | ||||||
|  | 
 | ||||||
|  |     code_list = list(data.keys())  # 获取到本次触发的标的代码 | ||||||
|  | 
 | ||||||
|  |     kline_in_callabck = xtdata.get_market_data_ex([], code_list, period=period)  # 在回调中获取klines数据 | ||||||
|  |     print(kline_in_callabck) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | for i in code_list: | ||||||
|  |     xtdata.subscribe_quote(i, period=period, count=-1, callback=f)  # 订阅时设定回调函数 | ||||||
|  | 
 | ||||||
|  | # 使用回调时,必须要同时使用xtdata.run()来阻塞程序,否则程序运行到最后一行就直接结束退出了。 | ||||||
|  | xtdata.run() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
		Loading…
	
		Reference in New Issue