commit;
This commit is contained in:
parent
ae9b2f2fcf
commit
aadef5f0fa
129
README.md
129
README.md
|
@ -211,6 +211,127 @@ pip install -r requirements.txt
|
||||||
|
|
||||||
筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。
|
筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。
|
||||||
|
|
||||||
|
## Docker 部署说明
|
||||||
|
|
||||||
|
本系统支持通过 Docker 进行部署,可以单实例部署或多实例部署。
|
||||||
|
|
||||||
|
### 1. 单实例部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 构建镜像
|
||||||
|
docker-compose build
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 多实例部署
|
||||||
|
|
||||||
|
提供了三个脚本用于管理多实例部署,每个实例将在不同端口上运行(从5088开始递增):
|
||||||
|
|
||||||
|
#### 使用 docker run 部署多实例(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 赋予脚本执行权限
|
||||||
|
chmod +x deploy-multiple.sh
|
||||||
|
|
||||||
|
# 部署3个实例,端口将为5088、5089、5090
|
||||||
|
./deploy-multiple.sh 3
|
||||||
|
```
|
||||||
|
|
||||||
|
每个实例将有自己独立的数据目录:`./instances/instance-N/`,包含独立的日志和报告文件。
|
||||||
|
|
||||||
|
#### 使用 docker-compose 部署多实例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 赋予脚本执行权限
|
||||||
|
chmod +x deploy-compose.sh
|
||||||
|
|
||||||
|
# 部署3个实例,端口将为5088、5089、5090
|
||||||
|
./deploy-compose.sh 3
|
||||||
|
```
|
||||||
|
|
||||||
|
每个实例将使用独立的项目名称 `stock-app-N`。
|
||||||
|
|
||||||
|
#### 管理多实例
|
||||||
|
|
||||||
|
使用 `manage-instances.sh` 脚本可以方便地管理已部署的实例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 赋予脚本执行权限
|
||||||
|
chmod +x manage-instances.sh
|
||||||
|
|
||||||
|
# 查看所有实例状态
|
||||||
|
./manage-instances.sh status
|
||||||
|
|
||||||
|
# 列出正在运行的实例
|
||||||
|
./manage-instances.sh list
|
||||||
|
|
||||||
|
# 启动指定实例
|
||||||
|
./manage-instances.sh start 2
|
||||||
|
|
||||||
|
# 启动所有实例
|
||||||
|
./manage-instances.sh start all
|
||||||
|
|
||||||
|
# 停止指定实例
|
||||||
|
./manage-instances.sh stop 1
|
||||||
|
|
||||||
|
# 停止所有实例
|
||||||
|
./manage-instances.sh stop all
|
||||||
|
|
||||||
|
# 重启指定实例
|
||||||
|
./manage-instances.sh restart 3
|
||||||
|
|
||||||
|
# 查看指定实例的日志
|
||||||
|
./manage-instances.sh logs 1
|
||||||
|
|
||||||
|
# 删除指定实例
|
||||||
|
./manage-instances.sh remove 2
|
||||||
|
|
||||||
|
# 删除所有实例
|
||||||
|
./manage-instances.sh remove all
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 中文字体安装
|
||||||
|
|
||||||
|
系统需要中文字体以正常生成PDF报告。如果部署时遇到 `找不到中文字体文件` 错误,可以使用以下方法解决:
|
||||||
|
|
||||||
|
#### 自动安装字体(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 进入容器内部
|
||||||
|
docker exec -it [容器ID或名称] bash
|
||||||
|
|
||||||
|
# 在容器内执行字体安装脚本
|
||||||
|
cd /app
|
||||||
|
python src/fundamentals_llm/setup_fonts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 手动安装字体
|
||||||
|
|
||||||
|
1. 在宿主机上下载中文字体文件(如simhei.ttf、wqy-microhei.ttc等)
|
||||||
|
2. 创建字体目录并复制字体文件:
|
||||||
|
```bash
|
||||||
|
mkdir -p src/fundamentals_llm/fonts
|
||||||
|
cp 你下载的字体文件.ttf src/fundamentals_llm/fonts/simhei.ttf
|
||||||
|
```
|
||||||
|
3. 重新构建镜像:
|
||||||
|
```bash
|
||||||
|
docker-compose build
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 访问服务
|
||||||
|
|
||||||
|
部署完成后,可以通过以下URL访问服务:
|
||||||
|
|
||||||
|
- 单实例:`http://服务器IP:5000`
|
||||||
|
- 多实例:`http://服务器IP:508X`(X为实例序号,从8开始)
|
||||||
|
|
||||||
|
例如:
|
||||||
|
- 实例1:`http://服务器IP:5088`
|
||||||
|
- 实例2:`http://服务器IP:5089`
|
||||||
|
- 实例3:`http://服务器IP:5090`
|
||||||
|
|
||||||
## 数据库查询示例
|
## 数据库查询示例
|
||||||
|
|
||||||
以下SQL查询示例展示了如何从数据库中筛选特定类型的企业:
|
以下SQL查询示例展示了如何从数据库中筛选特定类型的企业:
|
||||||
|
@ -251,3 +372,11 @@ AND stock_code IN (
|
||||||
```
|
```
|
||||||
|
|
||||||
注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。
|
注意:实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 重新构建镜像
|
||||||
|
docker-compose build
|
||||||
|
|
||||||
|
# 重启所有实例
|
||||||
|
./manage-instances.sh restart all
|
|
@ -13,5 +13,5 @@ volcengine-python-sdk[ark]
|
||||||
openai>=1.0
|
openai>=1.0
|
||||||
reportlab>=4.3.1
|
reportlab>=4.3.1
|
||||||
markdown2>=2.5.3
|
markdown2>=2.5.3
|
||||||
# 1. 按下 Win+R ,输入 regedit 打开注册表编辑器。
|
google-genai
|
||||||
# 2. 设置 \HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem 路径下的变量 LongPathsEnabled 为 1 即可。
|
redis==5.2.1
|
570
src/app.py
570
src/app.py
|
@ -1,27 +1,47 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import pandas as pd
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db
|
from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from flask import Flask, jsonify, request
|
from flask import Flask, jsonify, request, send_from_directory
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# 导入企业筛选器
|
# 导入企业筛选器
|
||||||
from src.fundamentals_llm.enterprise_screener import EnterpriseScreener
|
from src.fundamentals_llm.enterprise_screener import EnterpriseScreener
|
||||||
|
|
||||||
|
# 导入股票回测器
|
||||||
|
from src.stock_analysis_v2 import run_backtest, StockBacktester
|
||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(), # 输出到控制台
|
||||||
|
logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8') # 输出到文件
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 确保logs和results目录存在
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
os.makedirs('results', exist_ok=True)
|
||||||
|
os.makedirs('results/tasks', exist_ok=True)
|
||||||
|
os.makedirs('static/results', exist_ok=True)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info("Flask应用启动")
|
||||||
|
|
||||||
# 创建 Flask 应用
|
# 创建 Flask 应用
|
||||||
app = Flask(__name__)
|
app = Flask(__name__, static_folder='static')
|
||||||
CORS(app) # 启用跨域请求支持
|
CORS(app) # 启用跨域请求支持
|
||||||
|
|
||||||
# 创建企业筛选器实例
|
# 创建企业筛选器实例
|
||||||
|
@ -30,6 +50,329 @@ screener = EnterpriseScreener()
|
||||||
# 获取数据库连接
|
# 获取数据库连接
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
|
|
||||||
|
# 获取项目根目录
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports')
|
||||||
|
|
||||||
|
# 确保reports目录存在
|
||||||
|
os.makedirs(REPORTS_DIR, exist_ok=True)
|
||||||
|
logger.info(f"报告目录路径: {REPORTS_DIR}")
|
||||||
|
|
||||||
|
# 存储回测任务状态的字典
|
||||||
|
backtest_tasks = {}
|
||||||
|
|
||||||
|
def run_backtest_task(task_id, stocks_buy_dates, end_date):
|
||||||
|
"""
|
||||||
|
在后台运行回测任务
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"开始执行回测任务 {task_id}")
|
||||||
|
# 更新任务状态为进行中
|
||||||
|
backtest_tasks[task_id]['status'] = 'running'
|
||||||
|
|
||||||
|
# 运行回测
|
||||||
|
results, stats_list = run_backtest(stocks_buy_dates, end_date)
|
||||||
|
|
||||||
|
# 如果回测成功
|
||||||
|
if results and stats_list:
|
||||||
|
# 计算总体统计
|
||||||
|
stats_df = pd.DataFrame(stats_list)
|
||||||
|
total_profit = stats_df['final_profit'].sum()
|
||||||
|
avg_win_rate = stats_df['win_rate'].mean()
|
||||||
|
avg_holding_days = stats_df['avg_holding_days'].mean()
|
||||||
|
|
||||||
|
# 找出最佳止盈比例(假设为0.15,实际应从回测结果中分析得出)
|
||||||
|
best_take_profit_pct = 0.15
|
||||||
|
|
||||||
|
# 获取图表URL
|
||||||
|
chart_urls = {
|
||||||
|
"all_stocks": f"/static/results/{task_id}/all_stocks_analysis.png",
|
||||||
|
"profit_matrix": f"/static/results/{task_id}/profit_matrix_analysis.png"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存股票详细统计
|
||||||
|
stock_stats = []
|
||||||
|
for stat in stats_list:
|
||||||
|
stock_stats.append({
|
||||||
|
"symbol": stat['symbol'],
|
||||||
|
"total_trades": int(stat['total_trades']),
|
||||||
|
"profitable_trades": int(stat['profitable_trades']),
|
||||||
|
"loss_trades": int(stat['loss_trades']),
|
||||||
|
"win_rate": float(stat['win_rate']),
|
||||||
|
"avg_holding_days": float(stat['avg_holding_days']),
|
||||||
|
"final_profit": float(stat['final_profit']),
|
||||||
|
"entry_count": int(stat['entry_count'])
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建结果数据
|
||||||
|
result_data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "completed",
|
||||||
|
"results": {
|
||||||
|
"total_profit": float(total_profit),
|
||||||
|
"win_rate": float(avg_win_rate),
|
||||||
|
"avg_holding_days": float(avg_holding_days),
|
||||||
|
"best_take_profit_pct": best_take_profit_pct,
|
||||||
|
"stock_stats": stock_stats,
|
||||||
|
"chart_urls": chart_urls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存结果到文件
|
||||||
|
task_result_path = os.path.join('results', 'tasks', f"{task_id}.json")
|
||||||
|
with open(task_result_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(result_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# 更新任务状态为已完成
|
||||||
|
backtest_tasks[task_id].update({
|
||||||
|
'status': 'completed',
|
||||||
|
'results': result_data['results']
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"回测任务 {task_id} 已完成")
|
||||||
|
else:
|
||||||
|
# 更新任务状态为失败
|
||||||
|
backtest_tasks[task_id]['status'] = 'failed'
|
||||||
|
backtest_tasks[task_id]['error'] = "回测未产生有效结果"
|
||||||
|
logger.error(f"回测任务 {task_id} 失败:回测未产生有效结果")
|
||||||
|
except Exception as e:
|
||||||
|
# 更新任务状态为失败
|
||||||
|
backtest_tasks[task_id]['status'] = 'failed'
|
||||||
|
backtest_tasks[task_id]['error'] = str(e)
|
||||||
|
logger.error(f"回测任务 {task_id} 失败:{str(e)}")
|
||||||
|
|
||||||
|
@app.route('/api/backtest/run', methods=['POST'])
|
||||||
|
def start_backtest():
|
||||||
|
"""启动回测任务
|
||||||
|
|
||||||
|
请求体格式:
|
||||||
|
{
|
||||||
|
"stocks_buy_dates": {
|
||||||
|
"SH600522": ["2022-05-10", "2022-06-10"], // 股票代码: [买入日期列表]
|
||||||
|
"SZ002340": ["2022-06-15"],
|
||||||
|
"SH601615": ["2022-07-20", "2022-08-01"]
|
||||||
|
},
|
||||||
|
"end_date": "2022-10-20" // 所有股票共同的结束日期
|
||||||
|
}
|
||||||
|
|
||||||
|
返回内容:
|
||||||
|
{
|
||||||
|
"task_id": "backtask-khkerhy4u237y489237489truiy8432"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从请求体获取参数
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": "请求格式错误: 需要提供JSON数据"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
stocks_buy_dates = data.get('stocks_buy_dates')
|
||||||
|
end_date = data.get('end_date')
|
||||||
|
|
||||||
|
if not stocks_buy_dates or not isinstance(stocks_buy_dates, dict):
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": "请求格式错误: 需要提供stocks_buy_dates字典"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
if not end_date or not isinstance(end_date, str):
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": "请求格式错误: 需要提供有效的end_date"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 验证日期格式
|
||||||
|
try:
|
||||||
|
datetime.strptime(end_date, '%Y-%m-%d')
|
||||||
|
for stock_code, buy_dates in stocks_buy_dates.items():
|
||||||
|
if not isinstance(buy_dates, list):
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"请求格式错误: 股票 {stock_code} 的买入日期必须是列表"
|
||||||
|
}), 400
|
||||||
|
for buy_date in buy_dates:
|
||||||
|
datetime.strptime(buy_date, '%Y-%m-%d')
|
||||||
|
except ValueError as e:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"日期格式错误: {str(e)}"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 生成任务ID
|
||||||
|
task_id = f"backtask-{uuid.uuid4().hex[:16]}"
|
||||||
|
|
||||||
|
# 创建任务目录
|
||||||
|
task_dir = os.path.join('static', 'results', task_id)
|
||||||
|
os.makedirs(task_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 记录任务信息
|
||||||
|
backtest_tasks[task_id] = {
|
||||||
|
'status': 'pending',
|
||||||
|
'created_at': datetime.now().isoformat(),
|
||||||
|
'stocks_buy_dates': stocks_buy_dates,
|
||||||
|
'end_date': end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建线程运行回测
|
||||||
|
thread = Thread(target=run_backtest_task, args=(task_id, stocks_buy_dates, end_date))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
logger.info(f"已创建回测任务: {task_id}")
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"task_id": task_id
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建回测任务失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"创建回测任务失败: {str(e)}"
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/backtest/status', methods=['GET'])
|
||||||
|
def check_backtest_status():
|
||||||
|
"""查询回测任务状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- task_id: 回测任务ID
|
||||||
|
|
||||||
|
返回内容:
|
||||||
|
{
|
||||||
|
"task_id": "backtask-khkerhy4u237y489237489truiy8432",
|
||||||
|
"status": "running" | "completed" | "failed",
|
||||||
|
"created_at": "2023-12-01T10:30:45",
|
||||||
|
"error": "错误信息(如有)"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task_id = request.args.get('task_id')
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": "请求格式错误: 需要提供task_id参数"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 检查任务是否存在
|
||||||
|
if task_id not in backtest_tasks:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"任务不存在: {task_id}"
|
||||||
|
}), 404
|
||||||
|
|
||||||
|
# 获取任务信息
|
||||||
|
task_info = backtest_tasks[task_id]
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": task_info['status'],
|
||||||
|
"created_at": task_info['created_at']
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果任务失败,添加错误信息
|
||||||
|
if task_info['status'] == 'failed' and 'error' in task_info:
|
||||||
|
response['error'] = task_info['error']
|
||||||
|
|
||||||
|
return jsonify(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询任务状态失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"查询任务状态失败: {str(e)}"
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/backtest/result', methods=['GET'])
|
||||||
|
def get_backtest_result():
|
||||||
|
"""获取回测任务结果
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- task_id: 回测任务ID
|
||||||
|
|
||||||
|
返回内容:
|
||||||
|
{
|
||||||
|
"task_id": "backtask-2023121-001",
|
||||||
|
"status": "completed",
|
||||||
|
"results": {
|
||||||
|
"total_profit": 123456.78,
|
||||||
|
"win_rate": 75.5,
|
||||||
|
"avg_holding_days": 12.3,
|
||||||
|
"best_take_profit_pct": 0.15,
|
||||||
|
"stock_stats": [...],
|
||||||
|
"chart_urls": {
|
||||||
|
"all_stocks": "/static/results/backtask-2023121-001/all_stocks_analysis.png",
|
||||||
|
"profit_matrix": "/static/results/backtask-2023121-001/profit_matrix_analysis.png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task_id = request.args.get('task_id')
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": "请求格式错误: 需要提供task_id参数"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 检查任务是否存在
|
||||||
|
if task_id not in backtest_tasks:
|
||||||
|
# 尝试从文件加载
|
||||||
|
task_result_path = os.path.join('results', 'tasks', f"{task_id}.json")
|
||||||
|
if os.path.exists(task_result_path):
|
||||||
|
with open(task_result_path, 'r', encoding='utf-8') as f:
|
||||||
|
result_data = json.load(f)
|
||||||
|
return jsonify(result_data)
|
||||||
|
else:
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"任务不存在: {task_id}"
|
||||||
|
}), 404
|
||||||
|
|
||||||
|
# 获取任务信息
|
||||||
|
task_info = backtest_tasks[task_id]
|
||||||
|
|
||||||
|
# 检查任务是否完成
|
||||||
|
if task_info['status'] != 'completed':
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"任务尚未完成或已失败: {task_info['status']}"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "completed",
|
||||||
|
"results": task_info['results']
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonify(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取任务结果失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"获取任务结果失败: {str(e)}"
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/reports/<path:filename>')
|
||||||
|
def serve_report(filename):
|
||||||
|
"""提供PDF报告的访问"""
|
||||||
|
try:
|
||||||
|
logger.info(f"请求文件: {filename}")
|
||||||
|
logger.info(f"从目录: {REPORTS_DIR} 提供文件")
|
||||||
|
return send_from_directory(REPORTS_DIR, filename, as_attachment=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提供报告文件失败: {str(e)}")
|
||||||
|
return jsonify({"error": "文件不存在"}), 404
|
||||||
|
|
||||||
@app.route('/api/health', methods=['GET'])
|
@app.route('/api/health', methods=['GET'])
|
||||||
def health_check():
|
def health_check():
|
||||||
"""健康检查接口"""
|
"""健康检查接口"""
|
||||||
|
@ -185,10 +528,10 @@ def generate_reports():
|
||||||
|
|
||||||
# 导入 PDF 生成器模块
|
# 导入 PDF 生成器模块
|
||||||
try:
|
try:
|
||||||
from src.fundamentals_llm.pdf_generator import generate_investment_report
|
from src.fundamentals_llm.pdf_generator import PDFGenerator
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
from fundamentals_llm.pdf_generator import generate_investment_report
|
from fundamentals_llm.pdf_generator import PDFGenerator
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"无法导入 PDF 生成器模块: {str(e)}")
|
logger.error(f"无法导入 PDF 生成器模块: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
@ -200,8 +543,14 @@ def generate_reports():
|
||||||
generated_reports = []
|
generated_reports = []
|
||||||
for stock_code, stock_name in stocks:
|
for stock_code, stock_name in stocks:
|
||||||
try:
|
try:
|
||||||
|
# 创建 PDF 生成器实例
|
||||||
|
generator = PDFGenerator()
|
||||||
# 调用 PDF 生成器
|
# 调用 PDF 生成器
|
||||||
report_path = generate_investment_report(stock_code, stock_name)
|
report_path = generator.generate_pdf(
|
||||||
|
title=f"{stock_name}({stock_code}) 基本面分析报告",
|
||||||
|
content_dict={}, # 这里需要传入实际的内容字典
|
||||||
|
filename=f"{stock_name}_{stock_code}_analysis.pdf"
|
||||||
|
)
|
||||||
generated_reports.append({
|
generated_reports.append({
|
||||||
"code": stock_code,
|
"code": stock_code,
|
||||||
"name": stock_name,
|
"name": stock_name,
|
||||||
|
@ -505,9 +854,10 @@ def analyze_and_recommend():
|
||||||
"message": f"分析和推荐股票失败: {str(e)}"
|
"message": f"分析和推荐股票失败: {str(e)}"
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/comprehensive_analysis', methods=['POST'])
|
@app.route('/api/comprehensive_analysis', methods=['POST'])
|
||||||
def comprehensive_analysis():
|
def comprehensive_analysis():
|
||||||
"""综合分析接口 - 组合多种功能和参数
|
"""综合分析接口 - 使用队列方式处理被锁定的股票
|
||||||
|
|
||||||
请求体格式:
|
请求体格式:
|
||||||
{
|
{
|
||||||
|
@ -558,36 +908,9 @@ def comprehensive_analysis():
|
||||||
if not isinstance(limit, int) or limit <= 0:
|
if not isinstance(limit, int) or limit <= 0:
|
||||||
limit = 10
|
limit = 10
|
||||||
|
|
||||||
# 导入必要的聊天机器人模块
|
# 导入必要的模块
|
||||||
try:
|
try:
|
||||||
# 首先尝试导入聊天机器人模块
|
# 先导入基本面分析器
|
||||||
try:
|
|
||||||
from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
|
|
||||||
logger.info("成功从 src.fundamentals_llm.chat_bot 导入 ChatBot")
|
|
||||||
except ImportError as e1:
|
|
||||||
try:
|
|
||||||
from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
|
|
||||||
logger.info("成功从 fundamentals_llm.chat_bot 导入 ChatBot")
|
|
||||||
except ImportError as e2:
|
|
||||||
logger.error(f"无法导入在线聊天机器人模块: {str(e1)}, {str(e2)}")
|
|
||||||
return jsonify({
|
|
||||||
"status": "error",
|
|
||||||
"message": f"服务器配置错误: 聊天机器人模块不可用,错误详情: {str(e2)}"
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
# 然后尝试导入离线聊天机器人模块
|
|
||||||
try:
|
|
||||||
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
|
||||||
logger.info("成功从 src.fundamentals_llm.chat_bot_with_offline 导入 ChatBot")
|
|
||||||
except ImportError as e1:
|
|
||||||
try:
|
|
||||||
from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
|
||||||
logger.info("成功从 fundamentals_llm.chat_bot_with_offline 导入 ChatBot")
|
|
||||||
except ImportError as e2:
|
|
||||||
logger.warning(f"无法导入离线聊天机器人模块: {str(e1)}, {str(e2)}")
|
|
||||||
# 这里可以继续执行,因为某些功能可能不需要离线模型
|
|
||||||
|
|
||||||
# 最后导入基本面分析器
|
|
||||||
try:
|
try:
|
||||||
from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer
|
from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer
|
||||||
logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer")
|
logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer")
|
||||||
|
@ -601,6 +924,19 @@ def comprehensive_analysis():
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}"
|
"message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}"
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
# 再导入其他可能需要的模块
|
||||||
|
try:
|
||||||
|
from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
|
||||||
|
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
|
||||||
|
from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
|
except ImportError:
|
||||||
|
# 这些模块不是必须的,所以继续执行
|
||||||
|
logger.warning("无法导入聊天机器人模块,但这不会影响基本功能")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"导入必要模块时出错: {str(e)}")
|
logger.error(f"导入必要模块时出错: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
@ -611,64 +947,141 @@ def comprehensive_analysis():
|
||||||
# 创建基本面分析器实例
|
# 创建基本面分析器实例
|
||||||
analyzer = FundamentalAnalyzer()
|
analyzer = FundamentalAnalyzer()
|
||||||
|
|
||||||
# 为每个股票生成投资建议
|
# 准备结果容器
|
||||||
investment_advices = []
|
investment_advices = {} # 使用字典,股票代码作为键
|
||||||
for stock_code, stock_name in stocks:
|
processing_queue = list(stocks) # 初始处理队列
|
||||||
|
max_attempts = 5 # 最大重试次数
|
||||||
|
total_attempts = 0
|
||||||
|
|
||||||
|
# 导入数据库模块
|
||||||
|
from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db
|
||||||
|
|
||||||
|
# 开始处理队列
|
||||||
|
while processing_queue and total_attempts < max_attempts:
|
||||||
|
total_attempts += 1
|
||||||
|
logger.info(f"开始第 {total_attempts} 轮处理,队列中有 {len(processing_queue)} 只股票")
|
||||||
|
|
||||||
|
# 暂存下一轮需要处理的股票
|
||||||
|
next_round_queue = []
|
||||||
|
|
||||||
|
# 处理当前队列中的所有股票
|
||||||
|
for stock_code, stock_name in processing_queue:
|
||||||
try:
|
try:
|
||||||
# 生成投资建议
|
# 检查是否已有分析结果
|
||||||
|
db = next(get_db())
|
||||||
|
existing_result = get_analysis_result(db, stock_code, "investment_advice")
|
||||||
|
|
||||||
|
# 如果已有近期结果,直接使用
|
||||||
|
if existing_result and existing_result.update_time > datetime.now() - timedelta(hours=12):
|
||||||
|
investment_advices[stock_code] = {
|
||||||
|
"code": stock_code,
|
||||||
|
"name": stock_name,
|
||||||
|
"advice": existing_result.ai_response,
|
||||||
|
"reasoning": existing_result.reasoning_process,
|
||||||
|
"references": existing_result.references,
|
||||||
|
"status": "success",
|
||||||
|
"from_cache": True
|
||||||
|
}
|
||||||
|
logger.info(f"使用缓存的 {stock_name}({stock_code}) 分析结果")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否被锁定
|
||||||
|
if analyzer.is_stock_locked(stock_code, "investment_advice"):
|
||||||
|
# 已被锁定,放到下一轮队列
|
||||||
|
next_round_queue.append([stock_code, stock_name])
|
||||||
|
# 记录状态
|
||||||
|
if stock_code not in investment_advices:
|
||||||
|
investment_advices[stock_code] = {
|
||||||
|
"code": stock_code,
|
||||||
|
"name": stock_name,
|
||||||
|
"status": "pending",
|
||||||
|
"message": f"股票 {stock_code} 正在被其他请求分析中,已加入等待队列"
|
||||||
|
}
|
||||||
|
logger.info(f"股票 {stock_name}({stock_code}) 已被锁定,放入下一轮队列")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试锁定并分析
|
||||||
|
analyzer.lock_stock(stock_code, "investment_advice")
|
||||||
|
try:
|
||||||
|
# 执行分析
|
||||||
success, advice, reasoning, references = analyzer.query_analysis(
|
success, advice, reasoning, references = analyzer.query_analysis(
|
||||||
stock_code, stock_name, "investment_advice"
|
stock_code, stock_name, "investment_advice"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 记录结果
|
||||||
if success:
|
if success:
|
||||||
investment_advices.append({
|
investment_advices[stock_code] = {
|
||||||
"code": stock_code,
|
"code": stock_code,
|
||||||
"name": stock_name,
|
"name": stock_name,
|
||||||
"advice": advice,
|
"advice": advice,
|
||||||
"reasoning": reasoning,
|
"reasoning": reasoning,
|
||||||
"references": references,
|
"references": references,
|
||||||
"status": "success"
|
"status": "success"
|
||||||
})
|
}
|
||||||
logger.info(f"成功生成 {stock_name}({stock_code}) 的投资建议")
|
logger.info(f"成功分析 {stock_name}({stock_code})")
|
||||||
else:
|
else:
|
||||||
investment_advices.append({
|
investment_advices[stock_code] = {
|
||||||
"code": stock_code,
|
"code": stock_code,
|
||||||
"name": stock_name,
|
"name": stock_name,
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": advice
|
"error": advice or "分析失败,无详细信息"
|
||||||
})
|
}
|
||||||
logger.error(f"生成 {stock_name}({stock_code}) 的投资建议失败: {advice}")
|
logger.error(f"分析 {stock_name}({stock_code}) 失败: {advice}")
|
||||||
|
finally:
|
||||||
|
# 确保释放锁
|
||||||
|
analyzer.unlock_stock(stock_code, "investment_advice")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}")
|
# 处理异常
|
||||||
investment_advices.append({
|
investment_advices[stock_code] = {
|
||||||
"code": stock_code,
|
"code": stock_code,
|
||||||
"name": stock_name,
|
"name": stock_name,
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e)
|
"error": str(e)
|
||||||
|
}
|
||||||
|
logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}")
|
||||||
|
# 确保释放锁
|
||||||
|
try:
|
||||||
|
analyzer.unlock_stock(stock_code, "investment_advice")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 如果还有下一轮要处理的股票,等待一段时间后继续
|
||||||
|
if next_round_queue:
|
||||||
|
logger.info(f"本轮结束,还有 {len(next_round_queue)} 只股票等待下一轮处理")
|
||||||
|
# 等待30秒再处理下一轮
|
||||||
|
import time
|
||||||
|
time.sleep(30)
|
||||||
|
processing_queue = next_round_queue
|
||||||
|
else:
|
||||||
|
# 所有股票都已处理完,退出循环
|
||||||
|
logger.info("所有股票处理完毕")
|
||||||
|
processing_queue = []
|
||||||
|
|
||||||
|
# 处理仍在队列中的股票(达到最大重试次数但仍未处理的)
|
||||||
|
for stock_code, stock_name in processing_queue:
|
||||||
|
if stock_code in investment_advices and investment_advices[stock_code]["status"] == "pending":
|
||||||
|
investment_advices[stock_code].update({
|
||||||
|
"status": "timeout",
|
||||||
|
"message": f"等待超时,股票 {stock_code} 可能正在被长时间分析"
|
||||||
})
|
})
|
||||||
|
logger.warning(f"股票 {stock_name}({stock_code}) 分析超时")
|
||||||
|
|
||||||
|
# 将字典转换为列表
|
||||||
|
investment_advice_list = list(investment_advices.values())
|
||||||
|
|
||||||
# 生成PDF报告(如果需要)
|
# 生成PDF报告(如果需要)
|
||||||
pdf_results = []
|
pdf_results = []
|
||||||
if generate_pdf:
|
if generate_pdf:
|
||||||
try:
|
try:
|
||||||
# 导入PDF生成器模块
|
# 针对已成功分析的股票生成PDF
|
||||||
|
for stock_info in investment_advice_list:
|
||||||
|
if stock_info["status"] == "success":
|
||||||
|
stock_code = stock_info["code"]
|
||||||
|
stock_name = stock_info["name"]
|
||||||
try:
|
try:
|
||||||
from src.fundamentals_llm.pdf_generator import generate_investment_report
|
report_path = analyzer.generate_pdf_report(stock_code, stock_name)
|
||||||
except ImportError:
|
if report_path:
|
||||||
try:
|
|
||||||
from fundamentals_llm.pdf_generator import generate_investment_report
|
|
||||||
except ImportError as e:
|
|
||||||
logger.error(f"无法导入 PDF 生成器模块: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
"status": "error",
|
|
||||||
"message": f"服务器配置错误: PDF 生成器模块不可用, 错误详情: {str(e)}"
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
# 生成报告
|
|
||||||
for stock_code, stock_name in stocks:
|
|
||||||
try:
|
|
||||||
# 调用 PDF 生成器
|
|
||||||
report_path = generate_investment_report(stock_code, stock_name)
|
|
||||||
pdf_results.append({
|
pdf_results.append({
|
||||||
"code": stock_code,
|
"code": stock_code,
|
||||||
"name": stock_name,
|
"name": stock_name,
|
||||||
|
@ -676,6 +1089,14 @@ def comprehensive_analysis():
|
||||||
"status": "success"
|
"status": "success"
|
||||||
})
|
})
|
||||||
logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}")
|
logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}")
|
||||||
|
else:
|
||||||
|
pdf_results.append({
|
||||||
|
"code": stock_code,
|
||||||
|
"name": stock_name,
|
||||||
|
"status": "error",
|
||||||
|
"error": "生成报告失败"
|
||||||
|
})
|
||||||
|
logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}")
|
logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}")
|
||||||
pdf_results.append({
|
pdf_results.append({
|
||||||
|
@ -784,11 +1205,24 @@ def comprehensive_analysis():
|
||||||
logger.error(f"应用企业画像筛选失败: {str(e)}")
|
logger.error(f"应用企业画像筛选失败: {str(e)}")
|
||||||
filtered_stocks = []
|
filtered_stocks = []
|
||||||
|
|
||||||
|
# 统计各种状态的股票数量
|
||||||
|
success_count = sum(1 for item in investment_advice_list if item["status"] == "success")
|
||||||
|
pending_count = sum(1 for item in investment_advice_list if item["status"] == "pending")
|
||||||
|
timeout_count = sum(1 for item in investment_advice_list if item["status"] == "timeout")
|
||||||
|
error_count = sum(1 for item in investment_advice_list if item["status"] == "error")
|
||||||
|
|
||||||
# 返回结果
|
# 返回结果
|
||||||
response = {
|
response = {
|
||||||
"status": "success",
|
"status": "success" if success_count > 0 else "partial_success" if success_count + pending_count > 0 else "failed",
|
||||||
"total_input_stocks": len(stocks),
|
"total_input_stocks": len(stocks),
|
||||||
"investment_advices": investment_advices
|
"stats": {
|
||||||
|
"success": success_count,
|
||||||
|
"pending": pending_count,
|
||||||
|
"timeout": timeout_count,
|
||||||
|
"error": error_count
|
||||||
|
},
|
||||||
|
"rounds_attempted": total_attempts,
|
||||||
|
"investment_advices": investment_advice_list
|
||||||
}
|
}
|
||||||
|
|
||||||
if profile_filter:
|
if profile_filter:
|
||||||
|
|
|
@ -150,11 +150,15 @@ class ChatBot:
|
||||||
logger.error(f"格式化参考资料时出错: {str(e)}")
|
logger.error(f"格式化参考资料时出错: {str(e)}")
|
||||||
return str(ref)
|
return str(ref)
|
||||||
|
|
||||||
def chat(self, user_input: str) -> Dict[str, Any]:
|
def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]:
|
||||||
"""与AI进行对话
|
"""与AI进行对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: 用户输入的问题
|
user_input: 用户输入的问题
|
||||||
|
temperature: 控制输出的随机性,范围0-2,默认1.0
|
||||||
|
top_p: 控制输出的多样性,范围0-1,默认0.7
|
||||||
|
max_tokens: 控制输出的最大长度,默认4096
|
||||||
|
frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 包含以下字段的字典:
|
Dict[str, Any]: 包含以下字段的字典:
|
||||||
|
@ -173,6 +177,10 @@ class ChatBot:
|
||||||
stream = self.client.chat.completions.create(
|
stream = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.conversation_history,
|
messages=self.conversation_history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,420 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional, Any, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from google import genai
|
||||||
|
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||||
|
|
||||||
|
# 设置日志记录
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 定义基础数据结构
|
||||||
|
class TextAnalysisResult(BaseModel):
|
||||||
|
"""文本分析结果,包含分析正文、推理过程和引用URL"""
|
||||||
|
analysis_text: str = Field(description="详细的分析文本")
|
||||||
|
reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None)
|
||||||
|
references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None)
|
||||||
|
|
||||||
|
class NumericalAnalysisResult(BaseModel):
|
||||||
|
"""数值分析结果,包含数值和分析描述"""
|
||||||
|
value: str = Field(description="评估值")
|
||||||
|
description: str = Field(description="评估描述")
|
||||||
|
|
||||||
|
# 分析类型到模型类的映射
|
||||||
|
ANALYSIS_TYPE_MODEL_MAP = {
|
||||||
|
# 文本分析类型
|
||||||
|
"text_analysis_result": TextAnalysisResult,
|
||||||
|
|
||||||
|
# 数值分析类型
|
||||||
|
"numerical_analysis_result": NumericalAnalysisResult
|
||||||
|
}
|
||||||
|
|
||||||
|
class JsonOutputParser:
|
||||||
|
"""解析JSON输出的工具类"""
|
||||||
|
def __init__(self, pydantic_object):
|
||||||
|
"""初始化解析器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pydantic_object: 需要解析成的Pydantic模型类
|
||||||
|
"""
|
||||||
|
self.pydantic_object = pydantic_object
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
"""获取格式化指令"""
|
||||||
|
schema = self.pydantic_object.schema()
|
||||||
|
schema_str = json.dumps(schema, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return f"""请将你的回答格式化为符合以下JSON结构的格式:
|
||||||
|
{schema_str}
|
||||||
|
|
||||||
|
你的回答应该只包含一个有效的JSON对象,而不包含其他内容。请不要包含任何解释或前缀,仅输出JSON本身。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Any:
|
||||||
|
"""解析文本为对象"""
|
||||||
|
try:
|
||||||
|
# 尝试解析JSON
|
||||||
|
text = text.strip()
|
||||||
|
# 如果文本以```json开头且以```结尾,则提取中间部分
|
||||||
|
if text.startswith("```json") and text.endswith("```"):
|
||||||
|
text = text[7:-3].strip()
|
||||||
|
# 如果文本以```开头且以```结尾,则提取中间部分
|
||||||
|
elif text.startswith("```") and text.endswith("```"):
|
||||||
|
text = text[3:-3].strip()
|
||||||
|
|
||||||
|
json_obj = json.loads(text)
|
||||||
|
return self.pydantic_object.parse_obj(json_obj)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"无法解析为JSON或无法匹配模式: {e}")
|
||||||
|
|
||||||
|
# 获取项目根目录的绝对路径
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# 尝试导入配置
|
||||||
|
try:
|
||||||
|
# 添加项目根目录到 Python 路径
|
||||||
|
sys.path.append(os.path.dirname(ROOT_DIR))
|
||||||
|
sys.path.append(ROOT_DIR)
|
||||||
|
|
||||||
|
# 导入配置
|
||||||
|
try:
|
||||||
|
from scripts.config import get_model_config
|
||||||
|
|
||||||
|
logger.info("成功从scripts.config导入配置")
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from src.scripts.config import get_model_config
|
||||||
|
|
||||||
|
logger.info("成功从src.scripts.config导入配置")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("无法导入配置模块,使用默认配置")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用默认配置的实现
|
||||||
|
def get_model_config(platform: str, model_type: str):
|
||||||
|
return {
|
||||||
|
"api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU",
|
||||||
|
"model": "gemini-2.0-flash"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导入配置时出错: {str(e)},使用默认配置")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用默认配置的实现
|
||||||
|
def get_model_config(platform: str, model_type: str):
|
||||||
|
return {
|
||||||
|
"api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU",
|
||||||
|
"model": "gemini-2.0-flash"
|
||||||
|
}
|
||||||
|
|
||||||
|
class ChatBot:
|
||||||
|
def __init__(self, platform: str = "Gemini", model_type: str = "gemini-2.0-flash"):
|
||||||
|
"""初始化聊天机器人
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台名称,默认为Gemini
|
||||||
|
model_type: 要使用的模型类型,默认为gemini-2.0-flash
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从配置获取API配置
|
||||||
|
config = get_model_config(platform, model_type)
|
||||||
|
self.api_key = config["api_key"]
|
||||||
|
self.model = config["model"] # 直接使用配置中的模型名称
|
||||||
|
|
||||||
|
logger.info(f"初始化ChatBot,使用平台: {platform}, 模型: {self.model}")
|
||||||
|
|
||||||
|
google_search_tool = Tool(
|
||||||
|
google_search=GoogleSearch()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = genai.Client(api_key=self.api_key)
|
||||||
|
|
||||||
|
# 系统提示语
|
||||||
|
self.system_instruction = """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下:
|
||||||
|
|
||||||
|
1. 分析风格:
|
||||||
|
- 专业、客观、理性
|
||||||
|
- 注重数据支撑
|
||||||
|
- 关注风险控制
|
||||||
|
- 重视投资性价比
|
||||||
|
|
||||||
|
2. 分析框架:
|
||||||
|
- 公司基本面分析
|
||||||
|
- 行业竞争格局
|
||||||
|
- 估值水平评估
|
||||||
|
- 风险因素识别
|
||||||
|
- 投资机会判断
|
||||||
|
|
||||||
|
3. 输出要求:
|
||||||
|
- 简明扼要,重点突出
|
||||||
|
- 逻辑清晰,层次分明
|
||||||
|
- 数据准确,论据充分
|
||||||
|
- 结论明确,建议具体
|
||||||
|
|
||||||
|
请用专业投资经理的视角,对股票进行深入分析和投资建议。如果信息不足,请明确指出。"""
|
||||||
|
|
||||||
|
# 对话历史
|
||||||
|
self.conversation_history = []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化ChatBot时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def chat(self, user_input: str, analysis_type: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096) -> dict:
|
||||||
|
"""处理用户输入并返回AI回复
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: 用户输入的问题
|
||||||
|
analysis_type: 分析类型,如company_profile或business_scale等
|
||||||
|
temperature: 控制输出的随机性,范围0-2,默认0.7
|
||||||
|
top_p: 控制输出的多样性,范围0-1,默认0.7
|
||||||
|
max_tokens: 控制输出的最大长度,默认4096
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含AI回复的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from google.genai import types
|
||||||
|
google_search_tool = Tool(
|
||||||
|
google_search=GoogleSearch()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确定分析类型和模型
|
||||||
|
model_class = None
|
||||||
|
if analysis_type and analysis_type in ANALYSIS_TYPE_MODEL_MAP:
|
||||||
|
model_class = ANALYSIS_TYPE_MODEL_MAP[analysis_type]
|
||||||
|
parser = JsonOutputParser(pydantic_object=model_class)
|
||||||
|
format_instructions = parser.get_format_instructions()
|
||||||
|
|
||||||
|
# 在原有基础上添加格式指令
|
||||||
|
combined_instruction = f"{self.system_instruction}\n\n{format_instructions}"
|
||||||
|
else:
|
||||||
|
combined_instruction = self.system_instruction
|
||||||
|
parser = None
|
||||||
|
|
||||||
|
response = self.client.models.generate_content_stream(
|
||||||
|
model=self.model,
|
||||||
|
config=GenerateContentConfig(
|
||||||
|
system_instruction=combined_instruction,
|
||||||
|
tools=[google_search_tool],
|
||||||
|
response_modalities=["TEXT"],
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_output_tokens=max_tokens
|
||||||
|
),
|
||||||
|
contents=user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
for chunk in response:
|
||||||
|
full_response += chunk.text
|
||||||
|
print(chunk.text, end="")
|
||||||
|
|
||||||
|
# 如果指定了分析类型,尝试解析输出
|
||||||
|
if parser:
|
||||||
|
try:
|
||||||
|
# 检查分析类型是文本分析还是数值分析
|
||||||
|
is_text_analysis = model_class == TextAnalysisResult
|
||||||
|
is_numerical_analysis = model_class == NumericalAnalysisResult
|
||||||
|
|
||||||
|
# 使用解析方法尝试处理JSON响应
|
||||||
|
json_result = self.parse_json_response(full_response, model_class)
|
||||||
|
if json_result:
|
||||||
|
return json_result
|
||||||
|
|
||||||
|
# 如果使用解析方法失败,尝试使用parser解析
|
||||||
|
parsed_result = parser.parse(full_response)
|
||||||
|
structured_result = parsed_result.dict()
|
||||||
|
|
||||||
|
if isinstance(parsed_result, TextAnalysisResult):
|
||||||
|
# 文本分析结果
|
||||||
|
return {
|
||||||
|
"response": structured_result["analysis_text"],
|
||||||
|
"reasoning_process": structured_result.get("reasoning_process"),
|
||||||
|
"references": structured_result.get("references")
|
||||||
|
}
|
||||||
|
elif isinstance(parsed_result, NumericalAnalysisResult):
|
||||||
|
# 数值分析结果
|
||||||
|
return {
|
||||||
|
"response": structured_result["value"],
|
||||||
|
"description": structured_result.get("description", "")
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"解析响应失败: {str(e)},返回原始响应")
|
||||||
|
|
||||||
|
# 最后尝试根据分析类型直接构建结果
|
||||||
|
if is_text_analysis:
|
||||||
|
return {
|
||||||
|
"response": full_response,
|
||||||
|
"reasoning_process": None,
|
||||||
|
"references": None
|
||||||
|
}
|
||||||
|
elif is_numerical_analysis:
|
||||||
|
try:
|
||||||
|
# 尝试从文本中提取数字
|
||||||
|
import re
|
||||||
|
number_match = re.search(r'\b(\d+)\b', full_response)
|
||||||
|
value = int(number_match.group(1)) if number_match else 0
|
||||||
|
return {
|
||||||
|
"response": value,
|
||||||
|
"description": full_response
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
return {
|
||||||
|
"response": 0,
|
||||||
|
"description": full_response
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"response": full_response,
|
||||||
|
"reasoning_process": None,
|
||||||
|
"references": None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 普通响应
|
||||||
|
return {
|
||||||
|
"response": full_response,
|
||||||
|
"reasoning_process": None,
|
||||||
|
"references": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"聊天失败: {str(e)}")
|
||||||
|
error_msg = f"抱歉,发生错误: {str(e)}"
|
||||||
|
print(f"\n{error_msg}")
|
||||||
|
return {"response": error_msg, "reasoning_process": None, "references": None}
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
"""清除对话历史"""
|
||||||
|
self.conversation_history = []
|
||||||
|
# 重置会话
|
||||||
|
if hasattr(self, 'model_instance'):
|
||||||
|
self.chat_session = self.model_instance.start_chat()
|
||||||
|
print("对话历史已清除")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""运行聊天机器人"""
|
||||||
|
print("欢迎使用Gemini AI助手!输入 'quit' 退出,输入 'clear' 清除对话历史。")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 获取用户输入
|
||||||
|
user_input = input("\n你: ").strip()
|
||||||
|
|
||||||
|
# 检查退出命令
|
||||||
|
if user_input.lower() == 'quit':
|
||||||
|
print("感谢使用,再见!")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 检查清除历史命令
|
||||||
|
if user_input.lower() == 'clear':
|
||||||
|
self.clear_history()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果输入为空,继续下一轮
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取AI回复
|
||||||
|
self.chat(user_input)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n感谢使用,再见!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"运行错误: {str(e)}")
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
|
||||||
|
def parse_json_response(self, full_response: str, model_class) -> dict:
|
||||||
|
"""解析JSON响应,处理不同的JSON结构
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_response: 完整的响应文本
|
||||||
|
model_class: 模型类(TextAnalysisResult或NumericalAnalysisResult)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含解析结果的字典
|
||||||
|
"""
|
||||||
|
# 检查分析类型
|
||||||
|
is_text_analysis = model_class == TextAnalysisResult
|
||||||
|
is_numerical_analysis = model_class == NumericalAnalysisResult
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试提取JSON内容
|
||||||
|
json_text = full_response
|
||||||
|
# 如果文本以```json开头且以```结尾,则提取中间部分
|
||||||
|
if "```json" in json_text and "```" in json_text.split("```json", 1)[1]:
|
||||||
|
json_text = json_text.split("```json", 1)[1].split("```", 1)[0].strip()
|
||||||
|
# 如果文本以```开头且以```结尾,则提取中间部分
|
||||||
|
elif "```" in json_text and "```" in json_text.split("```", 1)[1]:
|
||||||
|
json_text = json_text.split("```", 1)[1].split("```", 1)[0].strip()
|
||||||
|
|
||||||
|
# 尝试解析JSON
|
||||||
|
json_obj = json.loads(json_text)
|
||||||
|
|
||||||
|
# 情况1: 直接包含字段的格式 {"analysis_text": "...", "reasoning_process": null, "references": [...]}
|
||||||
|
if is_text_analysis and "analysis_text" in json_obj:
|
||||||
|
return {
|
||||||
|
"response": json_obj["analysis_text"],
|
||||||
|
"reasoning_process": json_obj.get("reasoning_process"),
|
||||||
|
"references": json_obj.get("references", [])
|
||||||
|
}
|
||||||
|
elif is_numerical_analysis and "value" in json_obj:
|
||||||
|
value = json_obj["value"]
|
||||||
|
if isinstance(value, str) and value.isdigit():
|
||||||
|
value = int(value)
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
value = int(value)
|
||||||
|
else:
|
||||||
|
value = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": value,
|
||||||
|
"description": json_obj.get("description", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 情况2: properties中包含字段的格式
|
||||||
|
if "properties" in json_obj and isinstance(json_obj["properties"], dict):
|
||||||
|
properties = json_obj["properties"]
|
||||||
|
|
||||||
|
if is_text_analysis and "analysis_text" in properties:
|
||||||
|
return {
|
||||||
|
"response": properties["analysis_text"],
|
||||||
|
"reasoning_process": properties.get("reasoning_process"),
|
||||||
|
"references": properties.get("references", [])
|
||||||
|
}
|
||||||
|
elif is_numerical_analysis and "value" in properties:
|
||||||
|
value = properties["value"]
|
||||||
|
if isinstance(value, str) and value.isdigit():
|
||||||
|
value = int(value)
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
value = int(value)
|
||||||
|
else:
|
||||||
|
value = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": value,
|
||||||
|
"description": properties.get("description", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果无法识别结构,使用parser解析
|
||||||
|
raise ValueError("无法识别的JSON结构,尝试使用parser解析")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 记录错误但不抛出,让调用方法继续尝试其他解析方式
|
||||||
|
logger.warning(f"JSON解析失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 设置日志级别
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建并运行聊天机器人
|
||||||
|
bot = ChatBot()
|
||||||
|
bot.run()
|
|
@ -3,7 +3,7 @@ from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
|
|
||||||
# 设置日志记录
|
# 设置日志记录
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -102,8 +102,19 @@ class ChatBot:
|
||||||
logger.error(f"初始化ChatBot时出错: {str(e)}")
|
logger.error(f"初始化ChatBot时出错: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def chat(self, user_input: str) -> str:
|
def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> str:
|
||||||
"""处理用户输入并返回AI回复"""
|
"""处理用户输入并返回AI回复
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: 用户输入的问题
|
||||||
|
temperature: 控制输出的随机性,范围0-2,默认1.0
|
||||||
|
top_p: 控制输出的多样性,范围0-1,默认0.7
|
||||||
|
max_tokens: 控制输出的最大长度,默认4096
|
||||||
|
frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: AI的回答内容
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# # 添加用户消息到对话历史-多轮
|
# # 添加用户消息到对话历史-多轮
|
||||||
self.conversation_history.append({
|
self.conversation_history.append({
|
||||||
|
@ -116,7 +127,10 @@ class ChatBot:
|
||||||
stream = self.client.chat.completions.create(
|
stream = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.conversation_history,
|
messages=self.conversation_history,
|
||||||
temperature=0,
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=600
|
timeout=600
|
||||||
)
|
)
|
||||||
|
|
|
@ -130,9 +130,15 @@ class EnterpriseScreener:
|
||||||
{
|
{
|
||||||
'dimension': 'investment_advice',
|
'dimension': 'investment_advice',
|
||||||
'field': 'investment_advice_type',
|
'field': 'investment_advice_type',
|
||||||
'operator': '!=',
|
'operator': 'in',
|
||||||
'value': '不建议'
|
'value': ["'短期'","'中期'","'长期'"]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
'dimension': 'financial_report',
|
||||||
|
'field': 'financial_report_level',
|
||||||
|
'operator': '>=',
|
||||||
|
'value': -1
|
||||||
|
},
|
||||||
]
|
]
|
||||||
return self._screen_stocks_by_conditions(conditions, limit)
|
return self._screen_stocks_by_conditions(conditions, limit)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
import sys
|
||||||
from typing import Dict, List, Optional, Tuple, Callable
|
from datetime import datetime, timedelta
|
||||||
|
import time
|
||||||
|
import redis
|
||||||
|
from typing import Dict, List, Optional, Tuple, Callable, Any
|
||||||
|
|
||||||
# 修改导入路径,使用相对导入
|
# 修改导入路径,使用相对导入
|
||||||
try:
|
try:
|
||||||
# 尝试相对导入
|
# 尝试相对导入
|
||||||
|
@ -9,6 +13,7 @@ try:
|
||||||
from .chat_bot_with_offline import ChatBot as OfflineChatBot
|
from .chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
|
from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
|
||||||
from .pdf_generator import PDFGenerator
|
from .pdf_generator import PDFGenerator
|
||||||
|
from .text_processor import TextProcessor
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 如果相对导入失败,尝试尝试绝对导入
|
# 如果相对导入失败,尝试尝试绝对导入
|
||||||
try:
|
try:
|
||||||
|
@ -16,18 +21,17 @@ except ImportError:
|
||||||
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
|
from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
|
||||||
from src.fundamentals_llm.pdf_generator import PDFGenerator
|
from src.fundamentals_llm.pdf_generator import PDFGenerator
|
||||||
|
from src.fundamentals_llm.text_processor import TextProcessor
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况)
|
# 最后尝试直接导入(适用于当前目录已在PYTHONPATH中的情况)
|
||||||
from chat_bot import ChatBot
|
from chat_bot import ChatBot
|
||||||
from chat_bot_with_offline import ChatBot as OfflineChatBot
|
from chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
|
from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
|
||||||
from pdf_generator import PDFGenerator
|
from pdf_generator import PDFGenerator
|
||||||
|
from text_processor import TextProcessor
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 设置日志记录
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 获取项目根目录的绝对路径
|
# 获取项目根目录的绝对路径
|
||||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
@ -58,6 +62,34 @@ logging.basicConfig(
|
||||||
datefmt=date_format
|
datefmt=date_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info("测试日志输出 - 程序启动")
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Any, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# 定义基础数据结构
|
||||||
|
class TextAnalysisResult(BaseModel):
|
||||||
|
"""文本分析结果,包含分析正文、推理过程和引用URL"""
|
||||||
|
analysis_text: str = Field(description="详细的分析文本")
|
||||||
|
reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None)
|
||||||
|
references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None)
|
||||||
|
|
||||||
|
class NumericalAnalysisResult(BaseModel):
|
||||||
|
"""数值分析结果,包含数值和分析描述"""
|
||||||
|
value: str = Field(description="评估值")
|
||||||
|
description: str = Field(description="评估描述")
|
||||||
|
|
||||||
|
# 添加Redis客户端
|
||||||
|
redis_client = redis.Redis(
|
||||||
|
host='192.168.18.208', # Redis服务器地址,根据实际情况调整
|
||||||
|
port=6379,
|
||||||
|
password='wlkj2018',
|
||||||
|
db=14,
|
||||||
|
socket_timeout=5,
|
||||||
|
decode_responses=True
|
||||||
|
)
|
||||||
|
|
||||||
class FundamentalAnalyzer:
|
class FundamentalAnalyzer:
|
||||||
"""基本面分析器"""
|
"""基本面分析器"""
|
||||||
|
|
||||||
|
@ -66,9 +98,10 @@ class FundamentalAnalyzer:
|
||||||
# 使用联网模型进行基本面分析
|
# 使用联网模型进行基本面分析
|
||||||
self.chat_bot = ChatBot(model_type="online_bot")
|
self.chat_bot = ChatBot(model_type="online_bot")
|
||||||
# 使用离线模型进行其他分析
|
# 使用离线模型进行其他分析
|
||||||
self.offline_bot = OfflineChatBot(platform="tl_private", model_type="ds-v1")
|
self.offline_bot = OfflineChatBot(platform="volc", model_type="offline_model")
|
||||||
# 千问打杂
|
# 千问打杂
|
||||||
self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq")
|
# self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq")
|
||||||
|
self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="GLM")
|
||||||
self.db = next(get_db())
|
self.db = next(get_db())
|
||||||
|
|
||||||
# 定义维度映射
|
# 定义维度映射
|
||||||
|
@ -161,6 +194,8 @@ class FundamentalAnalyzer:
|
||||||
- 关键战略决策和转型
|
- 关键战略决策和转型
|
||||||
|
|
||||||
请提供专业、客观的分析,突出关键信息,避免冗长描述。"""
|
请提供专业、客观的分析,突出关键信息,避免冗长描述。"""
|
||||||
|
#开头清上下文缓存
|
||||||
|
self.chat_bot.clear_history()
|
||||||
# 获取AI分析结果
|
# 获取AI分析结果
|
||||||
result = self.chat_bot.chat(prompt)
|
result = self.chat_bot.chat(prompt)
|
||||||
|
|
||||||
|
@ -554,8 +589,8 @@ class FundamentalAnalyzer:
|
||||||
请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。"""
|
请仅返回一个数值:2、1、0或-1,不要包含任何解释或说明。"""
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
# 使用离线模型进行分析
|
# 使用离线模型进行分析
|
||||||
space_value_str = self.offline_bot_tl_qw.chat(prompt)
|
space_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
space_value_str = self._clean_model_output(space_value_str)
|
space_value_str = TextProcessor.clean_thought_process(space_value_str)
|
||||||
# 提取数值
|
# 提取数值
|
||||||
space_value = 0 # 默认值
|
space_value = 0 # 默认值
|
||||||
|
|
||||||
|
@ -658,9 +693,9 @@ class FundamentalAnalyzer:
|
||||||
请仅返回一个数值:1、0或-1,不要包含任何解释或说明。"""
|
请仅返回一个数值:1、0或-1,不要包含任何解释或说明。"""
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
# 使用离线模型进行分析
|
# 使用离线模型进行分析
|
||||||
events_value_str = self.offline_bot_tl_qw.chat(prompt)
|
events_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
# 数据清洗
|
# 数据清洗
|
||||||
events_value_str = self._clean_model_output(events_value_str)
|
events_value_str = TextProcessor.clean_thought_process(events_value_str)
|
||||||
# 提取数值
|
# 提取数值
|
||||||
events_value = 0 # 默认值
|
events_value = 0 # 默认值
|
||||||
|
|
||||||
|
@ -700,14 +735,14 @@ class FundamentalAnalyzer:
|
||||||
def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool:
|
def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool:
|
||||||
"""分析股吧讨论内容"""
|
"""分析股吧讨论内容"""
|
||||||
try:
|
try:
|
||||||
prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出:
|
prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在400字以内(主要讨论话题200字,重要信息汇总200字),请严格按照以下格式输出:
|
||||||
|
|
||||||
1. 主要讨论话题(150字左右):
|
1. 主要讨论话题:
|
||||||
- 近期热点事件
|
- 近期热点事件
|
||||||
- 投资者关注焦点
|
- 投资者关注焦点
|
||||||
- 市场情绪倾向
|
- 市场情绪倾向
|
||||||
|
|
||||||
2. 重要信息汇总(150字左右):
|
2. 重要信息汇总:
|
||||||
- 公司相关动态
|
- 公司相关动态
|
||||||
- 行业政策变化
|
- 行业政策变化
|
||||||
- 市场预期变化
|
- 市场预期变化
|
||||||
|
@ -762,10 +797,10 @@ class FundamentalAnalyzer:
|
||||||
请仅返回一个数值:1、0或-1,不要包含任何解释或说明。"""
|
请仅返回一个数值:1、0或-1,不要包含任何解释或说明。"""
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
# 使用离线模型进行分析
|
# 使用离线模型进行分析
|
||||||
emotion_value_str = self.offline_bot_tl_qw.chat(prompt)
|
emotion_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
|
|
||||||
# 数据清洗
|
# 数据清洗
|
||||||
emotion_value_str = self._clean_model_output(emotion_value_str)
|
emotion_value_str = TextProcessor.clean_thought_process(emotion_value_str)
|
||||||
# 提取数值
|
# 提取数值
|
||||||
emotion_value = 0 # 默认值
|
emotion_value = 0 # 默认值
|
||||||
|
|
||||||
|
@ -1015,10 +1050,10 @@ class FundamentalAnalyzer:
|
||||||
|
|
||||||
只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。"""
|
只需要输出一个数值,不要输出任何说明或解释。只输出:2,1,0,-1或-2。"""
|
||||||
|
|
||||||
response = self.chat_bot.chat(prompt)
|
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
|
|
||||||
# 提取数值
|
# 提取数值
|
||||||
rating_str = self._extract_numeric_value_from_response(response)
|
rating_str = TextProcessor.extract_numeric_value_from_response(response)
|
||||||
|
|
||||||
# 尝试将响应转换为整数
|
# 尝试将响应转换为整数
|
||||||
try:
|
try:
|
||||||
|
@ -1057,10 +1092,10 @@ class FundamentalAnalyzer:
|
||||||
|
|
||||||
只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。"""
|
只需要输出一个数值,不要输出任何说明或解释。只输出:1、0、-1,不要包含任何解释或说明。"""
|
||||||
|
|
||||||
response = self.chat_bot.chat(prompt)
|
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
|
|
||||||
# 提取数值
|
# 提取数值
|
||||||
odds_str = self._extract_numeric_value_from_response(response)
|
odds_str = TextProcessor.extract_numeric_value_from_response(response)
|
||||||
|
|
||||||
# 尝试将响应转换为整数
|
# 尝试将响应转换为整数
|
||||||
try:
|
try:
|
||||||
|
@ -1200,8 +1235,8 @@ class FundamentalAnalyzer:
|
||||||
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
||||||
|
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
response = self.offline_bot_tl_qw.chat(prompt)
|
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
pe_hist_str = self._clean_model_output(response)
|
pe_hist_str = TextProcessor.clean_thought_process(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pe_hist = int(pe_hist_str)
|
pe_hist = int(pe_hist_str)
|
||||||
|
@ -1240,8 +1275,8 @@ class FundamentalAnalyzer:
|
||||||
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
||||||
|
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
response = self.offline_bot_tl_qw.chat(prompt)
|
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
pb_hist_str = self._clean_model_output(response)
|
pb_hist_str = TextProcessor.clean_thought_process(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pb_hist = int(pb_hist_str)
|
pb_hist = int(pb_hist_str)
|
||||||
|
@ -1280,8 +1315,8 @@ class FundamentalAnalyzer:
|
||||||
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
||||||
|
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
response = self.offline_bot_tl_qw.chat(prompt)
|
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
pe_ind_str = self._clean_model_output(response)
|
pe_ind_str = TextProcessor.clean_thought_process(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pe_ind = int(pe_ind_str)
|
pe_ind = int(pe_ind_str)
|
||||||
|
@ -1320,8 +1355,8 @@ class FundamentalAnalyzer:
|
||||||
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
只需要输出一个数值,不要输出任何说明或解释。只输出:-1、0或1。"""
|
||||||
|
|
||||||
self.offline_bot_tl_qw.clear_history()
|
self.offline_bot_tl_qw.clear_history()
|
||||||
response = self.offline_bot_tl_qw.chat(prompt)
|
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
pb_ind_str = self._clean_model_output(response)
|
pb_ind_str = TextProcessor.clean_thought_process(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pb_ind = int(pb_ind_str)
|
pb_ind = int(pb_ind_str)
|
||||||
|
@ -1373,9 +1408,9 @@ class FundamentalAnalyzer:
|
||||||
{json.dumps(all_results, ensure_ascii=False, indent=2)}"""
|
{json.dumps(all_results, ensure_ascii=False, indent=2)}"""
|
||||||
self.offline_bot.clear_history()
|
self.offline_bot.clear_history()
|
||||||
# 使用离线模型生成建议
|
# 使用离线模型生成建议
|
||||||
result = self.offline_bot.chat(prompt)
|
result = self.offline_bot.chat(prompt,max_tokens=20000)
|
||||||
# 清理模型输出
|
# 清理模型输出
|
||||||
result = self._clean_model_output(result)
|
result = TextProcessor.clean_thought_process(result)
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
success = save_analysis_result(
|
success = save_analysis_result(
|
||||||
self.db,
|
self.db,
|
||||||
|
@ -1441,93 +1476,6 @@ class FundamentalAnalyzer:
|
||||||
logger.error(f"提取投资建议类型失败: {str(e)}")
|
logger.error(f"提取投资建议类型失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _clean_model_output(self, output: str) -> str:
|
|
||||||
"""清理模型输出,移除推理过程,只保留最终结果
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output: 模型原始输出文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 清理后的输出文本
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 找到</think>标签的位置
|
|
||||||
think_end = output.find('</think>')
|
|
||||||
if think_end != -1:
|
|
||||||
# 移除</think>标签及其之前的所有内容
|
|
||||||
output = output[think_end + len('</think>'):]
|
|
||||||
|
|
||||||
# 处理可能存在的空行
|
|
||||||
lines = output.split('\n')
|
|
||||||
cleaned_lines = []
|
|
||||||
for line in lines:
|
|
||||||
line = line.strip()
|
|
||||||
if line: # 只保留非空行
|
|
||||||
cleaned_lines.append(line)
|
|
||||||
|
|
||||||
# 重新组合文本
|
|
||||||
output = '\n'.join(cleaned_lines)
|
|
||||||
|
|
||||||
return output.strip()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理模型输出失败: {str(e)}")
|
|
||||||
return output.strip()
|
|
||||||
|
|
||||||
def _extract_numeric_value_from_response(self, response: str) -> str:
|
|
||||||
"""从模型响应中提取数值,移除参考资料和推理过程
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: 模型原始响应文本或响应对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 提取的数值字符串
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 处理响应对象(包含response字段的字典)
|
|
||||||
if isinstance(response, dict) and "response" in response:
|
|
||||||
response = response["response"]
|
|
||||||
|
|
||||||
# 确保响应是字符串
|
|
||||||
if not isinstance(response, str):
|
|
||||||
logger.warning(f"响应不是字符串类型: {type(response)}")
|
|
||||||
return "0"
|
|
||||||
|
|
||||||
# 移除推理过程部分
|
|
||||||
reasoning_start = response.find("推理过程:")
|
|
||||||
if reasoning_start != -1:
|
|
||||||
response = response[:reasoning_start].strip()
|
|
||||||
|
|
||||||
# 移除参考资料部分(通常以 [数字] 开头的行)
|
|
||||||
lines = response.split("\n")
|
|
||||||
cleaned_lines = []
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
# 跳过参考资料行(通常以 [数字] 开头)
|
|
||||||
if re.match(r'\[\d+\]', line.strip()):
|
|
||||||
continue
|
|
||||||
cleaned_lines.append(line)
|
|
||||||
|
|
||||||
response = "\n".join(cleaned_lines).strip()
|
|
||||||
|
|
||||||
# 提取数值
|
|
||||||
# 先尝试直接将整个响应转换为数值
|
|
||||||
if response.strip() in ["-2", "-1", "0", "1", "2"]:
|
|
||||||
return response.strip()
|
|
||||||
|
|
||||||
# 如果整个响应不是数值,尝试匹配第一个数值
|
|
||||||
match = re.search(r'([-]?[0-9])', response)
|
|
||||||
if match:
|
|
||||||
return match.group(1)
|
|
||||||
|
|
||||||
# 如果没有找到数值,返回默认值
|
|
||||||
logger.warning(f"未能从响应中提取数值: {response}")
|
|
||||||
return "0"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从响应中提取数值失败: {str(e)}")
|
|
||||||
return "0"
|
|
||||||
|
|
||||||
def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]:
|
def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]:
|
||||||
"""尝试多次从投资建议中提取建议类型
|
"""尝试多次从投资建议中提取建议类型
|
||||||
|
|
||||||
|
@ -1549,7 +1497,7 @@ class FundamentalAnalyzer:
|
||||||
|
|
||||||
# 使用千问离线模型提取建议类型
|
# 使用千问离线模型提取建议类型
|
||||||
|
|
||||||
result = self.offline_bot_tl_qw.chat(prompt)
|
result = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
|
||||||
|
|
||||||
# 检查是否是错误响应
|
# 检查是否是错误响应
|
||||||
if isinstance(result, str) and "抱歉,发生错误" in result:
|
if isinstance(result, str) and "抱歉,发生错误" in result:
|
||||||
|
@ -1557,7 +1505,7 @@ class FundamentalAnalyzer:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 清理模型输出
|
# 清理模型输出
|
||||||
cleaned_result = self._clean_model_output(result)
|
cleaned_result = TextProcessor.clean_thought_process(result)
|
||||||
|
|
||||||
# 检查结果是否为有效类型
|
# 检查结果是否为有效类型
|
||||||
if cleaned_result in valid_types:
|
if cleaned_result in valid_types:
|
||||||
|
@ -1644,6 +1592,24 @@ class FundamentalAnalyzer:
|
||||||
Optional[str]: 生成的PDF文件路径,如果失败则返回None
|
Optional[str]: 生成的PDF文件路径,如果失败则返回None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 检查是否已存在PDF文件
|
||||||
|
reports_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'reports')
|
||||||
|
os.makedirs(reports_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 构建可能的文件名格式
|
||||||
|
possible_filenames = [
|
||||||
|
f"{stock_name}_{stock_code}_analysis.pdf",
|
||||||
|
f"{stock_name}_{stock_code}.SZ_analysis.pdf",
|
||||||
|
f"{stock_name}_{stock_code}.SH_analysis.pdf"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 检查是否存在已生成的PDF文件
|
||||||
|
for filename in possible_filenames:
|
||||||
|
filepath = os.path.join(reports_dir, filename)
|
||||||
|
if os.path.exists(filepath):
|
||||||
|
logger.info(f"找到已存在的PDF报告: {filepath}")
|
||||||
|
return filepath
|
||||||
|
|
||||||
# 维度名称映射
|
# 维度名称映射
|
||||||
dimension_names = {
|
dimension_names = {
|
||||||
"company_profile": "公司简介",
|
"company_profile": "公司简介",
|
||||||
|
@ -1669,27 +1635,156 @@ class FundamentalAnalyzer:
|
||||||
logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果")
|
logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# 确保字体目录存在
|
||||||
|
fonts_dir = os.path.join(os.path.dirname(__file__), "fonts")
|
||||||
|
os.makedirs(fonts_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 检查是否存在字体文件,如果不存在则创建简单的默认字体标记文件
|
||||||
|
font_path = os.path.join(fonts_dir, "simhei.ttf")
|
||||||
|
if not os.path.exists(font_path):
|
||||||
|
# 尝试从系统字体目录复制
|
||||||
|
try:
|
||||||
|
import shutil
|
||||||
|
# 尝试常见的系统字体位置
|
||||||
|
system_fonts = [
|
||||||
|
"C:/Windows/Fonts/simhei.ttf", # Windows
|
||||||
|
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # Linux
|
||||||
|
"/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", # 其他Linux
|
||||||
|
"/System/Library/Fonts/PingFang.ttc" # macOS
|
||||||
|
]
|
||||||
|
|
||||||
|
for system_font in system_fonts:
|
||||||
|
if os.path.exists(system_font):
|
||||||
|
shutil.copy2(system_font, font_path)
|
||||||
|
logger.info(f"已复制字体文件: {system_font} -> {font_path}")
|
||||||
|
break
|
||||||
|
except Exception as font_error:
|
||||||
|
logger.warning(f"复制字体文件失败: {str(font_error)}")
|
||||||
|
|
||||||
# 创建PDF生成器实例
|
# 创建PDF生成器实例
|
||||||
generator = PDFGenerator()
|
generator = PDFGenerator()
|
||||||
|
|
||||||
# 生成PDF报告
|
# 生成PDF报告
|
||||||
|
try:
|
||||||
|
# 第一次尝试生成
|
||||||
filepath = generator.generate_pdf(
|
filepath = generator.generate_pdf(
|
||||||
title=f"{stock_name}({stock_code}) 基本面分析报告",
|
title=f"{stock_name}({stock_code}) 基本面分析报告",
|
||||||
content_dict=content_dict,
|
content_dict=content_dict,
|
||||||
|
output_dir=reports_dir,
|
||||||
filename=f"{stock_name}_{stock_code}_analysis.pdf"
|
filename=f"{stock_name}_{stock_code}_analysis.pdf"
|
||||||
)
|
)
|
||||||
|
|
||||||
if filepath:
|
if filepath:
|
||||||
logger.info(f"PDF报告已生成: {filepath}")
|
logger.info(f"PDF报告已生成: {filepath}")
|
||||||
else:
|
|
||||||
logger.error("PDF报告生成失败")
|
|
||||||
|
|
||||||
return filepath
|
return filepath
|
||||||
|
except Exception as pdf_error:
|
||||||
|
logger.error(f"生成PDF报告第一次尝试失败: {str(pdf_error)}")
|
||||||
|
|
||||||
|
# 如果是字体问题,可能需要使用备选方案
|
||||||
|
if "找不到中文字体文件" in str(pdf_error):
|
||||||
|
# 导出为文本文件作为备选
|
||||||
|
try:
|
||||||
|
txt_filename = f"{stock_name}_{stock_code}_analysis.txt"
|
||||||
|
txt_filepath = os.path.join(reports_dir, txt_filename)
|
||||||
|
|
||||||
|
with open(txt_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(f"{stock_name}({stock_code}) 基本面分析报告\n")
|
||||||
|
f.write(f"生成时间:{datetime.now().strftime('%Y年%m月%d日 %H:%M:%S')}\n\n")
|
||||||
|
|
||||||
|
for section_title, content in content_dict.items():
|
||||||
|
if content:
|
||||||
|
f.write(f"## {section_title}\n\n")
|
||||||
|
f.write(f"{content}\n\n")
|
||||||
|
|
||||||
|
logger.info(f"由于PDF生成失败,已生成文本报告: {txt_filepath}")
|
||||||
|
return txt_filepath
|
||||||
|
except Exception as txt_error:
|
||||||
|
logger.error(f"生成文本报告失败: {str(txt_error)}")
|
||||||
|
|
||||||
|
logger.error("PDF报告生成失败")
|
||||||
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成PDF报告失败: {str(e)}")
|
logger.error(f"生成PDF报告失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def is_stock_locked(self, stock_code: str, dimension: str) -> bool:
|
||||||
|
"""检查股票是否已被锁定
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: 股票代码
|
||||||
|
dimension: 分析维度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否被锁定
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
lock_key = f"stock_analysis_lock:{stock_code}:{dimension}"
|
||||||
|
|
||||||
|
# 检查是否存在锁
|
||||||
|
existing_lock = redis_client.get(lock_key)
|
||||||
|
if existing_lock:
|
||||||
|
lock_time = int(existing_lock)
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
# 锁超过30分钟(1800秒)视为过期
|
||||||
|
if current_time - lock_time > 1800:
|
||||||
|
# 锁已过期,可以释放
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
logger.info(f"股票 {stock_code} 维度 {dimension} 的过期锁已释放")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# 锁未过期,被锁定
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 不存在锁
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查股票 {stock_code} 锁状态时出错: {str(e)}")
|
||||||
|
# 出错时保守处理,返回未锁定
|
||||||
|
return False
|
||||||
|
|
||||||
|
def lock_stock(self, stock_code: str, dimension: str) -> bool:
|
||||||
|
"""锁定股票
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: 股票代码
|
||||||
|
dimension: 分析维度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功锁定
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
lock_key = f"stock_analysis_lock:{stock_code}:{dimension}"
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
# 设置锁,过期时间1小时
|
||||||
|
redis_client.set(lock_key, current_time, ex=3600)
|
||||||
|
logger.info(f"股票 {stock_code} 维度 {dimension} 已锁定")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"锁定股票 {stock_code} 时出错: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unlock_stock(self, stock_code: str, dimension: str) -> bool:
|
||||||
|
"""解锁股票
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: 股票代码
|
||||||
|
dimension: 分析维度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功解锁
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
lock_key = f"stock_analysis_lock:{stock_code}:{dimension}"
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
logger.info(f"股票 {stock_code} 维度 {dimension} 已解锁")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解锁股票 {stock_code} 时出错: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool:
|
def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool:
|
||||||
"""测试单个分析方法"""
|
"""测试单个分析方法"""
|
||||||
try:
|
try:
|
||||||
|
@ -1713,11 +1808,6 @@ def test_single_stock(analyzer: FundamentalAnalyzer, stock_code: str, stock_name
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
# 设置日志级别
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试股票列表
|
# 测试股票列表
|
||||||
test_stocks = [
|
test_stocks = [
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,25 +11,25 @@ import markdown2
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fpdf import FPDF
|
import shutil
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
# 修改导入路径,使用相对导入
|
# 修改导入路径,使用相对导入
|
||||||
try:
|
try:
|
||||||
# 尝试相对导入
|
# 尝试相对导入
|
||||||
from .chat_bot_with_offline import ChatBot
|
from .chat_bot import ChatBot
|
||||||
from .fundamental_analysis_database import get_db, get_analysis_result
|
from .chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 如果相对导入失败,尝试绝对导入
|
# 如果相对导入失败,尝试绝对导入
|
||||||
try:
|
try:
|
||||||
from src.fundamentals_llm.chat_bot_with_offline import ChatBot
|
from src.fundamentals_llm.chat_bot import ChatBot
|
||||||
from src.fundamentals_llm.fundamental_analysis_database import get_db, get_analysis_result
|
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# 最后尝试直接导入
|
# 最后尝试直接导入
|
||||||
from chat_bot_with_offline import ChatBot
|
from chat_bot import ChatBot
|
||||||
from fundamental_analysis_database import get_db, get_analysis_result
|
from chat_bot_with_offline import ChatBot as OfflineChatBot
|
||||||
|
|
||||||
# 设置日志记录
|
# 设置日志记录
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -39,30 +39,9 @@ class PDFGenerator:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化PDF生成器"""
|
"""初始化PDF生成器"""
|
||||||
# 注册中文字体
|
# 尝试注册中文字体
|
||||||
try:
|
self.font_name = self._register_chinese_font()
|
||||||
# 尝试使用系统自带的中文字体
|
self.chat_bot = OfflineChatBot(platform="volc", model_type="offline_model") # 使用离线模式
|
||||||
if os.name == 'nt': # Windows
|
|
||||||
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
|
||||||
else: # Linux/Mac
|
|
||||||
font_path = "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf"
|
|
||||||
|
|
||||||
if os.path.exists(font_path):
|
|
||||||
pdfmetrics.registerFont(TTFont('SimHei', font_path))
|
|
||||||
self.font_name = 'SimHei'
|
|
||||||
else:
|
|
||||||
# 如果找不到系统字体,尝试使用当前目录下的字体
|
|
||||||
font_path = os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf")
|
|
||||||
if os.path.exists(font_path):
|
|
||||||
pdfmetrics.registerFont(TTFont('SimHei', font_path))
|
|
||||||
self.font_name = 'SimHei'
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError("找不到中文字体文件")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"注册中文字体失败: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.chat_bot = ChatBot()
|
|
||||||
self.styles = getSampleStyleSheet()
|
self.styles = getSampleStyleSheet()
|
||||||
|
|
||||||
# 创建自定义样式
|
# 创建自定义样式
|
||||||
|
@ -139,6 +118,64 @@ class PDFGenerator:
|
||||||
textColor=colors.HexColor('#333333')
|
textColor=colors.HexColor('#333333')
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def _register_chinese_font(self):
|
||||||
|
"""查找并注册中文字体
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 注册的字体名称,如果失败则使用默认字体
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 可能的字体文件位置列表
|
||||||
|
font_locations = [
|
||||||
|
# 当前目录/子目录字体位置
|
||||||
|
os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf"),
|
||||||
|
os.path.join(os.path.dirname(__file__), "fonts", "wqy-microhei.ttc"),
|
||||||
|
os.path.join(os.path.dirname(__file__), "fonts", "wqy-zenhei.ttc"),
|
||||||
|
# Docker中可能的位置
|
||||||
|
"/app/src/fundamentals_llm/fonts/simhei.ttf",
|
||||||
|
# Windows系统字体位置
|
||||||
|
"C:/Windows/Fonts/simhei.ttf",
|
||||||
|
"C:/Windows/Fonts/simfang.ttf",
|
||||||
|
"C:/Windows/Fonts/simsun.ttc",
|
||||||
|
# Linux/Mac系统字体位置
|
||||||
|
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
|
||||||
|
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
|
||||||
|
"/usr/share/fonts/wqy-microhei/wqy-microhei.ttc",
|
||||||
|
"/usr/share/fonts/wqy-zenhei/wqy-zenhei.ttc",
|
||||||
|
"/usr/share/fonts/truetype/droid/DroidSansFallback.ttf",
|
||||||
|
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
|
||||||
|
"/System/Library/Fonts/PingFang.ttc" # macOS
|
||||||
|
]
|
||||||
|
|
||||||
|
# 尝试各个位置
|
||||||
|
for font_path in font_locations:
|
||||||
|
if os.path.exists(font_path):
|
||||||
|
logger.info(f"使用字体文件: {font_path}")
|
||||||
|
# 为防止字体文件名称不同,统一拷贝到字体目录并重命名
|
||||||
|
font_dir = os.path.join(os.path.dirname(__file__), "fonts")
|
||||||
|
os.makedirs(font_dir, exist_ok=True)
|
||||||
|
target_path = os.path.join(font_dir, "simhei.ttf")
|
||||||
|
|
||||||
|
# 如果字体文件不在目标位置,拷贝过去
|
||||||
|
if font_path != target_path and not os.path.exists(target_path):
|
||||||
|
try:
|
||||||
|
shutil.copy2(font_path, target_path)
|
||||||
|
logger.info(f"已将字体文件复制到: {target_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"复制字体文件失败: {str(e)}")
|
||||||
|
|
||||||
|
# 注册字体
|
||||||
|
pdfmetrics.registerFont(TTFont('SimHei', target_path))
|
||||||
|
return 'SimHei'
|
||||||
|
|
||||||
|
# 如果所有位置都找不到,使用默认字体
|
||||||
|
logger.warning("找不到中文字体文件,将使用默认字体")
|
||||||
|
return 'Helvetica'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"注册中文字体失败: {str(e)}")
|
||||||
|
return 'Helvetica' # 使用默认字体
|
||||||
|
|
||||||
def _convert_markdown_to_flowables(self, markdown_text: str) -> List:
|
def _convert_markdown_to_flowables(self, markdown_text: str) -> List:
|
||||||
"""将Markdown文本转换为PDF流对象
|
"""将Markdown文本转换为PDF流对象
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,190 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
中文字体安装工具
|
||||||
|
|
||||||
|
此脚本用于下载和安装中文字体,以支持PDF生成功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import platform
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def download_wqy_font():
|
||||||
|
"""
|
||||||
|
下载文泉驿微米黑字体
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 下载的字体文件路径,如果失败则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# 创建临时目录
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
font_file = os.path.join(temp_dir, "wqy-microhei.ttc")
|
||||||
|
|
||||||
|
# 文泉驿微米黑字体的下载链接
|
||||||
|
url = "https://mirrors.tuna.tsinghua.edu.cn/osdn/wqy/wqy-microhei-0.2.0-beta.tar.gz"
|
||||||
|
|
||||||
|
logger.info(f"开始下载文泉驿微米黑字体: {url}")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 下载压缩包
|
||||||
|
tar_file = os.path.join(temp_dir, "wqy-microhei.tar.gz")
|
||||||
|
with open(tar_file, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
# 解压字体文件
|
||||||
|
import tarfile
|
||||||
|
with tarfile.open(tar_file) as tar:
|
||||||
|
# 提取字体文件
|
||||||
|
for member in tar.getmembers():
|
||||||
|
if member.name.endswith(('.ttc', '.ttf')):
|
||||||
|
member.name = os.path.basename(member.name)
|
||||||
|
tar.extract(member, temp_dir)
|
||||||
|
extracted_file = os.path.join(temp_dir, member.name)
|
||||||
|
if os.path.exists(extracted_file):
|
||||||
|
return extracted_file
|
||||||
|
|
||||||
|
logger.error("在压缩包中未找到字体文件")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载字体失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_system_fonts():
|
||||||
|
"""
|
||||||
|
在系统中查找可用的中文字体
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 找到的字体文件路径,如果找不到则返回None
|
||||||
|
"""
|
||||||
|
# 常见的中文字体位置
|
||||||
|
font_locations = []
|
||||||
|
|
||||||
|
system = platform.system()
|
||||||
|
if system == "Windows":
|
||||||
|
font_locations = [
|
||||||
|
"C:/Windows/Fonts/simhei.ttf",
|
||||||
|
"C:/Windows/Fonts/simsun.ttc",
|
||||||
|
"C:/Windows/Fonts/msyh.ttf",
|
||||||
|
"C:/Windows/Fonts/simfang.ttf",
|
||||||
|
]
|
||||||
|
elif system == "Darwin": # macOS
|
||||||
|
font_locations = [
|
||||||
|
"/System/Library/Fonts/PingFang.ttc",
|
||||||
|
"/Library/Fonts/Microsoft/SimHei.ttf",
|
||||||
|
"/Library/Fonts/Microsoft/SimSun.ttf",
|
||||||
|
]
|
||||||
|
else: # Linux
|
||||||
|
font_locations = [
|
||||||
|
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
|
||||||
|
"/usr/share/fonts/wqy-microhei/wqy-microhei.ttc",
|
||||||
|
"/usr/share/fonts/wenquanyi/wqy-microhei/wqy-microhei.ttc",
|
||||||
|
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
|
||||||
|
"/usr/share/fonts/truetype/droid/DroidSansFallback.ttf",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 查找第一个存在的字体文件
|
||||||
|
for font_path in font_locations:
|
||||||
|
if os.path.exists(font_path):
|
||||||
|
logger.info(f"在系统中找到中文字体: {font_path}")
|
||||||
|
return font_path
|
||||||
|
|
||||||
|
logger.warning("在系统中未找到任何中文字体")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def install_font(target_dir=None):
|
||||||
|
"""
|
||||||
|
安装中文字体
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_dir: 目标目录,如果不提供则使用默认目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 安装后的字体文件路径,如果失败则返回None
|
||||||
|
"""
|
||||||
|
# 如果没有指定目标目录,使用默认位置
|
||||||
|
if target_dir is None:
|
||||||
|
# 获取当前脚本所在目录
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
target_dir = os.path.join(current_dir, "fonts")
|
||||||
|
|
||||||
|
# 确保目标目录存在
|
||||||
|
os.makedirs(target_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 目标字体文件路径
|
||||||
|
target_font = os.path.join(target_dir, "simhei.ttf")
|
||||||
|
|
||||||
|
# 如果目标字体已存在,直接返回
|
||||||
|
if os.path.exists(target_font):
|
||||||
|
logger.info(f"中文字体已存在: {target_font}")
|
||||||
|
return target_font
|
||||||
|
|
||||||
|
# 尝试从系统中复制字体
|
||||||
|
system_font = find_system_fonts()
|
||||||
|
if system_font:
|
||||||
|
try:
|
||||||
|
shutil.copy2(system_font, target_font)
|
||||||
|
logger.info(f"已从系统复制字体: {system_font} -> {target_font}")
|
||||||
|
return target_font
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"复制系统字体失败: {str(e)}")
|
||||||
|
|
||||||
|
# 如果系统中找不到字体,尝试下载
|
||||||
|
logger.info("尝试下载中文字体...")
|
||||||
|
downloaded_font = download_wqy_font()
|
||||||
|
if downloaded_font:
|
||||||
|
try:
|
||||||
|
shutil.copy2(downloaded_font, target_font)
|
||||||
|
logger.info(f"已安装下载的字体: {downloaded_font} -> {target_font}")
|
||||||
|
return target_font
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"复制下载的字体失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.error("安装中文字体失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
主函数
|
||||||
|
"""
|
||||||
|
print("中文字体安装工具")
|
||||||
|
print("----------------")
|
||||||
|
|
||||||
|
target_dir = None
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
target_dir = sys.argv[1]
|
||||||
|
print(f"使用指定的目标目录: {target_dir}")
|
||||||
|
|
||||||
|
result = install_font(target_dir)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
print(f"\n✓ 成功安装中文字体: {result}")
|
||||||
|
print("\n现在可以正常生成PDF报告了。")
|
||||||
|
else:
|
||||||
|
print("\n✗ 安装中文字体失败")
|
||||||
|
print("\n请手动将中文字体(.ttf或.ttc文件)复制到以下目录:")
|
||||||
|
if target_dir:
|
||||||
|
print(f" {target_dir}")
|
||||||
|
else:
|
||||||
|
print(f" {os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fonts')}")
|
||||||
|
print("\n并将其重命名为'simhei.ttf'")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,172 @@
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TextProcessor:
|
||||||
|
"""大模型文本处理工具类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_model_output(output: str) -> str:
|
||||||
|
"""清理模型输出文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: 模型输出的原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 清理后的文本
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 移除引用标记
|
||||||
|
output = re.sub(r'\[[^\]]*\]', '', output)
|
||||||
|
|
||||||
|
# 移除多余的空白字符
|
||||||
|
output = re.sub(r'\s+', ' ', output)
|
||||||
|
|
||||||
|
# 移除首尾的空白字符
|
||||||
|
output = output.strip()
|
||||||
|
|
||||||
|
# 移除可能存在的HTML标签
|
||||||
|
output = re.sub(r'<[^>]+>', '', output)
|
||||||
|
|
||||||
|
# 移除可能存在的Markdown格式
|
||||||
|
output = re.sub(r'[*_`~]', '', output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清理模型输出失败: {str(e)}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_thought_process(output: str) -> str:
|
||||||
|
"""清理模型输出,移除推理过程,只保留最终结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: 模型原始输出文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 清理后的输出文本
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 找到</think>标签的位置
|
||||||
|
think_end = output.find('</think>')
|
||||||
|
if think_end != -1:
|
||||||
|
# 移除</think>标签及其之前的所有内容
|
||||||
|
output = output[think_end + len('</think>'):]
|
||||||
|
|
||||||
|
# 处理可能存在的空行
|
||||||
|
lines = output.split('\n')
|
||||||
|
cleaned_lines = []
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line: # 只保留非空行
|
||||||
|
cleaned_lines.append(line)
|
||||||
|
|
||||||
|
# 重新组合文本
|
||||||
|
output = '\n'.join(cleaned_lines)
|
||||||
|
|
||||||
|
return output.strip()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清理模型输出失败: {str(e)}")
|
||||||
|
return output.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_numeric_value_from_response(response: str) -> str:
|
||||||
|
"""从模型响应中提取数值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: 模型响应的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取出的数值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 清理响应文本
|
||||||
|
cleaned_response = TextProcessor.clean_model_output(response)
|
||||||
|
|
||||||
|
# 尝试提取数值
|
||||||
|
numeric_patterns = [
|
||||||
|
r'[-+]?\d*\.\d+', # 浮点数
|
||||||
|
r'[-+]?\d+', # 整数
|
||||||
|
r'[-+]?\d+\.\d*' # 可能带小数点的数
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in numeric_patterns:
|
||||||
|
matches = re.findall(pattern, cleaned_response)
|
||||||
|
if matches:
|
||||||
|
# 返回第一个匹配的数值
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
# 如果没有找到数值,返回原始响应
|
||||||
|
return cleaned_response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取数值失败: {str(e)}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_list_from_response(response: str) -> List[str]:
|
||||||
|
"""从模型响应中提取列表项
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: 模型响应的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 提取出的列表项
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 清理响应文本
|
||||||
|
cleaned_response = TextProcessor.clean_model_output(response)
|
||||||
|
|
||||||
|
# 尝试匹配列表项
|
||||||
|
# 匹配数字编号的列表项
|
||||||
|
numbered_items = re.findall(r'\d+\.\s*([^\n]+)', cleaned_response)
|
||||||
|
if numbered_items:
|
||||||
|
return [item.strip() for item in numbered_items]
|
||||||
|
|
||||||
|
# 匹配带符号的列表项
|
||||||
|
bullet_items = re.findall(r'[-*•]\s*([^\n]+)', cleaned_response)
|
||||||
|
if bullet_items:
|
||||||
|
return [item.strip() for item in bullet_items]
|
||||||
|
|
||||||
|
# 如果没有找到列表项,返回空列表
|
||||||
|
return []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取列表项失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_key_value_pairs(response: str) -> Dict[str, str]:
|
||||||
|
"""从模型响应中提取键值对
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: 模型响应的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: 提取出的键值对
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 清理响应文本
|
||||||
|
cleaned_response = TextProcessor.clean_model_output(response)
|
||||||
|
|
||||||
|
# 尝试匹配键值对
|
||||||
|
# 匹配冒号分隔的键值对
|
||||||
|
pairs = re.findall(r'([^:]+):\s*([^\n]+)', cleaned_response)
|
||||||
|
if pairs:
|
||||||
|
return {key.strip(): value.strip() for key, value in pairs}
|
||||||
|
|
||||||
|
# 匹配等号分隔的键值对
|
||||||
|
pairs = re.findall(r'([^=]+)=\s*([^\n]+)', cleaned_response)
|
||||||
|
if pairs:
|
||||||
|
return {key.strip(): value.strip() for key, value in pairs}
|
||||||
|
|
||||||
|
# 如果没有找到键值对,返回空字典
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取键值对失败: {str(e)}")
|
||||||
|
return {}
|
|
@ -66,12 +66,22 @@ MODEL_CONFIGS = {
|
||||||
"doubao": "doubao-1-5-pro-32k-250115"
|
"doubao": "doubao-1-5-pro-32k-250115"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
# 谷歌Gemini
|
||||||
|
"Gemini": {
|
||||||
|
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
|
"api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU",
|
||||||
|
"models": {
|
||||||
|
"offline_model": "gemini-2.0-flash"
|
||||||
|
}
|
||||||
|
},
|
||||||
# 天链苹果
|
# 天链苹果
|
||||||
"tl_private": {
|
"tl_private": {
|
||||||
"base_url": "http://192.168.32.118:1234/v1/",
|
"base_url": "http://192.168.16.174:1234/v1/",
|
||||||
"api_key": "none",
|
"api_key": "none",
|
||||||
"models": {
|
"models": {
|
||||||
"ds-v1": "mlx-community/DeepSeek-R1-4bit"
|
"glm-z1": "glm-z1-rumination-32b-0414",
|
||||||
|
"glm-4": "glm-4-32b-0414-abliterated",
|
||||||
|
"ds_v1": "mlx-community/DeepSeek-R1-4bit",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
# 天链-千问
|
# 天链-千问
|
||||||
|
@ -79,7 +89,8 @@ MODEL_CONFIGS = {
|
||||||
"base_url": "http://192.168.16.178:11434/v1",
|
"base_url": "http://192.168.16.178:11434/v1",
|
||||||
"api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019",
|
"api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019",
|
||||||
"models": {
|
"models": {
|
||||||
"qwq": "qwq:32b"
|
"qwq": "qwq:32b",
|
||||||
|
"GLM": "hf-mirror.com/Cobra4687/GLM-4-32B-0414-abliterated-Q4_K_M-GGUF:Q4_K_M"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
# Deepseek配置
|
# Deepseek配置
|
||||||
|
|
|
@ -0,0 +1,986 @@
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from tqdm import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
# v2版本只做表格
|
||||||
|
|
||||||
|
# 添加中文字体支持
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
|
class StockAnalyzer:
|
||||||
|
def __init__(self, db_connection_string):
|
||||||
|
"""
|
||||||
|
Initialize the stock analyzer
|
||||||
|
:param db_connection_string: Database connection string
|
||||||
|
"""
|
||||||
|
self.engine = create_engine(db_connection_string)
|
||||||
|
self.initial_capital = 300000 # 30万本金
|
||||||
|
self.total_position = 1000000 # 100万建仓规模
|
||||||
|
self.borrowed_capital = 700000 # 70万借入资金
|
||||||
|
self.annual_interest_rate = 0.05 # 年化5%利息
|
||||||
|
self.daily_interest_rate = self.annual_interest_rate / 365
|
||||||
|
self.commission_rate = 0.05 # 5%手续费
|
||||||
|
self.min_holding_days = 7 # 最少持仓天数
|
||||||
|
|
||||||
|
def get_stock_list(self):
|
||||||
|
"""获取所有股票列表"""
|
||||||
|
# query = "SELECT DISTINCT gp_code as symbol FROM gp_code_all where mark1 = '1'"
|
||||||
|
# return pd.read_sql(query, self.engine)['symbol'].tolist()
|
||||||
|
return ['SH600522',
|
||||||
|
'SZ002340',
|
||||||
|
'SZ000733',
|
||||||
|
'SH601615',
|
||||||
|
'SH600157',
|
||||||
|
'SH688005',
|
||||||
|
'SH600903',
|
||||||
|
'SH600956',
|
||||||
|
'SH601187',
|
||||||
|
'SH603983'] # 你可以在这里添加更多股票代码
|
||||||
|
|
||||||
|
def get_stock_data(self, symbol, start_date, end_date, include_history=False):
|
||||||
|
"""获取指定股票在时间范围内的数据"""
|
||||||
|
try:
|
||||||
|
if include_history:
|
||||||
|
# 如果需要历史数据,获取start_date之前240天的数据
|
||||||
|
query = f"""
|
||||||
|
WITH history_data AS (
|
||||||
|
SELECT
|
||||||
|
symbol,
|
||||||
|
timestamp,
|
||||||
|
CAST(close as DECIMAL(10,2)) as close
|
||||||
|
FROM gp_day_data
|
||||||
|
WHERE symbol = '{symbol}'
|
||||||
|
AND timestamp < '{start_date}'
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT 240
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
symbol,
|
||||||
|
timestamp,
|
||||||
|
CAST(open as DECIMAL(10,2)) as open,
|
||||||
|
CAST(high as DECIMAL(10,2)) as high,
|
||||||
|
CAST(low as DECIMAL(10,2)) as low,
|
||||||
|
CAST(close as DECIMAL(10,2)) as close,
|
||||||
|
CAST(volume as DECIMAL(20,0)) as volume
|
||||||
|
FROM gp_day_data
|
||||||
|
WHERE symbol = '{symbol}'
|
||||||
|
AND timestamp BETWEEN '{start_date}' AND '{end_date}'
|
||||||
|
UNION ALL
|
||||||
|
SELECT
|
||||||
|
symbol,
|
||||||
|
timestamp,
|
||||||
|
NULL as open,
|
||||||
|
NULL as high,
|
||||||
|
NULL as low,
|
||||||
|
close,
|
||||||
|
NULL as volume
|
||||||
|
FROM history_data
|
||||||
|
ORDER BY timestamp
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
symbol,
|
||||||
|
timestamp,
|
||||||
|
CAST(open as DECIMAL(10,2)) as open,
|
||||||
|
CAST(high as DECIMAL(10,2)) as high,
|
||||||
|
CAST(low as DECIMAL(10,2)) as low,
|
||||||
|
CAST(close as DECIMAL(10,2)) as close,
|
||||||
|
CAST(volume as DECIMAL(20,0)) as volume
|
||||||
|
FROM gp_day_data
|
||||||
|
WHERE symbol = '{symbol}'
|
||||||
|
AND timestamp BETWEEN '{start_date}' AND '{end_date}'
|
||||||
|
ORDER BY timestamp
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 使用chunksize分批读取数据
|
||||||
|
chunks = []
|
||||||
|
for chunk in pd.read_sql(query, self.engine, chunksize=1000):
|
||||||
|
chunks.append(chunk)
|
||||||
|
df = pd.concat(chunks, ignore_index=True)
|
||||||
|
|
||||||
|
# 确保数值列的类型正确
|
||||||
|
numeric_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||||
|
for col in numeric_columns:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||||
|
|
||||||
|
# 删除任何包含 NaN 的行,但只考虑分析期间需要的列
|
||||||
|
if include_history:
|
||||||
|
historical_mask = (df['timestamp'] < start_date) & df['close'].notna()
|
||||||
|
analysis_mask = (df['timestamp'] >= start_date) & df[['open', 'high', 'low', 'close', 'volume']].notna().all(axis=1)
|
||||||
|
df = df[historical_mask | analysis_mask]
|
||||||
|
else:
|
||||||
|
df = df.dropna()
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取股票 {symbol} 数据时出错: {str(e)}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def calculate_trade_signals(self, df, take_profit_pct, stop_loss_pct, pullback_entry_pct, start_date):
|
||||||
|
"""计算交易信号和收益"""
|
||||||
|
"""计算交易信号和收益"""
|
||||||
|
# 检查历史数据是否足够
|
||||||
|
historical_data = df[df['timestamp'] < start_date]
|
||||||
|
if len(historical_data) < 240:
|
||||||
|
# print(f"历史数据不足240天,跳过分析")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# 计算240日均线基准价格(使用历史数据)
|
||||||
|
ma240_base_price = historical_data['close'].tail(240).mean()
|
||||||
|
# print(f"240日均线基准价格: {ma240_base_price:.2f}")
|
||||||
|
|
||||||
|
# 只使用分析期间的数据进行交易信号计算
|
||||||
|
analysis_df = df[df['timestamp'] >= start_date].copy()
|
||||||
|
if len(analysis_df) < 10: # 确保分析期间有足够的数据
|
||||||
|
# print(f"分析期间数据不足10天,跳过分析")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
positions = [] # 记录所有持仓,每个元素是一个字典,包含入场价格和日期
|
||||||
|
max_positions = 5 # 最大建仓次数
|
||||||
|
cumulative_profit = 0 # 总累计收益
|
||||||
|
trades_count = 0
|
||||||
|
max_holding_period = timedelta(days=180) # 最长持仓6个月
|
||||||
|
max_loss_amount = -300000 # 最大亏损金额为30万
|
||||||
|
|
||||||
|
# 添加统计变量
|
||||||
|
total_holding_days = 0 # 总持仓天数
|
||||||
|
max_capital_usage = 0 # 最大资金占用
|
||||||
|
total_trades = 0 # 总交易次数
|
||||||
|
profitable_trades = 0 # 盈利交易次数
|
||||||
|
loss_trades = 0 # 亏损交易次数
|
||||||
|
|
||||||
|
# print(f"\n{'='*50}")
|
||||||
|
# print(f"开始分析股票: {df.iloc[0]['symbol']}")
|
||||||
|
# print(f"{'='*50}")
|
||||||
|
|
||||||
|
# 记录前一天的收盘价
|
||||||
|
prev_close = None
|
||||||
|
|
||||||
|
for index, row in analysis_df.iterrows():
|
||||||
|
current_price = (float(row['high']) + float(row['low'])) / 2
|
||||||
|
day_result = {
|
||||||
|
'timestamp': row['timestamp'],
|
||||||
|
'profit': cumulative_profit, # 记录当前的累计收益
|
||||||
|
'position': False,
|
||||||
|
'action': 'hold'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新所有持仓的收益
|
||||||
|
positions_to_remove = []
|
||||||
|
|
||||||
|
for pos in positions:
|
||||||
|
pos_days_held = (row['timestamp'] - pos['entry_date']).days
|
||||||
|
pos_size = self.total_position / pos['entry_price']
|
||||||
|
|
||||||
|
# 检查是否需要平仓
|
||||||
|
should_close = False
|
||||||
|
close_type = None
|
||||||
|
close_price = current_price
|
||||||
|
|
||||||
|
if pos_days_held >= self.min_holding_days:
|
||||||
|
# 检查止盈
|
||||||
|
if float(row['high']) >= pos['entry_price'] * (1 + take_profit_pct):
|
||||||
|
should_close = True
|
||||||
|
close_type = 'sell_profit'
|
||||||
|
close_price = pos['entry_price'] * (1 + take_profit_pct)
|
||||||
|
# 检查最大亏损
|
||||||
|
elif (close_price - pos['entry_price']) * pos_size <= max_loss_amount:
|
||||||
|
should_close = True
|
||||||
|
close_type = 'sell_loss'
|
||||||
|
# 检查持仓时间
|
||||||
|
elif (row['timestamp'] - pos['entry_date']) >= max_holding_period:
|
||||||
|
should_close = True
|
||||||
|
close_type = 'time_limit'
|
||||||
|
|
||||||
|
if should_close:
|
||||||
|
profit = (close_price - pos['entry_price']) * pos_size
|
||||||
|
commission = profit * self.commission_rate if profit > 0 else 0
|
||||||
|
interest_cost = self.borrowed_capital * self.daily_interest_rate * pos_days_held
|
||||||
|
net_profit = profit - commission - interest_cost
|
||||||
|
|
||||||
|
# 更新统计数据
|
||||||
|
total_holding_days += pos_days_held
|
||||||
|
total_trades += 1
|
||||||
|
if net_profit > 0:
|
||||||
|
profitable_trades += 1
|
||||||
|
else:
|
||||||
|
loss_trades += 1
|
||||||
|
|
||||||
|
cumulative_profit += net_profit
|
||||||
|
positions_to_remove.append(pos)
|
||||||
|
day_result['action'] = close_type
|
||||||
|
day_result['profit'] = cumulative_profit
|
||||||
|
|
||||||
|
# print(f"\n>>> {'止盈平仓' if close_type == 'sell_profit' else '止损平仓' if close_type == 'sell_loss' else '到期平仓'}")
|
||||||
|
# print(f"持仓天数: {pos_days_held}")
|
||||||
|
# print(f"持仓数量: {pos_size:.0f}股")
|
||||||
|
# print(f"平仓价格: {close_price:.2f}")
|
||||||
|
# print(f"{'盈利' if profit > 0 else '亏损'}: {profit:,.2f}")
|
||||||
|
# print(f"手续费: {commission:,.2f}")
|
||||||
|
# print(f"利息成本: {interest_cost:,.2f}")
|
||||||
|
# print(f"净{'盈利' if net_profit > 0 else '亏损'}: {net_profit:,.2f}")
|
||||||
|
|
||||||
|
# 移除已平仓的持仓
|
||||||
|
for pos in positions_to_remove:
|
||||||
|
positions.remove(pos)
|
||||||
|
|
||||||
|
# 检查是否可以建仓
|
||||||
|
if len(positions) < max_positions and prev_close is not None:
|
||||||
|
# 计算当日跌幅
|
||||||
|
daily_drop = (prev_close - float(row['low'])) / prev_close
|
||||||
|
|
||||||
|
# 检查价格是否在均线区间内
|
||||||
|
price_in_range = (
|
||||||
|
current_price >= ma240_base_price * 0.7 and
|
||||||
|
current_price <= ma240_base_price * 1.3
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果跌幅超过5%且价格在均线区间内
|
||||||
|
if daily_drop >= 0.05 and price_in_range:
|
||||||
|
positions.append({
|
||||||
|
'entry_price': current_price,
|
||||||
|
'entry_date': row['timestamp']
|
||||||
|
})
|
||||||
|
trades_count += 1
|
||||||
|
# print(f"\n>>> 建仓信号 #{trades_count}")
|
||||||
|
# print(f"日期: {row['timestamp'].strftime('%Y-%m-%d')}")
|
||||||
|
# print(f"建仓价格: {current_price:.2f}")
|
||||||
|
# print(f"建仓数量: {self.total_position/current_price:.0f}股")
|
||||||
|
# print(f"当日跌幅: {daily_drop*100:.2f}%")
|
||||||
|
# print(f"距离均线: {((current_price/ma240_base_price)-1)*100:.2f}%")
|
||||||
|
|
||||||
|
# 更新前一天收盘价
|
||||||
|
prev_close = float(row['close'])
|
||||||
|
|
||||||
|
# 更新日结果
|
||||||
|
day_result['position'] = len(positions) > 0
|
||||||
|
results.append(day_result)
|
||||||
|
|
||||||
|
# 计算当前资金占用
|
||||||
|
current_capital_usage = sum(self.total_position for pos in positions)
|
||||||
|
max_capital_usage = max(max_capital_usage, current_capital_usage)
|
||||||
|
|
||||||
|
# 在最后一天强制平仓所有持仓
|
||||||
|
if positions:
|
||||||
|
final_price = (float(analysis_df.iloc[-1]['high']) + float(analysis_df.iloc[-1]['low'])) / 2
|
||||||
|
final_total_profit = 0
|
||||||
|
|
||||||
|
for pos in positions:
|
||||||
|
position_size = self.total_position / pos['entry_price']
|
||||||
|
days_held = (analysis_df.iloc[-1]['timestamp'] - pos['entry_date']).days
|
||||||
|
final_profit = (final_price - pos['entry_price']) * position_size
|
||||||
|
commission = final_profit * self.commission_rate if final_profit > 0 else 0
|
||||||
|
interest_cost = self.borrowed_capital * self.daily_interest_rate * days_held
|
||||||
|
net_profit = final_profit - commission - interest_cost
|
||||||
|
|
||||||
|
# 更新统计数据
|
||||||
|
total_holding_days += days_held
|
||||||
|
total_trades += 1
|
||||||
|
if net_profit > 0:
|
||||||
|
profitable_trades += 1
|
||||||
|
else:
|
||||||
|
loss_trades += 1
|
||||||
|
|
||||||
|
final_total_profit += net_profit
|
||||||
|
|
||||||
|
# print(f"\n>>> 到期强制平仓")
|
||||||
|
# print(f"持仓天数: {days_held}")
|
||||||
|
# print(f"持仓数量: {position_size:.0f}股")
|
||||||
|
# print(f"平仓价格: {final_price:.2f}")
|
||||||
|
# print(f"毛利润: {final_profit:,.2f}")
|
||||||
|
# print(f"手续费: {commission:,.2f}")
|
||||||
|
# print(f"利息成本: {interest_cost:,.2f}")
|
||||||
|
# print(f"净利润: {net_profit:,.2f}")
|
||||||
|
|
||||||
|
results[-1]['action'] = 'final_sell'
|
||||||
|
cumulative_profit += final_total_profit # 更新最终的累计收益
|
||||||
|
results[-1]['profit'] = cumulative_profit # 更新最后一天的累计收益
|
||||||
|
|
||||||
|
# 计算统计数据
|
||||||
|
avg_holding_days = total_holding_days / total_trades if total_trades > 0 else 0
|
||||||
|
win_rate = profitable_trades / total_trades * 100 if total_trades > 0 else 0
|
||||||
|
|
||||||
|
# print(f"\n{'='*50}")
|
||||||
|
# print(f"交易统计")
|
||||||
|
# print(f"总交易次数: {trades_count}")
|
||||||
|
# print(f"累计收益: {cumulative_profit:,.2f}")
|
||||||
|
# print(f"最大资金占用: {max_capital_usage:,.2f}")
|
||||||
|
# print(f"平均持仓天数: {avg_holding_days:.1f}")
|
||||||
|
# print(f"胜率: {win_rate:.1f}%")
|
||||||
|
# print(f"盈利交易: {profitable_trades}次")
|
||||||
|
# print(f"亏损交易: {loss_trades}次")
|
||||||
|
# print(f"{'='*50}\n")
|
||||||
|
|
||||||
|
# 将统计数据添加到结果中
|
||||||
|
if len(results) > 0:
|
||||||
|
results[-1]['stats'] = {
|
||||||
|
'total_trades': total_trades,
|
||||||
|
'profitable_trades': profitable_trades,
|
||||||
|
'loss_trades': loss_trades,
|
||||||
|
'win_rate': win_rate,
|
||||||
|
'avg_holding_days': avg_holding_days,
|
||||||
|
'max_capital_usage': max_capital_usage,
|
||||||
|
'final_profit': cumulative_profit
|
||||||
|
}
|
||||||
|
|
||||||
|
return pd.DataFrame(results)
|
||||||
|
|
||||||
|
def analyze_stock(self, symbol, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct):
|
||||||
|
"""分析单个股票"""
|
||||||
|
df = self.get_stock_data(symbol, start_date, end_date, include_history=True)
|
||||||
|
|
||||||
|
if len(df) < 240: # 确保总数据量足够
|
||||||
|
print(f"股票 {symbol} 总数据不足240天,跳过分析")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.calculate_trade_signals(
|
||||||
|
df,
|
||||||
|
take_profit_pct=take_profit_pct,
|
||||||
|
stop_loss_pct=stop_loss_pct,
|
||||||
|
pullback_entry_pct=pullback_entry_pct,
|
||||||
|
start_date=start_date
|
||||||
|
)
|
||||||
|
|
||||||
|
def analyze_all_stocks(self, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct):
|
||||||
|
"""分析所有股票"""
|
||||||
|
stocks = self.get_stock_list()
|
||||||
|
all_results = {}
|
||||||
|
analysis_summary = {
|
||||||
|
'processed': [],
|
||||||
|
'skipped': [],
|
||||||
|
'error': []
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n开始分析,共 {len(stocks)} 支股票")
|
||||||
|
|
||||||
|
for symbol in tqdm(stocks, desc="Analyzing stocks"):
|
||||||
|
try:
|
||||||
|
# 获取股票数据
|
||||||
|
df = self.get_stock_data(symbol, start_date, end_date, include_history=True)
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
print(f"\n股票 {symbol} 数据为空,跳过分析")
|
||||||
|
analysis_summary['skipped'].append({
|
||||||
|
'symbol': symbol,
|
||||||
|
'reason': 'empty_data',
|
||||||
|
'take_profit_pct': take_profit_pct
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查历史数据
|
||||||
|
historical_data = df[df['timestamp'] < start_date]
|
||||||
|
if len(historical_data) < 240:
|
||||||
|
# print(f"\n股票 {symbol} 历史数据不足240天({len(historical_data)}天),跳过分析")
|
||||||
|
analysis_summary['skipped'].append({
|
||||||
|
'symbol': symbol,
|
||||||
|
'reason': f'insufficient_history_{len(historical_data)}days',
|
||||||
|
'take_profit_pct': take_profit_pct
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查分析期数据
|
||||||
|
analysis_data = df[df['timestamp'] >= start_date]
|
||||||
|
if len(analysis_data) < 10:
|
||||||
|
print(f"\n股票 {symbol} 分析期数据不足10天({len(analysis_data)}天),跳过分析")
|
||||||
|
analysis_summary['skipped'].append({
|
||||||
|
'symbol': symbol,
|
||||||
|
'reason': f'insufficient_analysis_data_{len(analysis_data)}days',
|
||||||
|
'take_profit_pct': take_profit_pct
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算交易信号
|
||||||
|
result = self.calculate_trade_signals(
|
||||||
|
df,
|
||||||
|
take_profit_pct,
|
||||||
|
stop_loss_pct,
|
||||||
|
pullback_entry_pct,
|
||||||
|
start_date
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None and not result.empty:
|
||||||
|
# 检查是否有交易记录
|
||||||
|
has_trades = any(action in ['sell_profit', 'sell_loss', 'time_limit', 'final_sell']
|
||||||
|
for action in result['action'] if pd.notna(action))
|
||||||
|
|
||||||
|
if has_trades:
|
||||||
|
all_results[symbol] = result
|
||||||
|
analysis_summary['processed'].append({
|
||||||
|
'symbol': symbol,
|
||||||
|
'take_profit_pct': take_profit_pct,
|
||||||
|
'trades': len([x for x in result['action'] if pd.notna(x) and x != 'hold']),
|
||||||
|
'profit': result['profit'].iloc[-1]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# print(f"\n股票 {symbol} 没有产生有效的交易信号")
|
||||||
|
analysis_summary['skipped'].append({
|
||||||
|
'symbol': symbol,
|
||||||
|
'reason': 'no_valid_signals',
|
||||||
|
'take_profit_pct': take_profit_pct
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n处理股票 {symbol} 时出错: {str(e)}")
|
||||||
|
analysis_summary['error'].append({
|
||||||
|
'symbol': symbol,
|
||||||
|
'error': str(e),
|
||||||
|
'take_profit_pct': take_profit_pct
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 打印分析总结
|
||||||
|
print(f"\n分析完成:")
|
||||||
|
print(f"成功分析: {len(analysis_summary['processed'])} 支股票")
|
||||||
|
print(f"跳过分析: {len(analysis_summary['skipped'])} 支股票")
|
||||||
|
print(f"出错股票: {len(analysis_summary['error'])} 支股票")
|
||||||
|
|
||||||
|
# 保存分析总结
|
||||||
|
if not os.path.exists('results/analysis_summary'):
|
||||||
|
os.makedirs('results/analysis_summary')
|
||||||
|
|
||||||
|
# 保存处理成功的股票信息
|
||||||
|
if analysis_summary['processed']:
|
||||||
|
pd.DataFrame(analysis_summary['processed']).to_csv(
|
||||||
|
f'results/analysis_summary/processed_stocks_{take_profit_pct*100:.1f}.csv',
|
||||||
|
index=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存跳过的股票信息
|
||||||
|
if analysis_summary['skipped']:
|
||||||
|
pd.DataFrame(analysis_summary['skipped']).to_csv(
|
||||||
|
f'results/analysis_summary/skipped_stocks_{take_profit_pct*100:.1f}.csv',
|
||||||
|
index=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存出错的股票信息
|
||||||
|
if analysis_summary['error']:
|
||||||
|
pd.DataFrame(analysis_summary['error']).to_csv(
|
||||||
|
f'results/analysis_summary/error_stocks_{take_profit_pct*100:.1f}.csv',
|
||||||
|
index=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
def plot_results(self, results, symbol):
|
||||||
|
"""绘制单个股票的收益走势图"""
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.plot(results['timestamp'], results['profit'].cumsum(), label='Cumulative Profit')
|
||||||
|
plt.title(f'Stock {symbol} Trading Results')
|
||||||
|
plt.xlabel('Date')
|
||||||
|
plt.ylabel('Profit (CNY)')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def plot_all_stocks(self, all_results):
|
||||||
|
"""绘制所有股票的收益走势图"""
|
||||||
|
plt.figure(figsize=(15, 8))
|
||||||
|
|
||||||
|
# 记录所有股票的累计收益
|
||||||
|
total_profits = {}
|
||||||
|
max_profit = float('-inf')
|
||||||
|
min_profit = float('inf')
|
||||||
|
|
||||||
|
# 绘制每支股票的曲线
|
||||||
|
for symbol, results in all_results.items():
|
||||||
|
# 直接使用已经计算好的累计收益
|
||||||
|
profits = results['profit']
|
||||||
|
plt.plot(results['timestamp'], profits, label=symbol, alpha=0.5, linewidth=1)
|
||||||
|
|
||||||
|
# 更新最大最小收益
|
||||||
|
max_profit = max(max_profit, profits.max())
|
||||||
|
min_profit = min(min_profit, profits.min())
|
||||||
|
|
||||||
|
# 记录总收益(使用最后一个累计收益值)
|
||||||
|
total_profits[symbol] = profits.iloc[-1]
|
||||||
|
|
||||||
|
# 计算所有股票的中位数收益曲线
|
||||||
|
# 首先找到共同的日期范围
|
||||||
|
all_dates = set()
|
||||||
|
for results in all_results.values():
|
||||||
|
all_dates.update(results['timestamp'])
|
||||||
|
all_dates = sorted(list(all_dates))
|
||||||
|
|
||||||
|
# 创建一个包含所有日期的DataFrame
|
||||||
|
profits_df = pd.DataFrame(index=all_dates)
|
||||||
|
|
||||||
|
# 对每支股票,填充所有日期的收益
|
||||||
|
for symbol, results in all_results.items():
|
||||||
|
daily_profits = pd.Series(index=all_dates, data=np.nan)
|
||||||
|
# 直接使用已经计算好的累计收益
|
||||||
|
for date, profit in zip(results['timestamp'], results['profit']):
|
||||||
|
daily_profits[date] = profit
|
||||||
|
# 对于没有交易的日期,使用最后一个累计收益填充
|
||||||
|
daily_profits.fillna(method='ffill', inplace=True)
|
||||||
|
daily_profits.fillna(0, inplace=True) # 对开始之前的日期填充0
|
||||||
|
profits_df[symbol] = daily_profits
|
||||||
|
|
||||||
|
# 计算并绘制中位数收益曲线
|
||||||
|
median_line = profits_df.median(axis=1)
|
||||||
|
plt.plot(all_dates, median_line, 'r-', label='Median', linewidth=2)
|
||||||
|
|
||||||
|
plt.title('All Stocks Trading Results')
|
||||||
|
plt.xlabel('Date')
|
||||||
|
plt.ylabel('Profit (CNY)')
|
||||||
|
plt.grid(True)
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
|
||||||
|
# 添加利润区间标注
|
||||||
|
plt.text(0.02, 0.98,
|
||||||
|
f'Profit Range:\nMax: {max_profit:,.0f}\nMin: {min_profit:,.0f}\nMedian: {median_line.iloc[-1]:,.0f}',
|
||||||
|
transform=plt.gca().transAxes,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8))
|
||||||
|
|
||||||
|
# 计算所有股票的统计数据
|
||||||
|
total_stats = {
|
||||||
|
'total_trades': 0,
|
||||||
|
'profitable_trades': 0,
|
||||||
|
'loss_trades': 0,
|
||||||
|
'total_holding_days': 0,
|
||||||
|
'max_capital_usage': 0,
|
||||||
|
'total_profit': 0,
|
||||||
|
'profitable_stocks': 0,
|
||||||
|
'loss_stocks': 0,
|
||||||
|
'total_capital_usage': 0, # 添加总资金占用
|
||||||
|
'total_win_rate': 0, # 添加总胜率
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取前5名和后5名股票
|
||||||
|
top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||||
|
bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5]
|
||||||
|
|
||||||
|
# 添加排名信息
|
||||||
|
rank_text = "Top 5 Stocks:\n"
|
||||||
|
for symbol, profit in top_5:
|
||||||
|
rank_text += f"{symbol}: {profit:,.0f}\n"
|
||||||
|
rank_text += "\nBottom 5 Stocks:\n"
|
||||||
|
for symbol, profit in bottom_5:
|
||||||
|
rank_text += f"{symbol}: {profit:,.0f}\n"
|
||||||
|
|
||||||
|
plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top')
|
||||||
|
|
||||||
|
# 记录每支股票的统计数据用于计算平均值
|
||||||
|
stock_stats = []
|
||||||
|
|
||||||
|
for symbol, results in all_results.items():
|
||||||
|
if 'stats' in results.iloc[-1]:
|
||||||
|
stats = results.iloc[-1]['stats']
|
||||||
|
stock_stats.append(stats) # 记录每支股票的统计数据
|
||||||
|
|
||||||
|
total_stats['total_trades'] += stats['total_trades']
|
||||||
|
total_stats['profitable_trades'] += stats['profitable_trades']
|
||||||
|
total_stats['loss_trades'] += stats['loss_trades']
|
||||||
|
total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades']
|
||||||
|
total_stats['max_capital_usage'] = max(total_stats['max_capital_usage'], stats['max_capital_usage'])
|
||||||
|
total_stats['total_capital_usage'] += stats['max_capital_usage']
|
||||||
|
total_stats['total_profit'] += stats['final_profit']
|
||||||
|
total_stats['total_win_rate'] += stats['win_rate']
|
||||||
|
if stats['final_profit'] > 0:
|
||||||
|
total_stats['profitable_stocks'] += 1
|
||||||
|
else:
|
||||||
|
total_stats['loss_stocks'] += 1
|
||||||
|
|
||||||
|
# 计算总体统计
|
||||||
|
total_stocks = total_stats['profitable_stocks'] + total_stats['loss_stocks']
|
||||||
|
avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0
|
||||||
|
win_rate = total_stats['profitable_trades'] / total_stats['total_trades'] * 100 if total_stats['total_trades'] > 0 else 0
|
||||||
|
stock_win_rate = total_stats['profitable_stocks'] / total_stocks * 100 if total_stocks > 0 else 0
|
||||||
|
|
||||||
|
# 计算平均值统计
|
||||||
|
avg_trades_per_stock = total_stats['total_trades'] / total_stocks if total_stocks > 0 else 0
|
||||||
|
avg_profit_per_stock = total_stats['total_profit'] / total_stocks if total_stocks > 0 else 0
|
||||||
|
avg_capital_usage = total_stats['total_capital_usage'] / total_stocks if total_stocks > 0 else 0
|
||||||
|
avg_win_rate = total_stats['total_win_rate'] / total_stocks if total_stocks > 0 else 0
|
||||||
|
|
||||||
|
# 添加总体统计信息
|
||||||
|
stats_text = (
|
||||||
|
f"总体统计:\n"
|
||||||
|
f"总交易次数: {total_stats['total_trades']}\n"
|
||||||
|
# f"最大资金占用: {total_stats['max_capital_usage']:,.0f}\n"
|
||||||
|
f"平均持仓天数: {avg_holding_days:.1f}\n"
|
||||||
|
f"交易胜率: {win_rate:.1f}%\n"
|
||||||
|
f"股票胜率: {stock_win_rate:.1f}%\n"
|
||||||
|
f"盈利股票数: {total_stats['profitable_stocks']}\n"
|
||||||
|
f"亏损股票数: {total_stats['loss_stocks']}\n"
|
||||||
|
f"\n平均值统计:\n"
|
||||||
|
f"平均交易次数: {avg_trades_per_stock:.1f}\n"
|
||||||
|
f"平均收益: {avg_profit_per_stock:,.0f}\n"
|
||||||
|
f"平均资金占用: {avg_capital_usage:,.0f}\n"
|
||||||
|
f"平均持仓天数: {avg_holding_days:.1f}\n"
|
||||||
|
f"平均胜率: {avg_win_rate:.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.text(0.02, 0.6, stats_text,
|
||||||
|
transform=plt.gca().transAxes,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
|
||||||
|
verticalalignment='top',
|
||||||
|
fontsize=10)
|
||||||
|
|
||||||
|
# 获取前5名和后5名股票
|
||||||
|
top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||||
|
bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5]
|
||||||
|
|
||||||
|
# 添加排名信息
|
||||||
|
rank_text = "Top 5 Stocks:\n"
|
||||||
|
for symbol, profit in top_5:
|
||||||
|
rank_text += f"{symbol}: {profit:,.0f}\n"
|
||||||
|
rank_text += "\nBottom 5 Stocks:\n"
|
||||||
|
for symbol, profit in bottom_5:
|
||||||
|
rank_text += f"{symbol}: {profit:,.0f}\n"
|
||||||
|
|
||||||
|
plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt, total_profits
|
||||||
|
|
||||||
|
def plot_profit_distribution(self, total_profits):
|
||||||
|
"""绘制所有股票的收益分布小提琴图"""
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
|
||||||
|
# 将收益数据转换为数组,确保使用最终累计收益
|
||||||
|
profits = np.array([profit for profit in total_profits.values()])
|
||||||
|
|
||||||
|
# 计算统计数据
|
||||||
|
mean_profit = np.mean(profits)
|
||||||
|
median_profit = np.median(profits)
|
||||||
|
max_profit = np.max(profits)
|
||||||
|
min_profit = np.min(profits)
|
||||||
|
std_profit = np.std(profits)
|
||||||
|
|
||||||
|
# 绘制小提琴图
|
||||||
|
violin_parts = plt.violinplot(profits, positions=[0], showmeans=True, showmedians=True)
|
||||||
|
|
||||||
|
# 设置小提琴图的颜色
|
||||||
|
violin_parts['bodies'][0].set_facecolor('lightblue')
|
||||||
|
violin_parts['bodies'][0].set_alpha(0.7)
|
||||||
|
violin_parts['cmeans'].set_color('red')
|
||||||
|
violin_parts['cmedians'].set_color('blue')
|
||||||
|
|
||||||
|
# 添加统计信息
|
||||||
|
stats_text = (
|
||||||
|
f"统计信息:\n"
|
||||||
|
f"平均收益: {mean_profit:,.0f}\n"
|
||||||
|
f"中位数收益: {median_profit:,.0f}\n"
|
||||||
|
f"最大收益: {max_profit:,.0f}\n"
|
||||||
|
f"最小收益: {min_profit:,.0f}\n"
|
||||||
|
f"标准差: {std_profit:,.0f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.text(0.65, 0.95, stats_text,
|
||||||
|
transform=plt.gca().transAxes,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
|
||||||
|
verticalalignment='top',
|
||||||
|
fontsize=10)
|
||||||
|
|
||||||
|
# 设置图表样式
|
||||||
|
plt.title('股票收益分布图', fontsize=14, pad=20)
|
||||||
|
plt.ylabel('收益 (元)', fontsize=12)
|
||||||
|
plt.xticks([0], ['所有股票'], fontsize=10)
|
||||||
|
plt.grid(True, axis='y', alpha=0.3)
|
||||||
|
|
||||||
|
# 添加零线
|
||||||
|
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
|
||||||
|
|
||||||
|
# 计算盈利和亏损的股票数量
|
||||||
|
profit_count = np.sum(profits > 0)
|
||||||
|
loss_count = np.sum(profits < 0)
|
||||||
|
total_count = len(profits)
|
||||||
|
|
||||||
|
# 添加盈亏比例信息
|
||||||
|
ratio_text = (
|
||||||
|
f"盈亏比例:\n"
|
||||||
|
f"盈利: {profit_count}支 ({profit_count/total_count*100:.1f}%)\n"
|
||||||
|
f"亏损: {loss_count}支 ({loss_count/total_count*100:.1f}%)\n"
|
||||||
|
f"总计: {total_count}支"
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.text(0.65, 0.6, ratio_text,
|
||||||
|
transform=plt.gca().transAxes,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
|
||||||
|
verticalalignment='top',
|
||||||
|
fontsize=10)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def plot_profit_matrix(self, df):
|
||||||
|
"""绘制止盈比例分析矩阵的结果图"""
|
||||||
|
plt.figure(figsize=(15, 10))
|
||||||
|
|
||||||
|
# 创建子图
|
||||||
|
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
||||||
|
|
||||||
|
# 1. 总收益曲线
|
||||||
|
ax1.plot(df.index, df['total_profit'], 'b-', linewidth=2)
|
||||||
|
ax1.set_title('总收益 vs 止盈比例')
|
||||||
|
ax1.set_xlabel('止盈比例 (%)')
|
||||||
|
ax1.set_ylabel('总收益 (元)')
|
||||||
|
ax1.grid(True)
|
||||||
|
|
||||||
|
# 标记最优止盈比例点
|
||||||
|
best_profit_pct = df['total_profit'].idxmax()
|
||||||
|
best_profit = df['total_profit'].max()
|
||||||
|
ax1.plot(best_profit_pct, best_profit, 'ro')
|
||||||
|
ax1.annotate(f'最优: {best_profit_pct:.1f}%\n{best_profit:,.0f}元',
|
||||||
|
(best_profit_pct, best_profit),
|
||||||
|
xytext=(10, 10), textcoords='offset points')
|
||||||
|
|
||||||
|
# 2. 平均每笔收益和中位数收益
|
||||||
|
ax2.plot(df.index, df['avg_profit_per_trade'], 'g-', linewidth=2, label='平均收益')
|
||||||
|
ax2.plot(df.index, df['median_profit_per_trade'], 'r--', linewidth=2, label='中位数收益')
|
||||||
|
ax2.set_title('每笔交易收益 vs 止盈比例')
|
||||||
|
ax2.set_xlabel('止盈比例 (%)')
|
||||||
|
ax2.set_ylabel('收益 (元)')
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(True)
|
||||||
|
|
||||||
|
# 3. 平均持仓天数曲线
|
||||||
|
ax3.plot(df.index, df['avg_holding_days'], 'r-', linewidth=2)
|
||||||
|
ax3.set_title('平均持仓天数 vs 止盈比例')
|
||||||
|
ax3.set_xlabel('止盈比例 (%)')
|
||||||
|
ax3.set_ylabel('持仓天数')
|
||||||
|
ax3.grid(True)
|
||||||
|
|
||||||
|
# 4. 胜率对比
|
||||||
|
ax4.plot(df.index, df['trade_win_rate'], 'c-', linewidth=2, label='交易胜率')
|
||||||
|
ax4.plot(df.index, df['stock_win_rate'], 'm--', linewidth=2, label='股票胜率')
|
||||||
|
ax4.set_title('胜率 vs 止盈比例')
|
||||||
|
ax4.set_xlabel('止盈比例 (%)')
|
||||||
|
ax4.set_ylabel('胜率 (%)')
|
||||||
|
ax4.legend()
|
||||||
|
ax4.grid(True)
|
||||||
|
|
||||||
|
# 调整布局
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def analyze_profit_matrix(self, start_date, end_date, stop_loss_pct, pullback_entry_pct):
|
||||||
|
"""分析不同止盈比例的表现矩阵"""
|
||||||
|
# 创建results目录和临时文件目录
|
||||||
|
if not os.path.exists('results/temp'):
|
||||||
|
os.makedirs('results/temp')
|
||||||
|
|
||||||
|
# 从30%到5%,每次减少1%
|
||||||
|
profit_percentages = list(np.arange(0.30, 0.04, -0.01))
|
||||||
|
|
||||||
|
for take_profit_pct in profit_percentages:
|
||||||
|
# 检查是否已经分析过这个止盈点
|
||||||
|
temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv'
|
||||||
|
if os.path.exists(temp_file):
|
||||||
|
print(f"\n止盈比例 {take_profit_pct*100:.1f}% 已分析,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n分析止盈比例: {take_profit_pct*100:.1f}%")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 运行分析
|
||||||
|
results = self.analyze_all_stocks(
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
take_profit_pct,
|
||||||
|
stop_loss_pct,
|
||||||
|
pullback_entry_pct
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤掉空结果
|
||||||
|
valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty}
|
||||||
|
|
||||||
|
if valid_results:
|
||||||
|
# 计算统计数据
|
||||||
|
total_stats = {
|
||||||
|
'total_trades': 0,
|
||||||
|
'profitable_trades': 0,
|
||||||
|
'loss_trades': 0,
|
||||||
|
'total_holding_days': 0,
|
||||||
|
'total_profit': 0,
|
||||||
|
'total_capital_usage': 0,
|
||||||
|
'total_win_rate': 0,
|
||||||
|
'stock_count': len(valid_results),
|
||||||
|
'profitable_stocks': 0,
|
||||||
|
'all_trade_profits': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 收集每支股票的统计数据
|
||||||
|
for symbol, result in valid_results.items():
|
||||||
|
if 'stats' in result.iloc[-1]:
|
||||||
|
stats = result.iloc[-1]['stats']
|
||||||
|
total_stats['total_trades'] += stats['total_trades']
|
||||||
|
total_stats['profitable_trades'] += stats['profitable_trades']
|
||||||
|
total_stats['loss_trades'] += stats['loss_trades']
|
||||||
|
total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades']
|
||||||
|
total_stats['total_profit'] += stats['final_profit']
|
||||||
|
total_stats['total_capital_usage'] += stats['max_capital_usage']
|
||||||
|
total_stats['total_win_rate'] += stats['win_rate']
|
||||||
|
|
||||||
|
if stats['final_profit'] > 0:
|
||||||
|
total_stats['profitable_stocks'] += 1
|
||||||
|
|
||||||
|
if hasattr(stats, 'trade_profits'):
|
||||||
|
total_stats['all_trade_profits'].extend(stats['trade_profits'])
|
||||||
|
|
||||||
|
# 计算统计数据
|
||||||
|
avg_trades = total_stats['total_trades'] / total_stats['stock_count']
|
||||||
|
avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0
|
||||||
|
avg_profit_per_trade = total_stats['total_profit'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0
|
||||||
|
median_profit_per_trade = np.median(total_stats['all_trade_profits']) if total_stats['all_trade_profits'] else 0
|
||||||
|
total_profit = total_stats['total_profit']
|
||||||
|
trade_win_rate = (total_stats['profitable_trades'] / total_stats['total_trades'] * 100) if total_stats['total_trades'] > 0 else 0
|
||||||
|
stock_win_rate = (total_stats['profitable_stocks'] / total_stats['stock_count'] * 100) if total_stats['stock_count'] > 0 else 0
|
||||||
|
|
||||||
|
# 创建单个止盈点的结果DataFrame
|
||||||
|
result_df = pd.DataFrame([{
|
||||||
|
'take_profit_pct': take_profit_pct * 100,
|
||||||
|
'total_profit': total_profit,
|
||||||
|
'avg_profit_per_trade': avg_profit_per_trade,
|
||||||
|
'median_profit_per_trade': median_profit_per_trade,
|
||||||
|
'avg_holding_days': avg_holding_days,
|
||||||
|
'trade_win_rate': trade_win_rate,
|
||||||
|
'stock_win_rate': stock_win_rate,
|
||||||
|
'total_trades': total_stats['total_trades'],
|
||||||
|
'stock_count': total_stats['stock_count']
|
||||||
|
}])
|
||||||
|
|
||||||
|
# 保存单个止盈点的结果
|
||||||
|
result_df.to_csv(temp_file, index=False)
|
||||||
|
print(f"保存止盈点 {take_profit_pct*100:.1f}% 的分析结果")
|
||||||
|
|
||||||
|
# 清理内存
|
||||||
|
del valid_results
|
||||||
|
del results
|
||||||
|
if 'total_stats' in locals():
|
||||||
|
del total_stats
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理止盈点 {take_profit_pct*100:.1f}% 时出错: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 合并所有结果
|
||||||
|
print("\n合并所有分析结果...")
|
||||||
|
all_results = []
|
||||||
|
for take_profit_pct in profit_percentages:
|
||||||
|
temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv'
|
||||||
|
if os.path.exists(temp_file):
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(temp_file)
|
||||||
|
all_results.append(df)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取文件 {temp_file} 时出错: {str(e)}")
|
||||||
|
|
||||||
|
if all_results:
|
||||||
|
# 合并所有结果
|
||||||
|
df = pd.concat(all_results, ignore_index=True)
|
||||||
|
df = df.round({
|
||||||
|
'take_profit_pct': 1,
|
||||||
|
'total_profit': 0,
|
||||||
|
'avg_profit_per_trade': 0,
|
||||||
|
'median_profit_per_trade': 0,
|
||||||
|
'avg_holding_days': 1,
|
||||||
|
'trade_win_rate': 1,
|
||||||
|
'stock_win_rate': 1,
|
||||||
|
'total_trades': 0,
|
||||||
|
'stock_count': 0
|
||||||
|
})
|
||||||
|
|
||||||
|
# 设置止盈比例为索引并排序
|
||||||
|
df.set_index('take_profit_pct', inplace=True)
|
||||||
|
df.sort_index(ascending=False, inplace=True)
|
||||||
|
|
||||||
|
# 保存最终结果
|
||||||
|
df.to_csv('results/profit_matrix_analysis.csv')
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n止盈比例分析矩阵:")
|
||||||
|
print("=" * 120)
|
||||||
|
print(df.to_string())
|
||||||
|
print("=" * 120)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 绘制并保存矩阵分析图表
|
||||||
|
plt = self.plot_profit_matrix(df)
|
||||||
|
plt.savefig('results/profit_matrix_analysis.png', bbox_inches='tight', dpi=300)
|
||||||
|
plt.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"绘制图表时出错: {str(e)}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
print("没有找到任何分析结果")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 数据库连接配置
|
||||||
|
db_config = {
|
||||||
|
'host': '192.168.18.199',
|
||||||
|
'port': 3306,
|
||||||
|
'user': 'root',
|
||||||
|
'password': 'Chlry#$.8',
|
||||||
|
'database': 'db_gp_cj'
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}"
|
||||||
|
|
||||||
|
# 创建results目录
|
||||||
|
|
||||||
|
if not os.path.exists('results'):
|
||||||
|
os.makedirs('results')
|
||||||
|
|
||||||
|
# 创建分析器实例
|
||||||
|
analyzer = StockAnalyzer(connection_string)
|
||||||
|
|
||||||
|
# 设置分析参数
|
||||||
|
start_date = '2022-05-05'
|
||||||
|
end_date = '2023-08-28'
|
||||||
|
stop_loss_pct = -0.99 # -99%止损
|
||||||
|
pullback_entry_pct = -0.05 # -5%回调建仓
|
||||||
|
|
||||||
|
# 运行止盈比例分析
|
||||||
|
profit_matrix = analyzer.analyze_profit_matrix(
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
stop_loss_pct,
|
||||||
|
pullback_entry_pct
|
||||||
|
)
|
||||||
|
|
||||||
|
if not profit_matrix.empty:
|
||||||
|
# 使用最优止盈比例进行完整分析
|
||||||
|
take_profit_pct = profit_matrix['total_profit'].idxmax() / 100
|
||||||
|
print(f"\n使用最优止盈比例 {take_profit_pct*100:.1f}% 运行详细分析")
|
||||||
|
|
||||||
|
# 运行分析并只生成所需的图表
|
||||||
|
results = analyzer.analyze_all_stocks(
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
take_profit_pct,
|
||||||
|
stop_loss_pct,
|
||||||
|
pullback_entry_pct
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤掉空结果
|
||||||
|
valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty}
|
||||||
|
|
||||||
|
if valid_results:
|
||||||
|
# 只生成所需的两个图表
|
||||||
|
plt, _ = analyzer.plot_all_stocks(valid_results)
|
||||||
|
plt.savefig('results/all_stocks_analysis.png', bbox_inches='tight', dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"程序运行出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,81 @@
|
||||||
|
# 用前须知
|
||||||
|
|
||||||
|
## xtdata提供和MiniQmt的交互接口,本质是和MiniQmt建立连接,由MiniQmt处理行情数据请求,再把结果回传返回到python层。使用的行情服务器以及能获取到的行情数据和MiniQmt是一致的,要检查数据或者切换连接时直接操作MiniQmt即可。
|
||||||
|
|
||||||
|
## 对于数据获取接口,使用时需要先确保MiniQmt已有所需要的数据,如果不足可以通过补充数据接口补充,再调用数据获取接口获取。
|
||||||
|
|
||||||
|
## 对于订阅接口,直接设置数据回调,数据到来时会由回调返回。订阅接收到的数据一般会保存下来,同种数据不需要再单独补充。
|
||||||
|
|
||||||
|
# 代码讲解
|
||||||
|
|
||||||
|
# 从本地python导入xtquant库,如果出现报错则说明安装失败
|
||||||
|
from xtquant import xtdata
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 设定一个标的列表
|
||||||
|
code_list = ["000001.SZ"]
|
||||||
|
# 设定获取数据的周期
|
||||||
|
period = "1d"
|
||||||
|
|
||||||
|
# 下载标的行情数据
|
||||||
|
if 1:
|
||||||
|
## 为了方便用户进行数据管理,xtquant的大部分历史数据都是以压缩形式存储在本地的
|
||||||
|
## 比如行情数据,需要通过download_history_data下载,财务数据需要通过
|
||||||
|
## 所以在取历史数据之前,我们需要调用数据下载接口,将数据下载到本地
|
||||||
|
for i in code_list:
|
||||||
|
xtdata.download_history_data(i, period=period, incrementally=True) # 增量下载行情数据(开高低收,等等)到本地
|
||||||
|
|
||||||
|
xtdata.download_financial_data(code_list) # 下载财务数据到本地
|
||||||
|
xtdata.download_sector_data() # 下载板块数据到本地
|
||||||
|
# 更多数据的下载方式可以通过数据字典查询
|
||||||
|
|
||||||
|
# 读取本地历史行情数据
|
||||||
|
history_data = xtdata.get_market_data_ex([], code_list, period=period, count=-1)
|
||||||
|
print(history_data)
|
||||||
|
print("=" * 20)
|
||||||
|
|
||||||
|
# 如果需要盘中的实时行情,需要向服务器进行订阅后才能获取
|
||||||
|
# 订阅后,get_market_data函数于get_market_data_ex函数将会自动拼接本地历史行情与服务器实时行情
|
||||||
|
|
||||||
|
# 向服务器订阅数据
|
||||||
|
for i in code_list:
|
||||||
|
xtdata.subscribe_quote(i, period=period, count=-1) # 设置count = -1来取到当天所有实时行情
|
||||||
|
|
||||||
|
# 等待订阅完成
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 获取订阅后的行情
|
||||||
|
kline_data = xtdata.get_market_data_ex([], code_list, period=period)
|
||||||
|
print(kline_data)
|
||||||
|
|
||||||
|
# 获取订阅后的行情,并以固定间隔进行刷新,预期会循环打印10次
|
||||||
|
for i in range(10):
|
||||||
|
# 这边做演示,就用for来循环了,实际使用中可以用while True
|
||||||
|
kline_data = xtdata.get_market_data_ex([], code_list, period=period)
|
||||||
|
print(kline_data)
|
||||||
|
time.sleep(3) # 三秒后再次获取行情
|
||||||
|
|
||||||
|
|
||||||
|
# 如果不想用固定间隔触发,可以以用订阅后的回调来执行
|
||||||
|
# 这种模式下当订阅的callback回调函数将会异步的执行,每当订阅的标的tick发生变化更新,callback回调函数就会被调用一次
|
||||||
|
# 本地已有的数据不会触发callback
|
||||||
|
|
||||||
|
# 定义的回测函数
|
||||||
|
## 回调函数中,data是本次触发回调的数据,只有一条
|
||||||
|
def f(data):
|
||||||
|
# print(data)
|
||||||
|
|
||||||
|
code_list = list(data.keys()) # 获取到本次触发的标的代码
|
||||||
|
|
||||||
|
kline_in_callabck = xtdata.get_market_data_ex([], code_list, period=period) # 在回调中获取klines数据
|
||||||
|
print(kline_in_callabck)
|
||||||
|
|
||||||
|
|
||||||
|
for i in code_list:
|
||||||
|
xtdata.subscribe_quote(i, period=period, count=-1, callback=f) # 订阅时设定回调函数
|
||||||
|
|
||||||
|
# 使用回调时,必须要同时使用xtdata.run()来阻塞程序,否则程序运行到最后一行就直接结束退出了。
|
||||||
|
xtdata.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue