This commit is contained in:
liao 2025-05-06 15:13:15 +08:00
parent ae9b2f2fcf
commit aadef5f0fa
17 changed files with 6243 additions and 1696 deletions

129
README.md
View File

@ -211,6 +211,127 @@ pip install -r requirements.txt
筛选逻辑:选择行业环境稳定、有高质量新业务合作、财务状况良好、券商看好且上涨空间大的企业。
## Docker 部署说明
本系统支持通过 Docker 进行部署,可以单实例部署或多实例部署。
### 1. 单实例部署
```bash
# 构建镜像
docker-compose build
# 启动服务
docker-compose up -d
```
### 2. 多实例部署
提供了三个脚本用于管理多实例部署每个实例将在不同端口上运行从5088开始递增
#### 使用 docker run 部署多实例(推荐)
```bash
# 赋予脚本执行权限
chmod +x deploy-multiple.sh
# 部署3个实例端口将为5088、5089、5090
./deploy-multiple.sh 3
```
每个实例将有自己独立的数据目录:`./instances/instance-N/`,包含独立的日志和报告文件。
#### 使用 docker-compose 部署多实例
```bash
# 赋予脚本执行权限
chmod +x deploy-compose.sh
# 部署3个实例端口将为5088、5089、5090
./deploy-compose.sh 3
```
每个实例将使用独立的项目名称 `stock-app-N`
#### 管理多实例
使用 `manage-instances.sh` 脚本可以方便地管理已部署的实例:
```bash
# 赋予脚本执行权限
chmod +x manage-instances.sh
# 查看所有实例状态
./manage-instances.sh status
# 列出正在运行的实例
./manage-instances.sh list
# 启动指定实例
./manage-instances.sh start 2
# 启动所有实例
./manage-instances.sh start all
# 停止指定实例
./manage-instances.sh stop 1
# 停止所有实例
./manage-instances.sh stop all
# 重启指定实例
./manage-instances.sh restart 3
# 查看指定实例的日志
./manage-instances.sh logs 1
# 删除指定实例
./manage-instances.sh remove 2
# 删除所有实例
./manage-instances.sh remove all
```
### 3. 中文字体安装
系统需要中文字体以正常生成PDF报告。如果部署时遇到 `找不到中文字体文件` 错误,可以使用以下方法解决:
#### 自动安装字体(推荐)
```bash
# 进入容器内部
docker exec -it [容器ID或名称] bash
# 在容器内执行字体安装脚本
cd /app
python src/fundamentals_llm/setup_fonts.py
```
#### 手动安装字体
1. 在宿主机上下载中文字体文件如simhei.ttf、wqy-microhei.ttc等
2. 创建字体目录并复制字体文件:
```bash
mkdir -p src/fundamentals_llm/fonts
cp 你下载的字体文件.ttf src/fundamentals_llm/fonts/simhei.ttf
```
3. 重新构建镜像:
```bash
docker-compose build
```
### 4. 访问服务
部署完成后可以通过以下URL访问服务
- 单实例:`http://服务器IP:5000`
- 多实例:`http://服务器IP:508X`X为实例序号从8开始
例如:
- 实例1`http://服务器IP:5088`
- 实例2`http://服务器IP:5089`
- 实例3`http://服务器IP:5090`
## 数据库查询示例
以下SQL查询示例展示了如何从数据库中筛选特定类型的企业
@ -251,3 +372,11 @@ AND stock_code IN (
```
注意实际查询时可能需要根据数据库表结构调整SQL语句。建议根据具体需求组合使用不同的筛选条件。
# 重新构建镜像
docker-compose build
# 重启所有实例
./manage-instances.sh restart all

View File

@ -13,5 +13,5 @@ volcengine-python-sdk[ark]
openai>=1.0
reportlab>=4.3.1
markdown2>=2.5.3
# 1. 按下 Win+R ,输入 regedit 打开注册表编辑器。
# 2. 设置 \HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem 路径下的变量 LongPathsEnabled 为 1 即可。
google-genai
redis==5.2.1

View File

@ -1,27 +1,47 @@
import sys
import os
from datetime import datetime, timedelta
import pandas as pd
import uuid
import json
from threading import Thread
from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from flask import Flask, jsonify, request
from flask import Flask, jsonify, request, send_from_directory
from flask_cors import CORS
import logging
# 导入企业筛选器
from src.fundamentals_llm.enterprise_screener import EnterpriseScreener
# 导入股票回测器
from src.stock_analysis_v2 import run_backtest, StockBacktester
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(), # 输出到控制台
logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8') # 输出到文件
]
)
# 确保logs和results目录存在
os.makedirs('logs', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('results/tasks', exist_ok=True)
os.makedirs('static/results', exist_ok=True)
logger = logging.getLogger(__name__)
logger.info("Flask应用启动")
# 创建 Flask 应用
app = Flask(__name__)
app = Flask(__name__, static_folder='static')
CORS(app) # 启用跨域请求支持
# 创建企业筛选器实例
@ -30,6 +50,329 @@ screener = EnterpriseScreener()
# 获取数据库连接
db = next(get_db())
# 获取项目根目录
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports')
# 确保reports目录存在
os.makedirs(REPORTS_DIR, exist_ok=True)
logger.info(f"报告目录路径: {REPORTS_DIR}")
# 存储回测任务状态的字典
backtest_tasks = {}
def run_backtest_task(task_id, stocks_buy_dates, end_date):
"""
在后台运行回测任务
"""
try:
logger.info(f"开始执行回测任务 {task_id}")
# 更新任务状态为进行中
backtest_tasks[task_id]['status'] = 'running'
# 运行回测
results, stats_list = run_backtest(stocks_buy_dates, end_date)
# 如果回测成功
if results and stats_list:
# 计算总体统计
stats_df = pd.DataFrame(stats_list)
total_profit = stats_df['final_profit'].sum()
avg_win_rate = stats_df['win_rate'].mean()
avg_holding_days = stats_df['avg_holding_days'].mean()
# 找出最佳止盈比例(假设为0.15,实际应从回测结果中分析得出)
best_take_profit_pct = 0.15
# 获取图表URL
chart_urls = {
"all_stocks": f"/static/results/{task_id}/all_stocks_analysis.png",
"profit_matrix": f"/static/results/{task_id}/profit_matrix_analysis.png"
}
# 保存股票详细统计
stock_stats = []
for stat in stats_list:
stock_stats.append({
"symbol": stat['symbol'],
"total_trades": int(stat['total_trades']),
"profitable_trades": int(stat['profitable_trades']),
"loss_trades": int(stat['loss_trades']),
"win_rate": float(stat['win_rate']),
"avg_holding_days": float(stat['avg_holding_days']),
"final_profit": float(stat['final_profit']),
"entry_count": int(stat['entry_count'])
})
# 构建结果数据
result_data = {
"task_id": task_id,
"status": "completed",
"results": {
"total_profit": float(total_profit),
"win_rate": float(avg_win_rate),
"avg_holding_days": float(avg_holding_days),
"best_take_profit_pct": best_take_profit_pct,
"stock_stats": stock_stats,
"chart_urls": chart_urls
}
}
# 保存结果到文件
task_result_path = os.path.join('results', 'tasks', f"{task_id}.json")
with open(task_result_path, 'w', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=2)
# 更新任务状态为已完成
backtest_tasks[task_id].update({
'status': 'completed',
'results': result_data['results']
})
logger.info(f"回测任务 {task_id} 已完成")
else:
# 更新任务状态为失败
backtest_tasks[task_id]['status'] = 'failed'
backtest_tasks[task_id]['error'] = "回测未产生有效结果"
logger.error(f"回测任务 {task_id} 失败:回测未产生有效结果")
except Exception as e:
# 更新任务状态为失败
backtest_tasks[task_id]['status'] = 'failed'
backtest_tasks[task_id]['error'] = str(e)
logger.error(f"回测任务 {task_id} 失败:{str(e)}")
@app.route('/api/backtest/run', methods=['POST'])
def start_backtest():
"""启动回测任务
请求体格式:
{
"stocks_buy_dates": {
"SH600522": ["2022-05-10", "2022-06-10"], // 股票代码: [买入日期列表]
"SZ002340": ["2022-06-15"],
"SH601615": ["2022-07-20", "2022-08-01"]
},
"end_date": "2022-10-20" // 所有股票共同的结束日期
}
返回内容:
{
"task_id": "backtask-khkerhy4u237y489237489truiy8432"
}
"""
try:
# 从请求体获取参数
data = request.get_json()
if not data:
return jsonify({
"status": "error",
"message": "请求格式错误: 需要提供JSON数据"
}), 400
stocks_buy_dates = data.get('stocks_buy_dates')
end_date = data.get('end_date')
if not stocks_buy_dates or not isinstance(stocks_buy_dates, dict):
return jsonify({
"status": "error",
"message": "请求格式错误: 需要提供stocks_buy_dates字典"
}), 400
if not end_date or not isinstance(end_date, str):
return jsonify({
"status": "error",
"message": "请求格式错误: 需要提供有效的end_date"
}), 400
# 验证日期格式
try:
datetime.strptime(end_date, '%Y-%m-%d')
for stock_code, buy_dates in stocks_buy_dates.items():
if not isinstance(buy_dates, list):
return jsonify({
"status": "error",
"message": f"请求格式错误: 股票 {stock_code} 的买入日期必须是列表"
}), 400
for buy_date in buy_dates:
datetime.strptime(buy_date, '%Y-%m-%d')
except ValueError as e:
return jsonify({
"status": "error",
"message": f"日期格式错误: {str(e)}"
}), 400
# 生成任务ID
task_id = f"backtask-{uuid.uuid4().hex[:16]}"
# 创建任务目录
task_dir = os.path.join('static', 'results', task_id)
os.makedirs(task_dir, exist_ok=True)
# 记录任务信息
backtest_tasks[task_id] = {
'status': 'pending',
'created_at': datetime.now().isoformat(),
'stocks_buy_dates': stocks_buy_dates,
'end_date': end_date
}
# 创建线程运行回测
thread = Thread(target=run_backtest_task, args=(task_id, stocks_buy_dates, end_date))
thread.daemon = True
thread.start()
logger.info(f"已创建回测任务: {task_id}")
return jsonify({
"task_id": task_id
})
except Exception as e:
logger.error(f"创建回测任务失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"创建回测任务失败: {str(e)}"
}), 500
@app.route('/api/backtest/status', methods=['GET'])
def check_backtest_status():
"""查询回测任务状态
参数:
- task_id: 回测任务ID
返回内容:
{
"task_id": "backtask-khkerhy4u237y489237489truiy8432",
"status": "running" | "completed" | "failed",
"created_at": "2023-12-01T10:30:45",
"error": "错误信息(如有)"
}
"""
try:
task_id = request.args.get('task_id')
if not task_id:
return jsonify({
"status": "error",
"message": "请求格式错误: 需要提供task_id参数"
}), 400
# 检查任务是否存在
if task_id not in backtest_tasks:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 获取任务信息
task_info = backtest_tasks[task_id]
# 构建响应
response = {
"task_id": task_id,
"status": task_info['status'],
"created_at": task_info['created_at']
}
# 如果任务失败,添加错误信息
if task_info['status'] == 'failed' and 'error' in task_info:
response['error'] = task_info['error']
return jsonify(response)
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"查询任务状态失败: {str(e)}"
}), 500
@app.route('/api/backtest/result', methods=['GET'])
def get_backtest_result():
"""获取回测任务结果
参数:
- task_id: 回测任务ID
返回内容:
{
"task_id": "backtask-2023121-001",
"status": "completed",
"results": {
"total_profit": 123456.78,
"win_rate": 75.5,
"avg_holding_days": 12.3,
"best_take_profit_pct": 0.15,
"stock_stats": [...],
"chart_urls": {
"all_stocks": "/static/results/backtask-2023121-001/all_stocks_analysis.png",
"profit_matrix": "/static/results/backtask-2023121-001/profit_matrix_analysis.png"
}
}
}
"""
try:
task_id = request.args.get('task_id')
if not task_id:
return jsonify({
"status": "error",
"message": "请求格式错误: 需要提供task_id参数"
}), 400
# 检查任务是否存在
if task_id not in backtest_tasks:
# 尝试从文件加载
task_result_path = os.path.join('results', 'tasks', f"{task_id}.json")
if os.path.exists(task_result_path):
with open(task_result_path, 'r', encoding='utf-8') as f:
result_data = json.load(f)
return jsonify(result_data)
else:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 获取任务信息
task_info = backtest_tasks[task_id]
# 检查任务是否完成
if task_info['status'] != 'completed':
return jsonify({
"status": "error",
"message": f"任务尚未完成或已失败: {task_info['status']}"
}), 400
# 构建响应
response = {
"task_id": task_id,
"status": "completed",
"results": task_info['results']
}
return jsonify(response)
except Exception as e:
logger.error(f"获取任务结果失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取任务结果失败: {str(e)}"
}), 500
@app.route('/api/reports/<path:filename>')
def serve_report(filename):
"""提供PDF报告的访问"""
try:
logger.info(f"请求文件: {filename}")
logger.info(f"从目录: {REPORTS_DIR} 提供文件")
return send_from_directory(REPORTS_DIR, filename, as_attachment=True)
except Exception as e:
logger.error(f"提供报告文件失败: {str(e)}")
return jsonify({"error": "文件不存在"}), 404
@app.route('/api/health', methods=['GET'])
def health_check():
"""健康检查接口"""
@ -185,10 +528,10 @@ def generate_reports():
# 导入 PDF 生成器模块
try:
from src.fundamentals_llm.pdf_generator import generate_investment_report
from src.fundamentals_llm.pdf_generator import PDFGenerator
except ImportError:
try:
from fundamentals_llm.pdf_generator import generate_investment_report
from fundamentals_llm.pdf_generator import PDFGenerator
except ImportError as e:
logger.error(f"无法导入 PDF 生成器模块: {str(e)}")
return jsonify({
@ -200,8 +543,14 @@ def generate_reports():
generated_reports = []
for stock_code, stock_name in stocks:
try:
# 创建 PDF 生成器实例
generator = PDFGenerator()
# 调用 PDF 生成器
report_path = generate_investment_report(stock_code, stock_name)
report_path = generator.generate_pdf(
title=f"{stock_name}({stock_code}) 基本面分析报告",
content_dict={}, # 这里需要传入实际的内容字典
filename=f"{stock_name}_{stock_code}_analysis.pdf"
)
generated_reports.append({
"code": stock_code,
"name": stock_name,
@ -505,9 +854,10 @@ def analyze_and_recommend():
"message": f"分析和推荐股票失败: {str(e)}"
}), 500
@app.route('/api/comprehensive_analysis', methods=['POST'])
def comprehensive_analysis():
"""综合分析接口 - 组合多种功能和参数
"""综合分析接口 - 使用队列方式处理被锁定的股票
请求体格式:
{
@ -558,36 +908,9 @@ def comprehensive_analysis():
if not isinstance(limit, int) or limit <= 0:
limit = 10
# 导入必要的聊天机器人模块
# 导入必要的模块
try:
# 首先尝试导入聊天机器人模块
try:
from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
logger.info("成功从 src.fundamentals_llm.chat_bot 导入 ChatBot")
except ImportError as e1:
try:
from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
logger.info("成功从 fundamentals_llm.chat_bot 导入 ChatBot")
except ImportError as e2:
logger.error(f"无法导入在线聊天机器人模块: {str(e1)}, {str(e2)}")
return jsonify({
"status": "error",
"message": f"服务器配置错误: 聊天机器人模块不可用,错误详情: {str(e2)}"
}), 500
# 然后尝试导入离线聊天机器人模块
try:
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
logger.info("成功从 src.fundamentals_llm.chat_bot_with_offline 导入 ChatBot")
except ImportError as e1:
try:
from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
logger.info("成功从 fundamentals_llm.chat_bot_with_offline 导入 ChatBot")
except ImportError as e2:
logger.warning(f"无法导入离线聊天机器人模块: {str(e1)}, {str(e2)}")
# 这里可以继续执行,因为某些功能可能不需要离线模型
# 最后导入基本面分析器
# 先导入基本面分析器
try:
from src.fundamentals_llm.fundamental_analysis import FundamentalAnalyzer
logger.info("成功从 src.fundamentals_llm.fundamental_analysis 导入 FundamentalAnalyzer")
@ -601,6 +924,19 @@ def comprehensive_analysis():
"status": "error",
"message": f"服务器配置错误: 基本面分析模块不可用,错误详情: {str(e2)}"
}), 500
# 再导入其他可能需要的模块
try:
from src.fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
except ImportError:
try:
from fundamentals_llm.chat_bot import ChatBot as OnlineChatBot
from fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
except ImportError:
# 这些模块不是必须的,所以继续执行
logger.warning("无法导入聊天机器人模块,但这不会影响基本功能")
except Exception as e:
logger.error(f"导入必要模块时出错: {str(e)}")
return jsonify({
@ -611,64 +947,141 @@ def comprehensive_analysis():
# 创建基本面分析器实例
analyzer = FundamentalAnalyzer()
# 为每个股票生成投资建议
investment_advices = []
for stock_code, stock_name in stocks:
# 准备结果容器
investment_advices = {} # 使用字典,股票代码作为键
processing_queue = list(stocks) # 初始处理队列
max_attempts = 5 # 最大重试次数
total_attempts = 0
# 导入数据库模块
from src.fundamentals_llm.fundamental_analysis_database import get_analysis_result, get_db
# 开始处理队列
while processing_queue and total_attempts < max_attempts:
total_attempts += 1
logger.info(f"开始第 {total_attempts} 轮处理,队列中有 {len(processing_queue)} 只股票")
# 暂存下一轮需要处理的股票
next_round_queue = []
# 处理当前队列中的所有股票
for stock_code, stock_name in processing_queue:
try:
# 生成投资建议
# 检查是否已有分析结果
db = next(get_db())
existing_result = get_analysis_result(db, stock_code, "investment_advice")
# 如果已有近期结果,直接使用
if existing_result and existing_result.update_time > datetime.now() - timedelta(hours=12):
investment_advices[stock_code] = {
"code": stock_code,
"name": stock_name,
"advice": existing_result.ai_response,
"reasoning": existing_result.reasoning_process,
"references": existing_result.references,
"status": "success",
"from_cache": True
}
logger.info(f"使用缓存的 {stock_name}({stock_code}) 分析结果")
continue
# 检查是否被锁定
if analyzer.is_stock_locked(stock_code, "investment_advice"):
# 已被锁定,放到下一轮队列
next_round_queue.append([stock_code, stock_name])
# 记录状态
if stock_code not in investment_advices:
investment_advices[stock_code] = {
"code": stock_code,
"name": stock_name,
"status": "pending",
"message": f"股票 {stock_code} 正在被其他请求分析中,已加入等待队列"
}
logger.info(f"股票 {stock_name}({stock_code}) 已被锁定,放入下一轮队列")
continue
# 尝试锁定并分析
analyzer.lock_stock(stock_code, "investment_advice")
try:
# 执行分析
success, advice, reasoning, references = analyzer.query_analysis(
stock_code, stock_name, "investment_advice"
)
# 记录结果
if success:
investment_advices.append({
investment_advices[stock_code] = {
"code": stock_code,
"name": stock_name,
"advice": advice,
"reasoning": reasoning,
"references": references,
"status": "success"
})
logger.info(f"成功生成 {stock_name}({stock_code}) 的投资建议")
}
logger.info(f"成功分析 {stock_name}({stock_code})")
else:
investment_advices.append({
investment_advices[stock_code] = {
"code": stock_code,
"name": stock_name,
"status": "error",
"error": advice
})
logger.error(f"生成 {stock_name}({stock_code}) 的投资建议失败: {advice}")
"error": advice or "分析失败,无详细信息"
}
logger.error(f"分析 {stock_name}({stock_code}) 失败: {advice}")
finally:
# 确保释放锁
analyzer.unlock_stock(stock_code, "investment_advice")
except Exception as e:
logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}")
investment_advices.append({
# 处理异常
investment_advices[stock_code] = {
"code": stock_code,
"name": stock_name,
"status": "error",
"error": str(e)
}
logger.error(f"处理 {stock_name}({stock_code}) 时出错: {str(e)}")
# 确保释放锁
try:
analyzer.unlock_stock(stock_code, "investment_advice")
except:
pass
# 如果还有下一轮要处理的股票,等待一段时间后继续
if next_round_queue:
logger.info(f"本轮结束,还有 {len(next_round_queue)} 只股票等待下一轮处理")
# 等待30秒再处理下一轮
import time
time.sleep(30)
processing_queue = next_round_queue
else:
# 所有股票都已处理完,退出循环
logger.info("所有股票处理完毕")
processing_queue = []
# 处理仍在队列中的股票(达到最大重试次数但仍未处理的)
for stock_code, stock_name in processing_queue:
if stock_code in investment_advices and investment_advices[stock_code]["status"] == "pending":
investment_advices[stock_code].update({
"status": "timeout",
"message": f"等待超时,股票 {stock_code} 可能正在被长时间分析"
})
logger.warning(f"股票 {stock_name}({stock_code}) 分析超时")
# 将字典转换为列表
investment_advice_list = list(investment_advices.values())
# 生成PDF报告如果需要
pdf_results = []
if generate_pdf:
try:
# 导入PDF生成器模块
# 针对已成功分析的股票生成PDF
for stock_info in investment_advice_list:
if stock_info["status"] == "success":
stock_code = stock_info["code"]
stock_name = stock_info["name"]
try:
from src.fundamentals_llm.pdf_generator import generate_investment_report
except ImportError:
try:
from fundamentals_llm.pdf_generator import generate_investment_report
except ImportError as e:
logger.error(f"无法导入 PDF 生成器模块: {str(e)}")
return jsonify({
"status": "error",
"message": f"服务器配置错误: PDF 生成器模块不可用, 错误详情: {str(e)}"
}), 500
# 生成报告
for stock_code, stock_name in stocks:
try:
# 调用 PDF 生成器
report_path = generate_investment_report(stock_code, stock_name)
report_path = analyzer.generate_pdf_report(stock_code, stock_name)
if report_path:
pdf_results.append({
"code": stock_code,
"name": stock_name,
@ -676,6 +1089,14 @@ def comprehensive_analysis():
"status": "success"
})
logger.info(f"成功生成 {stock_name}({stock_code}) 的投资报告: {report_path}")
else:
pdf_results.append({
"code": stock_code,
"name": stock_name,
"status": "error",
"error": "生成报告失败"
})
logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败")
except Exception as e:
logger.error(f"生成 {stock_name}({stock_code}) 的投资报告失败: {str(e)}")
pdf_results.append({
@ -784,11 +1205,24 @@ def comprehensive_analysis():
logger.error(f"应用企业画像筛选失败: {str(e)}")
filtered_stocks = []
# 统计各种状态的股票数量
success_count = sum(1 for item in investment_advice_list if item["status"] == "success")
pending_count = sum(1 for item in investment_advice_list if item["status"] == "pending")
timeout_count = sum(1 for item in investment_advice_list if item["status"] == "timeout")
error_count = sum(1 for item in investment_advice_list if item["status"] == "error")
# 返回结果
response = {
"status": "success",
"status": "success" if success_count > 0 else "partial_success" if success_count + pending_count > 0 else "failed",
"total_input_stocks": len(stocks),
"investment_advices": investment_advices
"stats": {
"success": success_count,
"pending": pending_count,
"timeout": timeout_count,
"error": error_count
},
"rounds_attempted": total_attempts,
"investment_advices": investment_advice_list
}
if profile_filter:

View File

@ -150,11 +150,15 @@ class ChatBot:
logger.error(f"格式化参考资料时出错: {str(e)}")
return str(ref)
def chat(self, user_input: str) -> Dict[str, Any]:
def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]:
"""与AI进行对话
Args:
user_input: 用户输入的问题
temperature: 控制输出的随机性范围0-2默认1.0
top_p: 控制输出的多样性范围0-1默认0.7
max_tokens: 控制输出的最大长度默认4096
frequency_penalty: 控制重复惩罚范围-2到2默认0.0
Returns:
Dict[str, Any]: 包含以下字段的字典
@ -173,6 +177,10 @@ class ChatBot:
stream = self.client.chat.completions.create(
model=self.model,
messages=self.conversation_history,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
stream=True
)

View File

@ -0,0 +1,420 @@
import logging
import os
import sys
import json
from typing import Dict, List, Optional, Any, Union
from pydantic import BaseModel, Field
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
# 设置日志记录
logger = logging.getLogger(__name__)
# 定义基础数据结构
class TextAnalysisResult(BaseModel):
"""文本分析结果包含分析正文、推理过程和引用URL"""
analysis_text: str = Field(description="详细的分析文本")
reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None)
references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None)
class NumericalAnalysisResult(BaseModel):
"""数值分析结果,包含数值和分析描述"""
value: str = Field(description="评估值")
description: str = Field(description="评估描述")
# 分析类型到模型类的映射
ANALYSIS_TYPE_MODEL_MAP = {
# 文本分析类型
"text_analysis_result": TextAnalysisResult,
# 数值分析类型
"numerical_analysis_result": NumericalAnalysisResult
}
class JsonOutputParser:
"""解析JSON输出的工具类"""
def __init__(self, pydantic_object):
"""初始化解析器
Args:
pydantic_object: 需要解析成的Pydantic模型类
"""
self.pydantic_object = pydantic_object
def get_format_instructions(self) -> str:
"""获取格式化指令"""
schema = self.pydantic_object.schema()
schema_str = json.dumps(schema, ensure_ascii=False, indent=2)
return f"""请将你的回答格式化为符合以下JSON结构的格式:
{schema_str}
你的回答应该只包含一个有效的JSON对象而不包含其他内容请不要包含任何解释或前缀仅输出JSON本身
"""
def parse(self, text: str) -> Any:
"""解析文本为对象"""
try:
# 尝试解析JSON
text = text.strip()
# 如果文本以```json开头且以```结尾,则提取中间部分
if text.startswith("```json") and text.endswith("```"):
text = text[7:-3].strip()
# 如果文本以```开头且以```结尾,则提取中间部分
elif text.startswith("```") and text.endswith("```"):
text = text[3:-3].strip()
json_obj = json.loads(text)
return self.pydantic_object.parse_obj(json_obj)
except Exception as e:
raise ValueError(f"无法解析为JSON或无法匹配模式: {e}")
# 获取项目根目录的绝对路径
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 尝试导入配置
try:
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(ROOT_DIR))
sys.path.append(ROOT_DIR)
# 导入配置
try:
from scripts.config import get_model_config
logger.info("成功从scripts.config导入配置")
except ImportError:
try:
from src.scripts.config import get_model_config
logger.info("成功从src.scripts.config导入配置")
except ImportError:
logger.warning("无法导入配置模块,使用默认配置")
# 使用默认配置的实现
def get_model_config(platform: str, model_type: str):
return {
"api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU",
"model": "gemini-2.0-flash"
}
except Exception as e:
logger.error(f"导入配置时出错: {str(e)},使用默认配置")
# 使用默认配置的实现
def get_model_config(platform: str, model_type: str):
return {
"api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU",
"model": "gemini-2.0-flash"
}
class ChatBot:
def __init__(self, platform: str = "Gemini", model_type: str = "gemini-2.0-flash"):
"""初始化聊天机器人
Args:
platform: 平台名称默认为Gemini
model_type: 要使用的模型类型默认为gemini-2.0-flash
"""
try:
# 从配置获取API配置
config = get_model_config(platform, model_type)
self.api_key = config["api_key"]
self.model = config["model"] # 直接使用配置中的模型名称
logger.info(f"初始化ChatBot使用平台: {platform}, 模型: {self.model}")
google_search_tool = Tool(
google_search=GoogleSearch()
)
self.client = genai.Client(api_key=self.api_key)
# 系统提示语
self.system_instruction = """你是一位经验丰富的专业投资经理,擅长基本面分析和投资决策。你的分析特点如下:
1. 分析风格
- 专业客观理性
- 注重数据支撑
- 关注风险控制
- 重视投资性价比
2. 分析框架
- 公司基本面分析
- 行业竞争格局
- 估值水平评估
- 风险因素识别
- 投资机会判断
3. 输出要求
- 简明扼要重点突出
- 逻辑清晰层次分明
- 数据准确论据充分
- 结论明确建议具体
请用专业投资经理的视角对股票进行深入分析和投资建议如果信息不足请明确指出"""
# 对话历史
self.conversation_history = []
except Exception as e:
logger.error(f"初始化ChatBot时出错: {str(e)}")
raise
def chat(self, user_input: str, analysis_type: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096) -> dict:
"""处理用户输入并返回AI回复
Args:
user_input: 用户输入的问题
analysis_type: 分析类型如company_profile或business_scale等
temperature: 控制输出的随机性范围0-2默认0.7
top_p: 控制输出的多样性范围0-1默认0.7
max_tokens: 控制输出的最大长度默认4096
Returns:
dict: 包含AI回复的字典
"""
try:
from google.genai import types
google_search_tool = Tool(
google_search=GoogleSearch()
)
# 确定分析类型和模型
model_class = None
if analysis_type and analysis_type in ANALYSIS_TYPE_MODEL_MAP:
model_class = ANALYSIS_TYPE_MODEL_MAP[analysis_type]
parser = JsonOutputParser(pydantic_object=model_class)
format_instructions = parser.get_format_instructions()
# 在原有基础上添加格式指令
combined_instruction = f"{self.system_instruction}\n\n{format_instructions}"
else:
combined_instruction = self.system_instruction
parser = None
response = self.client.models.generate_content_stream(
model=self.model,
config=GenerateContentConfig(
system_instruction=combined_instruction,
tools=[google_search_tool],
response_modalities=["TEXT"],
temperature=temperature,
top_p=top_p,
max_output_tokens=max_tokens
),
contents=user_input
)
full_response = ""
for chunk in response:
full_response += chunk.text
print(chunk.text, end="")
# 如果指定了分析类型,尝试解析输出
if parser:
try:
# 检查分析类型是文本分析还是数值分析
is_text_analysis = model_class == TextAnalysisResult
is_numerical_analysis = model_class == NumericalAnalysisResult
# 使用解析方法尝试处理JSON响应
json_result = self.parse_json_response(full_response, model_class)
if json_result:
return json_result
# 如果使用解析方法失败尝试使用parser解析
parsed_result = parser.parse(full_response)
structured_result = parsed_result.dict()
if isinstance(parsed_result, TextAnalysisResult):
# 文本分析结果
return {
"response": structured_result["analysis_text"],
"reasoning_process": structured_result.get("reasoning_process"),
"references": structured_result.get("references")
}
elif isinstance(parsed_result, NumericalAnalysisResult):
# 数值分析结果
return {
"response": structured_result["value"],
"description": structured_result.get("description", "")
}
except Exception as e:
logger.warning(f"解析响应失败: {str(e)},返回原始响应")
# 最后尝试根据分析类型直接构建结果
if is_text_analysis:
return {
"response": full_response,
"reasoning_process": None,
"references": None
}
elif is_numerical_analysis:
try:
# 尝试从文本中提取数字
import re
number_match = re.search(r'\b(\d+)\b', full_response)
value = int(number_match.group(1)) if number_match else 0
return {
"response": value,
"description": full_response
}
except:
return {
"response": 0,
"description": full_response
}
else:
return {
"response": full_response,
"reasoning_process": None,
"references": None
}
else:
# 普通响应
return {
"response": full_response,
"reasoning_process": None,
"references": None
}
except Exception as e:
logger.error(f"聊天失败: {str(e)}")
error_msg = f"抱歉,发生错误: {str(e)}"
print(f"\n{error_msg}")
return {"response": error_msg, "reasoning_process": None, "references": None}
def clear_history(self):
"""清除对话历史"""
self.conversation_history = []
# 重置会话
if hasattr(self, 'model_instance'):
self.chat_session = self.model_instance.start_chat()
print("对话历史已清除")
def run(self):
"""运行聊天机器人"""
print("欢迎使用Gemini AI助手输入 'quit' 退出,输入 'clear' 清除对话历史。")
print("-" * 50)
while True:
try:
# 获取用户输入
user_input = input("\n你: ").strip()
# 检查退出命令
if user_input.lower() == 'quit':
print("感谢使用,再见!")
break
# 检查清除历史命令
if user_input.lower() == 'clear':
self.clear_history()
continue
# 如果输入为空,继续下一轮
if not user_input:
continue
# 获取AI回复
self.chat(user_input)
print("-" * 50)
except KeyboardInterrupt:
print("\n感谢使用,再见!")
break
except Exception as e:
logger.error(f"运行错误: {str(e)}")
print(f"发生错误: {str(e)}")
def parse_json_response(self, full_response: str, model_class) -> dict:
"""解析JSON响应处理不同的JSON结构
Args:
full_response: 完整的响应文本
model_class: 模型类TextAnalysisResult或NumericalAnalysisResult
Returns:
dict: 包含解析结果的字典
"""
# 检查分析类型
is_text_analysis = model_class == TextAnalysisResult
is_numerical_analysis = model_class == NumericalAnalysisResult
try:
# 尝试提取JSON内容
json_text = full_response
# 如果文本以```json开头且以```结尾,则提取中间部分
if "```json" in json_text and "```" in json_text.split("```json", 1)[1]:
json_text = json_text.split("```json", 1)[1].split("```", 1)[0].strip()
# 如果文本以```开头且以```结尾,则提取中间部分
elif "```" in json_text and "```" in json_text.split("```", 1)[1]:
json_text = json_text.split("```", 1)[1].split("```", 1)[0].strip()
# 尝试解析JSON
json_obj = json.loads(json_text)
# 情况1: 直接包含字段的格式 {"analysis_text": "...", "reasoning_process": null, "references": [...]}
if is_text_analysis and "analysis_text" in json_obj:
return {
"response": json_obj["analysis_text"],
"reasoning_process": json_obj.get("reasoning_process"),
"references": json_obj.get("references", [])
}
elif is_numerical_analysis and "value" in json_obj:
value = json_obj["value"]
if isinstance(value, str) and value.isdigit():
value = int(value)
elif isinstance(value, (int, float)):
value = int(value)
else:
value = 0
return {
"response": value,
"description": json_obj.get("description", "")
}
# 情况2: properties中包含字段的格式
if "properties" in json_obj and isinstance(json_obj["properties"], dict):
properties = json_obj["properties"]
if is_text_analysis and "analysis_text" in properties:
return {
"response": properties["analysis_text"],
"reasoning_process": properties.get("reasoning_process"),
"references": properties.get("references", [])
}
elif is_numerical_analysis and "value" in properties:
value = properties["value"]
if isinstance(value, str) and value.isdigit():
value = int(value)
elif isinstance(value, (int, float)):
value = int(value)
else:
value = 0
return {
"response": value,
"description": properties.get("description", "")
}
# 如果无法识别结构使用parser解析
raise ValueError("无法识别的JSON结构尝试使用parser解析")
except Exception as e:
# 记录错误但不抛出,让调用方法继续尝试其他解析方式
logger.warning(f"JSON解析失败: {str(e)}")
return None
if __name__ == "__main__":
# 设置日志级别
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 创建并运行聊天机器人
bot = ChatBot()
bot.run()

View File

@ -3,7 +3,7 @@ from openai import OpenAI
import os
import sys
import time
import random
# 设置日志记录
logger = logging.getLogger(__name__)
@ -102,8 +102,19 @@ class ChatBot:
logger.error(f"初始化ChatBot时出错: {str(e)}")
raise
def chat(self, user_input: str) -> str:
"""处理用户输入并返回AI回复"""
def chat(self, user_input: str, temperature: float = 1.0, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> str:
"""处理用户输入并返回AI回复
Args:
user_input: 用户输入的问题
temperature: 控制输出的随机性范围0-2默认1.0
top_p: 控制输出的多样性范围0-1默认0.7
max_tokens: 控制输出的最大长度默认4096
frequency_penalty: 控制重复惩罚范围-2到2默认0.0
Returns:
str: AI的回答内容
"""
try:
# # 添加用户消息到对话历史-多轮
self.conversation_history.append({
@ -116,7 +127,10 @@ class ChatBot:
stream = self.client.chat.completions.create(
model=self.model,
messages=self.conversation_history,
temperature=0,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
stream=True,
timeout=600
)

View File

@ -130,9 +130,15 @@ class EnterpriseScreener:
{
'dimension': 'investment_advice',
'field': 'investment_advice_type',
'operator': '!=',
'value': '不建议'
}
'operator': 'in',
'value': ["'短期'","'中期'","'长期'"]
},
{
'dimension': 'financial_report',
'field': 'financial_report_level',
'operator': '>=',
'value': -1
},
]
return self._screen_stocks_by_conditions(conditions, limit)

View File

@ -1,7 +1,11 @@
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Callable
import sys
from datetime import datetime, timedelta
import time
import redis
from typing import Dict, List, Optional, Tuple, Callable, Any
# 修改导入路径,使用相对导入
try:
# 尝试相对导入
@ -9,6 +13,7 @@ try:
from .chat_bot_with_offline import ChatBot as OfflineChatBot
from .fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
from .pdf_generator import PDFGenerator
from .text_processor import TextProcessor
except ImportError:
# 如果相对导入失败,尝试尝试绝对导入
try:
@ -16,18 +21,17 @@ except ImportError:
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
from src.fundamentals_llm.fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
from src.fundamentals_llm.pdf_generator import PDFGenerator
from src.fundamentals_llm.text_processor import TextProcessor
except ImportError:
# 最后尝试直接导入适用于当前目录已在PYTHONPATH中的情况
from chat_bot import ChatBot
from chat_bot_with_offline import ChatBot as OfflineChatBot
from fundamental_analysis_database import get_db, save_analysis_result, update_analysis_result, get_analysis_result
from pdf_generator import PDFGenerator
from text_processor import TextProcessor
import json
import re
# 设置日志记录
logger = logging.getLogger(__name__)
# 获取项目根目录的绝对路径
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@ -58,6 +62,34 @@ logging.basicConfig(
datefmt=date_format
)
logger = logging.getLogger(__name__)
logger.info("测试日志输出 - 程序启动")
from typing import Dict, List, Optional, Any, Union
from pydantic import BaseModel, Field
# 定义基础数据结构
class TextAnalysisResult(BaseModel):
"""文本分析结果包含分析正文、推理过程和引用URL"""
analysis_text: str = Field(description="详细的分析文本")
reasoning_process: Optional[str] = Field(description="模型的推理过程", default=None)
references: Optional[List[str]] = Field(description="参考资料和引用URL列表", default=None)
class NumericalAnalysisResult(BaseModel):
"""数值分析结果,包含数值和分析描述"""
value: str = Field(description="评估值")
description: str = Field(description="评估描述")
# 添加Redis客户端
redis_client = redis.Redis(
host='192.168.18.208', # Redis服务器地址根据实际情况调整
port=6379,
password='wlkj2018',
db=14,
socket_timeout=5,
decode_responses=True
)
class FundamentalAnalyzer:
"""基本面分析器"""
@ -66,9 +98,10 @@ class FundamentalAnalyzer:
# 使用联网模型进行基本面分析
self.chat_bot = ChatBot(model_type="online_bot")
# 使用离线模型进行其他分析
self.offline_bot = OfflineChatBot(platform="tl_private", model_type="ds-v1")
self.offline_bot = OfflineChatBot(platform="volc", model_type="offline_model")
# 千问打杂
self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq")
# self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="qwq")
self.offline_bot_tl_qw = OfflineChatBot(platform="tl_qw_private", model_type="GLM")
self.db = next(get_db())
# 定义维度映射
@ -161,6 +194,8 @@ class FundamentalAnalyzer:
- 关键战略决策和转型
请提供专业客观的分析突出关键信息避免冗长描述"""
#开头清上下文缓存
self.chat_bot.clear_history()
# 获取AI分析结果
result = self.chat_bot.chat(prompt)
@ -554,8 +589,8 @@ class FundamentalAnalyzer:
请仅返回一个数值210-1不要包含任何解释或说明"""
self.offline_bot_tl_qw.clear_history()
# 使用离线模型进行分析
space_value_str = self.offline_bot_tl_qw.chat(prompt)
space_value_str = self._clean_model_output(space_value_str)
space_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
space_value_str = TextProcessor.clean_thought_process(space_value_str)
# 提取数值
space_value = 0 # 默认值
@ -658,9 +693,9 @@ class FundamentalAnalyzer:
请仅返回一个数值10-1不要包含任何解释或说明"""
self.offline_bot_tl_qw.clear_history()
# 使用离线模型进行分析
events_value_str = self.offline_bot_tl_qw.chat(prompt)
events_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
# 数据清洗
events_value_str = self._clean_model_output(events_value_str)
events_value_str = TextProcessor.clean_thought_process(events_value_str)
# 提取数值
events_value = 0 # 默认值
@ -700,14 +735,14 @@ class FundamentalAnalyzer:
def analyze_stock_discussion(self, stock_code: str, stock_name: str) -> bool:
"""分析股吧讨论内容"""
try:
prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在300字以内,请严格按照以下格式输出:
prompt = f"""请对{stock_name}({stock_code})的股吧讨论内容进行简要分析,要求输出控制在400字以内(主要讨论话题200字重要信息汇总200字),请严格按照以下格式输出:
1. 主要讨论话题150字左右
1. 主要讨论话题
- 近期热点事件
- 投资者关注焦点
- 市场情绪倾向
2. 重要信息汇总150字左右
2. 重要信息汇总
- 公司相关动态
- 行业政策变化
- 市场预期变化
@ -762,10 +797,10 @@ class FundamentalAnalyzer:
请仅返回一个数值10-1不要包含任何解释或说明"""
self.offline_bot_tl_qw.clear_history()
# 使用离线模型进行分析
emotion_value_str = self.offline_bot_tl_qw.chat(prompt)
emotion_value_str = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
# 数据清洗
emotion_value_str = self._clean_model_output(emotion_value_str)
emotion_value_str = TextProcessor.clean_thought_process(emotion_value_str)
# 提取数值
emotion_value = 0 # 默认值
@ -1015,10 +1050,10 @@ class FundamentalAnalyzer:
只需要输出一个数值不要输出任何说明或解释只输出2,1,0,-1-2"""
response = self.chat_bot.chat(prompt)
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
# 提取数值
rating_str = self._extract_numeric_value_from_response(response)
rating_str = TextProcessor.extract_numeric_value_from_response(response)
# 尝试将响应转换为整数
try:
@ -1057,10 +1092,10 @@ class FundamentalAnalyzer:
只需要输出一个数值不要输出任何说明或解释只输出10-1不要包含任何解释或说明"""
response = self.chat_bot.chat(prompt)
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
# 提取数值
odds_str = self._extract_numeric_value_from_response(response)
odds_str = TextProcessor.extract_numeric_value_from_response(response)
# 尝试将响应转换为整数
try:
@ -1200,8 +1235,8 @@ class FundamentalAnalyzer:
只需要输出一个数值不要输出任何说明或解释只输出-10或1"""
self.offline_bot_tl_qw.clear_history()
response = self.offline_bot_tl_qw.chat(prompt)
pe_hist_str = self._clean_model_output(response)
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
pe_hist_str = TextProcessor.clean_thought_process(response)
try:
pe_hist = int(pe_hist_str)
@ -1240,8 +1275,8 @@ class FundamentalAnalyzer:
只需要输出一个数值不要输出任何说明或解释只输出-10或1"""
self.offline_bot_tl_qw.clear_history()
response = self.offline_bot_tl_qw.chat(prompt)
pb_hist_str = self._clean_model_output(response)
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
pb_hist_str = TextProcessor.clean_thought_process(response)
try:
pb_hist = int(pb_hist_str)
@ -1280,8 +1315,8 @@ class FundamentalAnalyzer:
只需要输出一个数值不要输出任何说明或解释只输出-10或1"""
self.offline_bot_tl_qw.clear_history()
response = self.offline_bot_tl_qw.chat(prompt)
pe_ind_str = self._clean_model_output(response)
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
pe_ind_str = TextProcessor.clean_thought_process(response)
try:
pe_ind = int(pe_ind_str)
@ -1320,8 +1355,8 @@ class FundamentalAnalyzer:
只需要输出一个数值不要输出任何说明或解释只输出-10或1"""
self.offline_bot_tl_qw.clear_history()
response = self.offline_bot_tl_qw.chat(prompt)
pb_ind_str = self._clean_model_output(response)
response = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
pb_ind_str = TextProcessor.clean_thought_process(response)
try:
pb_ind = int(pb_ind_str)
@ -1373,9 +1408,9 @@ class FundamentalAnalyzer:
{json.dumps(all_results, ensure_ascii=False, indent=2)}"""
self.offline_bot.clear_history()
# 使用离线模型生成建议
result = self.offline_bot.chat(prompt)
result = self.offline_bot.chat(prompt,max_tokens=20000)
# 清理模型输出
result = self._clean_model_output(result)
result = TextProcessor.clean_thought_process(result)
# 保存到数据库
success = save_analysis_result(
self.db,
@ -1441,93 +1476,6 @@ class FundamentalAnalyzer:
logger.error(f"提取投资建议类型失败: {str(e)}")
return None
def _clean_model_output(self, output: str) -> str:
"""清理模型输出,移除推理过程,只保留最终结果
Args:
output: 模型原始输出文本
Returns:
str: 清理后的输出文本
"""
try:
# 找到</think>标签的位置
think_end = output.find('</think>')
if think_end != -1:
# 移除</think>标签及其之前的所有内容
output = output[think_end + len('</think>'):]
# 处理可能存在的空行
lines = output.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if line: # 只保留非空行
cleaned_lines.append(line)
# 重新组合文本
output = '\n'.join(cleaned_lines)
return output.strip()
except Exception as e:
logger.error(f"清理模型输出失败: {str(e)}")
return output.strip()
def _extract_numeric_value_from_response(self, response: str) -> str:
"""从模型响应中提取数值,移除参考资料和推理过程
Args:
response: 模型原始响应文本或响应对象
Returns:
str: 提取的数值字符串
"""
try:
# 处理响应对象包含response字段的字典
if isinstance(response, dict) and "response" in response:
response = response["response"]
# 确保响应是字符串
if not isinstance(response, str):
logger.warning(f"响应不是字符串类型: {type(response)}")
return "0"
# 移除推理过程部分
reasoning_start = response.find("推理过程:")
if reasoning_start != -1:
response = response[:reasoning_start].strip()
# 移除参考资料部分(通常以 [数字] 开头的行)
lines = response.split("\n")
cleaned_lines = []
for line in lines:
# 跳过参考资料行(通常以 [数字] 开头)
if re.match(r'\[\d+\]', line.strip()):
continue
cleaned_lines.append(line)
response = "\n".join(cleaned_lines).strip()
# 提取数值
# 先尝试直接将整个响应转换为数值
if response.strip() in ["-2", "-1", "0", "1", "2"]:
return response.strip()
# 如果整个响应不是数值,尝试匹配第一个数值
match = re.search(r'([-]?[0-9])', response)
if match:
return match.group(1)
# 如果没有找到数值,返回默认值
logger.warning(f"未能从响应中提取数值: {response}")
return "0"
except Exception as e:
logger.error(f"从响应中提取数值失败: {str(e)}")
return "0"
def _try_extract_advice_type(self, advice_text: str, max_attempts: int = 3) -> Optional[str]:
"""尝试多次从投资建议中提取建议类型
@ -1549,7 +1497,7 @@ class FundamentalAnalyzer:
# 使用千问离线模型提取建议类型
result = self.offline_bot_tl_qw.chat(prompt)
result = self.offline_bot_tl_qw.chat(prompt,temperature=0.0)
# 检查是否是错误响应
if isinstance(result, str) and "抱歉,发生错误" in result:
@ -1557,7 +1505,7 @@ class FundamentalAnalyzer:
continue
# 清理模型输出
cleaned_result = self._clean_model_output(result)
cleaned_result = TextProcessor.clean_thought_process(result)
# 检查结果是否为有效类型
if cleaned_result in valid_types:
@ -1644,6 +1592,24 @@ class FundamentalAnalyzer:
Optional[str]: 生成的PDF文件路径如果失败则返回None
"""
try:
# 检查是否已存在PDF文件
reports_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'reports')
os.makedirs(reports_dir, exist_ok=True)
# 构建可能的文件名格式
possible_filenames = [
f"{stock_name}_{stock_code}_analysis.pdf",
f"{stock_name}_{stock_code}.SZ_analysis.pdf",
f"{stock_name}_{stock_code}.SH_analysis.pdf"
]
# 检查是否存在已生成的PDF文件
for filename in possible_filenames:
filepath = os.path.join(reports_dir, filename)
if os.path.exists(filepath):
logger.info(f"找到已存在的PDF报告: {filepath}")
return filepath
# 维度名称映射
dimension_names = {
"company_profile": "公司简介",
@ -1669,27 +1635,156 @@ class FundamentalAnalyzer:
logger.warning(f"未找到 {stock_name}({stock_code}) 的任何分析结果")
return None
# 确保字体目录存在
fonts_dir = os.path.join(os.path.dirname(__file__), "fonts")
os.makedirs(fonts_dir, exist_ok=True)
# 检查是否存在字体文件,如果不存在则创建简单的默认字体标记文件
font_path = os.path.join(fonts_dir, "simhei.ttf")
if not os.path.exists(font_path):
# 尝试从系统字体目录复制
try:
import shutil
# 尝试常见的系统字体位置
system_fonts = [
"C:/Windows/Fonts/simhei.ttf", # Windows
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # Linux
"/usr/share/fonts/wqy-microhei/wqy-microhei.ttc", # 其他Linux
"/System/Library/Fonts/PingFang.ttc" # macOS
]
for system_font in system_fonts:
if os.path.exists(system_font):
shutil.copy2(system_font, font_path)
logger.info(f"已复制字体文件: {system_font} -> {font_path}")
break
except Exception as font_error:
logger.warning(f"复制字体文件失败: {str(font_error)}")
# 创建PDF生成器实例
generator = PDFGenerator()
# 生成PDF报告
try:
# 第一次尝试生成
filepath = generator.generate_pdf(
title=f"{stock_name}({stock_code}) 基本面分析报告",
content_dict=content_dict,
output_dir=reports_dir,
filename=f"{stock_name}_{stock_code}_analysis.pdf"
)
if filepath:
logger.info(f"PDF报告已生成: {filepath}")
else:
logger.error("PDF报告生成失败")
return filepath
except Exception as pdf_error:
logger.error(f"生成PDF报告第一次尝试失败: {str(pdf_error)}")
# 如果是字体问题,可能需要使用备选方案
if "找不到中文字体文件" in str(pdf_error):
# 导出为文本文件作为备选
try:
txt_filename = f"{stock_name}_{stock_code}_analysis.txt"
txt_filepath = os.path.join(reports_dir, txt_filename)
with open(txt_filepath, 'w', encoding='utf-8') as f:
f.write(f"{stock_name}({stock_code}) 基本面分析报告\n")
f.write(f"生成时间:{datetime.now().strftime('%Y年%m月%d%H:%M:%S')}\n\n")
for section_title, content in content_dict.items():
if content:
f.write(f"## {section_title}\n\n")
f.write(f"{content}\n\n")
logger.info(f"由于PDF生成失败已生成文本报告: {txt_filepath}")
return txt_filepath
except Exception as txt_error:
logger.error(f"生成文本报告失败: {str(txt_error)}")
logger.error("PDF报告生成失败")
return None
except Exception as e:
logger.error(f"生成PDF报告失败: {str(e)}")
return None
def is_stock_locked(self, stock_code: str, dimension: str) -> bool:
"""检查股票是否已被锁定
Args:
stock_code: 股票代码
dimension: 分析维度
Returns:
bool: 是否被锁定
"""
try:
lock_key = f"stock_analysis_lock:{stock_code}:{dimension}"
# 检查是否存在锁
existing_lock = redis_client.get(lock_key)
if existing_lock:
lock_time = int(existing_lock)
current_time = int(time.time())
# 锁超过30分钟(1800秒)视为过期
if current_time - lock_time > 1800:
# 锁已过期,可以释放
redis_client.delete(lock_key)
logger.info(f"股票 {stock_code} 维度 {dimension} 的过期锁已释放")
return False
else:
# 锁未过期,被锁定
return True
# 不存在锁
return False
except Exception as e:
logger.error(f"检查股票 {stock_code} 锁状态时出错: {str(e)}")
# 出错时保守处理,返回未锁定
return False
def lock_stock(self, stock_code: str, dimension: str) -> bool:
"""锁定股票
Args:
stock_code: 股票代码
dimension: 分析维度
Returns:
bool: 是否成功锁定
"""
try:
lock_key = f"stock_analysis_lock:{stock_code}:{dimension}"
current_time = int(time.time())
# 设置锁过期时间1小时
redis_client.set(lock_key, current_time, ex=3600)
logger.info(f"股票 {stock_code} 维度 {dimension} 已锁定")
return True
except Exception as e:
logger.error(f"锁定股票 {stock_code} 时出错: {str(e)}")
return False
def unlock_stock(self, stock_code: str, dimension: str) -> bool:
"""解锁股票
Args:
stock_code: 股票代码
dimension: 分析维度
Returns:
bool: 是否成功解锁
"""
try:
lock_key = f"stock_analysis_lock:{stock_code}:{dimension}"
redis_client.delete(lock_key)
logger.info(f"股票 {stock_code} 维度 {dimension} 已解锁")
return True
except Exception as e:
logger.error(f"解锁股票 {stock_code} 时出错: {str(e)}")
return False
def test_single_method(method: Callable, stock_code: str, stock_name: str) -> bool:
"""测试单个分析方法"""
try:
@ -1713,11 +1808,6 @@ def test_single_stock(analyzer: FundamentalAnalyzer, stock_code: str, stock_name
def main():
"""主函数"""
# 设置日志级别
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 测试股票列表
test_stocks = [

File diff suppressed because it is too large Load Diff

View File

@ -11,25 +11,25 @@ import markdown2
from bs4 import BeautifulSoup
import os
from datetime import datetime
from fpdf import FPDF
import matplotlib.pyplot as plt
import shutil
import matplotlib
matplotlib.use('Agg')
# 修改导入路径,使用相对导入
try:
# 尝试相对导入
from .chat_bot_with_offline import ChatBot
from .fundamental_analysis_database import get_db, get_analysis_result
from .chat_bot import ChatBot
from .chat_bot_with_offline import ChatBot as OfflineChatBot
except ImportError:
# 如果相对导入失败,尝试绝对导入
try:
from src.fundamentals_llm.chat_bot_with_offline import ChatBot
from src.fundamentals_llm.fundamental_analysis_database import get_db, get_analysis_result
from src.fundamentals_llm.chat_bot import ChatBot
from src.fundamentals_llm.chat_bot_with_offline import ChatBot as OfflineChatBot
except ImportError:
# 最后尝试直接导入
from chat_bot_with_offline import ChatBot
from fundamental_analysis_database import get_db, get_analysis_result
from chat_bot import ChatBot
from chat_bot_with_offline import ChatBot as OfflineChatBot
# 设置日志记录
logger = logging.getLogger(__name__)
@ -39,30 +39,9 @@ class PDFGenerator:
def __init__(self):
"""初始化PDF生成器"""
# 注册中文字体
try:
# 尝试使用系统自带的中文字体
if os.name == 'nt': # Windows
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
else: # Linux/Mac
font_path = "/usr/share/fonts/truetype/droid/DroidSansFallback.ttf"
if os.path.exists(font_path):
pdfmetrics.registerFont(TTFont('SimHei', font_path))
self.font_name = 'SimHei'
else:
# 如果找不到系统字体,尝试使用当前目录下的字体
font_path = os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf")
if os.path.exists(font_path):
pdfmetrics.registerFont(TTFont('SimHei', font_path))
self.font_name = 'SimHei'
else:
raise FileNotFoundError("找不到中文字体文件")
except Exception as e:
logger.error(f"注册中文字体失败: {str(e)}")
raise
self.chat_bot = ChatBot()
# 尝试注册中文字体
self.font_name = self._register_chinese_font()
self.chat_bot = OfflineChatBot(platform="volc", model_type="offline_model") # 使用离线模式
self.styles = getSampleStyleSheet()
# 创建自定义样式
@ -139,6 +118,64 @@ class PDFGenerator:
textColor=colors.HexColor('#333333')
))
def _register_chinese_font(self):
"""查找并注册中文字体
Returns:
str: 注册的字体名称如果失败则使用默认字体
"""
try:
# 可能的字体文件位置列表
font_locations = [
# 当前目录/子目录字体位置
os.path.join(os.path.dirname(__file__), "fonts", "simhei.ttf"),
os.path.join(os.path.dirname(__file__), "fonts", "wqy-microhei.ttc"),
os.path.join(os.path.dirname(__file__), "fonts", "wqy-zenhei.ttc"),
# Docker中可能的位置
"/app/src/fundamentals_llm/fonts/simhei.ttf",
# Windows系统字体位置
"C:/Windows/Fonts/simhei.ttf",
"C:/Windows/Fonts/simfang.ttf",
"C:/Windows/Fonts/simsun.ttc",
# Linux/Mac系统字体位置
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
"/usr/share/fonts/wqy-microhei/wqy-microhei.ttc",
"/usr/share/fonts/wqy-zenhei/wqy-zenhei.ttc",
"/usr/share/fonts/truetype/droid/DroidSansFallback.ttf",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/System/Library/Fonts/PingFang.ttc" # macOS
]
# 尝试各个位置
for font_path in font_locations:
if os.path.exists(font_path):
logger.info(f"使用字体文件: {font_path}")
# 为防止字体文件名称不同,统一拷贝到字体目录并重命名
font_dir = os.path.join(os.path.dirname(__file__), "fonts")
os.makedirs(font_dir, exist_ok=True)
target_path = os.path.join(font_dir, "simhei.ttf")
# 如果字体文件不在目标位置,拷贝过去
if font_path != target_path and not os.path.exists(target_path):
try:
shutil.copy2(font_path, target_path)
logger.info(f"已将字体文件复制到: {target_path}")
except Exception as e:
logger.warning(f"复制字体文件失败: {str(e)}")
# 注册字体
pdfmetrics.registerFont(TTFont('SimHei', target_path))
return 'SimHei'
# 如果所有位置都找不到,使用默认字体
logger.warning("找不到中文字体文件,将使用默认字体")
return 'Helvetica'
except Exception as e:
logger.error(f"注册中文字体失败: {str(e)}")
return 'Helvetica' # 使用默认字体
def _convert_markdown_to_flowables(self, markdown_text: str) -> List:
"""将Markdown文本转换为PDF流对象

View File

@ -0,0 +1,190 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
中文字体安装工具
此脚本用于下载和安装中文字体以支持PDF生成功能
"""
import os
import sys
import logging
import shutil
import tempfile
import platform
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def download_wqy_font():
"""
下载文泉驿微米黑字体
Returns:
str: 下载的字体文件路径如果失败则返回None
"""
try:
import requests
# 创建临时目录
temp_dir = tempfile.mkdtemp()
font_file = os.path.join(temp_dir, "wqy-microhei.ttc")
# 文泉驿微米黑字体的下载链接
url = "https://mirrors.tuna.tsinghua.edu.cn/osdn/wqy/wqy-microhei-0.2.0-beta.tar.gz"
logger.info(f"开始下载文泉驿微米黑字体: {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
# 下载压缩包
tar_file = os.path.join(temp_dir, "wqy-microhei.tar.gz")
with open(tar_file, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# 解压字体文件
import tarfile
with tarfile.open(tar_file) as tar:
# 提取字体文件
for member in tar.getmembers():
if member.name.endswith(('.ttc', '.ttf')):
member.name = os.path.basename(member.name)
tar.extract(member, temp_dir)
extracted_file = os.path.join(temp_dir, member.name)
if os.path.exists(extracted_file):
return extracted_file
logger.error("在压缩包中未找到字体文件")
return None
except Exception as e:
logger.error(f"下载字体失败: {str(e)}")
return None
def find_system_fonts():
"""
在系统中查找可用的中文字体
Returns:
str: 找到的字体文件路径如果找不到则返回None
"""
# 常见的中文字体位置
font_locations = []
system = platform.system()
if system == "Windows":
font_locations = [
"C:/Windows/Fonts/simhei.ttf",
"C:/Windows/Fonts/simsun.ttc",
"C:/Windows/Fonts/msyh.ttf",
"C:/Windows/Fonts/simfang.ttf",
]
elif system == "Darwin": # macOS
font_locations = [
"/System/Library/Fonts/PingFang.ttc",
"/Library/Fonts/Microsoft/SimHei.ttf",
"/Library/Fonts/Microsoft/SimSun.ttf",
]
else: # Linux
font_locations = [
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
"/usr/share/fonts/wqy-microhei/wqy-microhei.ttc",
"/usr/share/fonts/wenquanyi/wqy-microhei/wqy-microhei.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/droid/DroidSansFallback.ttf",
]
# 查找第一个存在的字体文件
for font_path in font_locations:
if os.path.exists(font_path):
logger.info(f"在系统中找到中文字体: {font_path}")
return font_path
logger.warning("在系统中未找到任何中文字体")
return None
def install_font(target_dir=None):
"""
安装中文字体
Args:
target_dir: 目标目录如果不提供则使用默认目录
Returns:
str: 安装后的字体文件路径如果失败则返回None
"""
# 如果没有指定目标目录,使用默认位置
if target_dir is None:
# 获取当前脚本所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
target_dir = os.path.join(current_dir, "fonts")
# 确保目标目录存在
os.makedirs(target_dir, exist_ok=True)
# 目标字体文件路径
target_font = os.path.join(target_dir, "simhei.ttf")
# 如果目标字体已存在,直接返回
if os.path.exists(target_font):
logger.info(f"中文字体已存在: {target_font}")
return target_font
# 尝试从系统中复制字体
system_font = find_system_fonts()
if system_font:
try:
shutil.copy2(system_font, target_font)
logger.info(f"已从系统复制字体: {system_font} -> {target_font}")
return target_font
except Exception as e:
logger.error(f"复制系统字体失败: {str(e)}")
# 如果系统中找不到字体,尝试下载
logger.info("尝试下载中文字体...")
downloaded_font = download_wqy_font()
if downloaded_font:
try:
shutil.copy2(downloaded_font, target_font)
logger.info(f"已安装下载的字体: {downloaded_font} -> {target_font}")
return target_font
except Exception as e:
logger.error(f"复制下载的字体失败: {str(e)}")
logger.error("安装中文字体失败")
return None
def main():
"""
主函数
"""
print("中文字体安装工具")
print("----------------")
target_dir = None
if len(sys.argv) > 1:
target_dir = sys.argv[1]
print(f"使用指定的目标目录: {target_dir}")
result = install_font(target_dir)
if result:
print(f"\n✓ 成功安装中文字体: {result}")
print("\n现在可以正常生成PDF报告了。")
else:
print("\n✗ 安装中文字体失败")
print("\n请手动将中文字体(.ttf或.ttc文件)复制到以下目录:")
if target_dir:
print(f" {target_dir}")
else:
print(f" {os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fonts')}")
print("\n并将其重命名为'simhei.ttf'")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,172 @@
import re
import logging
from typing import Optional, List, Dict, Any
logger = logging.getLogger(__name__)
class TextProcessor:
"""大模型文本处理工具类"""
@staticmethod
def clean_model_output(output: str) -> str:
"""清理模型输出文本
Args:
output: 模型输出的原始文本
Returns:
str: 清理后的文本
"""
try:
# 移除引用标记
output = re.sub(r'\[[^\]]*\]', '', output)
# 移除多余的空白字符
output = re.sub(r'\s+', ' ', output)
# 移除首尾的空白字符
output = output.strip()
# 移除可能存在的HTML标签
output = re.sub(r'<[^>]+>', '', output)
# 移除可能存在的Markdown格式
output = re.sub(r'[*_`~]', '', output)
return output
except Exception as e:
logger.error(f"清理模型输出失败: {str(e)}")
return output
@staticmethod
def clean_thought_process(output: str) -> str:
"""清理模型输出,移除推理过程,只保留最终结果
Args:
output: 模型原始输出文本
Returns:
str: 清理后的输出文本
"""
try:
# 找到</think>标签的位置
think_end = output.find('</think>')
if think_end != -1:
# 移除</think>标签及其之前的所有内容
output = output[think_end + len('</think>'):]
# 处理可能存在的空行
lines = output.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if line: # 只保留非空行
cleaned_lines.append(line)
# 重新组合文本
output = '\n'.join(cleaned_lines)
return output.strip()
except Exception as e:
logger.error(f"清理模型输出失败: {str(e)}")
return output.strip()
@staticmethod
def extract_numeric_value_from_response(response: str) -> str:
"""从模型响应中提取数值
Args:
response: 模型响应的文本
Returns:
str: 提取出的数值
"""
try:
# 清理响应文本
cleaned_response = TextProcessor.clean_model_output(response)
# 尝试提取数值
numeric_patterns = [
r'[-+]?\d*\.\d+', # 浮点数
r'[-+]?\d+', # 整数
r'[-+]?\d+\.\d*' # 可能带小数点的数
]
for pattern in numeric_patterns:
matches = re.findall(pattern, cleaned_response)
if matches:
# 返回第一个匹配的数值
return matches[0]
# 如果没有找到数值,返回原始响应
return cleaned_response
except Exception as e:
logger.error(f"提取数值失败: {str(e)}")
return response
@staticmethod
def extract_list_from_response(response: str) -> List[str]:
"""从模型响应中提取列表项
Args:
response: 模型响应的文本
Returns:
List[str]: 提取出的列表项
"""
try:
# 清理响应文本
cleaned_response = TextProcessor.clean_model_output(response)
# 尝试匹配列表项
# 匹配数字编号的列表项
numbered_items = re.findall(r'\d+\.\s*([^\n]+)', cleaned_response)
if numbered_items:
return [item.strip() for item in numbered_items]
# 匹配带符号的列表项
bullet_items = re.findall(r'[-*•]\s*([^\n]+)', cleaned_response)
if bullet_items:
return [item.strip() for item in bullet_items]
# 如果没有找到列表项,返回空列表
return []
except Exception as e:
logger.error(f"提取列表项失败: {str(e)}")
return []
@staticmethod
def extract_key_value_pairs(response: str) -> Dict[str, str]:
"""从模型响应中提取键值对
Args:
response: 模型响应的文本
Returns:
Dict[str, str]: 提取出的键值对
"""
try:
# 清理响应文本
cleaned_response = TextProcessor.clean_model_output(response)
# 尝试匹配键值对
# 匹配冒号分隔的键值对
pairs = re.findall(r'([^:]+):\s*([^\n]+)', cleaned_response)
if pairs:
return {key.strip(): value.strip() for key, value in pairs}
# 匹配等号分隔的键值对
pairs = re.findall(r'([^=]+)=\s*([^\n]+)', cleaned_response)
if pairs:
return {key.strip(): value.strip() for key, value in pairs}
# 如果没有找到键值对,返回空字典
return {}
except Exception as e:
logger.error(f"提取键值对失败: {str(e)}")
return {}

View File

@ -66,12 +66,22 @@ MODEL_CONFIGS = {
"doubao": "doubao-1-5-pro-32k-250115"
}
},
# 谷歌Gemini
"Gemini": {
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
"api_key": "AIzaSyAVE8yTaPtN-TxCCHTc9Jb-aCV-Xo1EFuU",
"models": {
"offline_model": "gemini-2.0-flash"
}
},
# 天链苹果
"tl_private": {
"base_url": "http://192.168.32.118:1234/v1/",
"base_url": "http://192.168.16.174:1234/v1/",
"api_key": "none",
"models": {
"ds-v1": "mlx-community/DeepSeek-R1-4bit"
"glm-z1": "glm-z1-rumination-32b-0414",
"glm-4": "glm-4-32b-0414-abliterated",
"ds_v1": "mlx-community/DeepSeek-R1-4bit",
}
},
# 天链-千问
@ -79,7 +89,8 @@ MODEL_CONFIGS = {
"base_url": "http://192.168.16.178:11434/v1",
"api_key": "sk-WaVRJKkyhrFlH4ZV35B9Aa61759b400c9cA002D00f3f1019",
"models": {
"qwq": "qwq:32b"
"qwq": "qwq:32b",
"GLM": "hf-mirror.com/Cobra4687/GLM-4-32B-0414-abliterated-Q4_K_M-GGUF:Q4_K_M"
}
},
# Deepseek配置

View File

@ -0,0 +1,986 @@
import pandas as pd
import numpy as np
from sqlalchemy import create_engine
from datetime import datetime, timedelta
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
# v2版本只做表格
# 添加中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
class StockAnalyzer:
def __init__(self, db_connection_string):
"""
Initialize the stock analyzer
:param db_connection_string: Database connection string
"""
self.engine = create_engine(db_connection_string)
self.initial_capital = 300000 # 30万本金
self.total_position = 1000000 # 100万建仓规模
self.borrowed_capital = 700000 # 70万借入资金
self.annual_interest_rate = 0.05 # 年化5%利息
self.daily_interest_rate = self.annual_interest_rate / 365
self.commission_rate = 0.05 # 5%手续费
self.min_holding_days = 7 # 最少持仓天数
def get_stock_list(self):
"""获取所有股票列表"""
# query = "SELECT DISTINCT gp_code as symbol FROM gp_code_all where mark1 = '1'"
# return pd.read_sql(query, self.engine)['symbol'].tolist()
return ['SH600522',
'SZ002340',
'SZ000733',
'SH601615',
'SH600157',
'SH688005',
'SH600903',
'SH600956',
'SH601187',
'SH603983'] # 你可以在这里添加更多股票代码
def get_stock_data(self, symbol, start_date, end_date, include_history=False):
"""获取指定股票在时间范围内的数据"""
try:
if include_history:
# 如果需要历史数据获取start_date之前240天的数据
query = f"""
WITH history_data AS (
SELECT
symbol,
timestamp,
CAST(close as DECIMAL(10,2)) as close
FROM gp_day_data
WHERE symbol = '{symbol}'
AND timestamp < '{start_date}'
ORDER BY timestamp DESC
LIMIT 240
)
SELECT
symbol,
timestamp,
CAST(open as DECIMAL(10,2)) as open,
CAST(high as DECIMAL(10,2)) as high,
CAST(low as DECIMAL(10,2)) as low,
CAST(close as DECIMAL(10,2)) as close,
CAST(volume as DECIMAL(20,0)) as volume
FROM gp_day_data
WHERE symbol = '{symbol}'
AND timestamp BETWEEN '{start_date}' AND '{end_date}'
UNION ALL
SELECT
symbol,
timestamp,
NULL as open,
NULL as high,
NULL as low,
close,
NULL as volume
FROM history_data
ORDER BY timestamp
"""
else:
query = f"""
SELECT
symbol,
timestamp,
CAST(open as DECIMAL(10,2)) as open,
CAST(high as DECIMAL(10,2)) as high,
CAST(low as DECIMAL(10,2)) as low,
CAST(close as DECIMAL(10,2)) as close,
CAST(volume as DECIMAL(20,0)) as volume
FROM gp_day_data
WHERE symbol = '{symbol}'
AND timestamp BETWEEN '{start_date}' AND '{end_date}'
ORDER BY timestamp
"""
# 使用chunksize分批读取数据
chunks = []
for chunk in pd.read_sql(query, self.engine, chunksize=1000):
chunks.append(chunk)
df = pd.concat(chunks, ignore_index=True)
# 确保数值列的类型正确
numeric_columns = ['open', 'high', 'low', 'close', 'volume']
for col in numeric_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 删除任何包含 NaN 的行,但只考虑分析期间需要的列
if include_history:
historical_mask = (df['timestamp'] < start_date) & df['close'].notna()
analysis_mask = (df['timestamp'] >= start_date) & df[['open', 'high', 'low', 'close', 'volume']].notna().all(axis=1)
df = df[historical_mask | analysis_mask]
else:
df = df.dropna()
return df
except Exception as e:
print(f"获取股票 {symbol} 数据时出错: {str(e)}")
return pd.DataFrame()
def calculate_trade_signals(self, df, take_profit_pct, stop_loss_pct, pullback_entry_pct, start_date):
"""计算交易信号和收益"""
"""计算交易信号和收益"""
# 检查历史数据是否足够
historical_data = df[df['timestamp'] < start_date]
if len(historical_data) < 240:
# print(f"历史数据不足240天跳过分析")
return pd.DataFrame()
# 计算240日均线基准价格使用历史数据
ma240_base_price = historical_data['close'].tail(240).mean()
# print(f"240日均线基准价格: {ma240_base_price:.2f}")
# 只使用分析期间的数据进行交易信号计算
analysis_df = df[df['timestamp'] >= start_date].copy()
if len(analysis_df) < 10: # 确保分析期间有足够的数据
# print(f"分析期间数据不足10天跳过分析")
return pd.DataFrame()
results = []
positions = [] # 记录所有持仓,每个元素是一个字典,包含入场价格和日期
max_positions = 5 # 最大建仓次数
cumulative_profit = 0 # 总累计收益
trades_count = 0
max_holding_period = timedelta(days=180) # 最长持仓6个月
max_loss_amount = -300000 # 最大亏损金额为30万
# 添加统计变量
total_holding_days = 0 # 总持仓天数
max_capital_usage = 0 # 最大资金占用
total_trades = 0 # 总交易次数
profitable_trades = 0 # 盈利交易次数
loss_trades = 0 # 亏损交易次数
# print(f"\n{'='*50}")
# print(f"开始分析股票: {df.iloc[0]['symbol']}")
# print(f"{'='*50}")
# 记录前一天的收盘价
prev_close = None
for index, row in analysis_df.iterrows():
current_price = (float(row['high']) + float(row['low'])) / 2
day_result = {
'timestamp': row['timestamp'],
'profit': cumulative_profit, # 记录当前的累计收益
'position': False,
'action': 'hold'
}
# 更新所有持仓的收益
positions_to_remove = []
for pos in positions:
pos_days_held = (row['timestamp'] - pos['entry_date']).days
pos_size = self.total_position / pos['entry_price']
# 检查是否需要平仓
should_close = False
close_type = None
close_price = current_price
if pos_days_held >= self.min_holding_days:
# 检查止盈
if float(row['high']) >= pos['entry_price'] * (1 + take_profit_pct):
should_close = True
close_type = 'sell_profit'
close_price = pos['entry_price'] * (1 + take_profit_pct)
# 检查最大亏损
elif (close_price - pos['entry_price']) * pos_size <= max_loss_amount:
should_close = True
close_type = 'sell_loss'
# 检查持仓时间
elif (row['timestamp'] - pos['entry_date']) >= max_holding_period:
should_close = True
close_type = 'time_limit'
if should_close:
profit = (close_price - pos['entry_price']) * pos_size
commission = profit * self.commission_rate if profit > 0 else 0
interest_cost = self.borrowed_capital * self.daily_interest_rate * pos_days_held
net_profit = profit - commission - interest_cost
# 更新统计数据
total_holding_days += pos_days_held
total_trades += 1
if net_profit > 0:
profitable_trades += 1
else:
loss_trades += 1
cumulative_profit += net_profit
positions_to_remove.append(pos)
day_result['action'] = close_type
day_result['profit'] = cumulative_profit
# print(f"\n>>> {'止盈平仓' if close_type == 'sell_profit' else '止损平仓' if close_type == 'sell_loss' else '到期平仓'}")
# print(f"持仓天数: {pos_days_held}")
# print(f"持仓数量: {pos_size:.0f}股")
# print(f"平仓价格: {close_price:.2f}")
# print(f"{'盈利' if profit > 0 else '亏损'}: {profit:,.2f}")
# print(f"手续费: {commission:,.2f}")
# print(f"利息成本: {interest_cost:,.2f}")
# print(f"净{'盈利' if net_profit > 0 else '亏损'}: {net_profit:,.2f}")
# 移除已平仓的持仓
for pos in positions_to_remove:
positions.remove(pos)
# 检查是否可以建仓
if len(positions) < max_positions and prev_close is not None:
# 计算当日跌幅
daily_drop = (prev_close - float(row['low'])) / prev_close
# 检查价格是否在均线区间内
price_in_range = (
current_price >= ma240_base_price * 0.7 and
current_price <= ma240_base_price * 1.3
)
# 如果跌幅超过5%且价格在均线区间内
if daily_drop >= 0.05 and price_in_range:
positions.append({
'entry_price': current_price,
'entry_date': row['timestamp']
})
trades_count += 1
# print(f"\n>>> 建仓信号 #{trades_count}")
# print(f"日期: {row['timestamp'].strftime('%Y-%m-%d')}")
# print(f"建仓价格: {current_price:.2f}")
# print(f"建仓数量: {self.total_position/current_price:.0f}股")
# print(f"当日跌幅: {daily_drop*100:.2f}%")
# print(f"距离均线: {((current_price/ma240_base_price)-1)*100:.2f}%")
# 更新前一天收盘价
prev_close = float(row['close'])
# 更新日结果
day_result['position'] = len(positions) > 0
results.append(day_result)
# 计算当前资金占用
current_capital_usage = sum(self.total_position for pos in positions)
max_capital_usage = max(max_capital_usage, current_capital_usage)
# 在最后一天强制平仓所有持仓
if positions:
final_price = (float(analysis_df.iloc[-1]['high']) + float(analysis_df.iloc[-1]['low'])) / 2
final_total_profit = 0
for pos in positions:
position_size = self.total_position / pos['entry_price']
days_held = (analysis_df.iloc[-1]['timestamp'] - pos['entry_date']).days
final_profit = (final_price - pos['entry_price']) * position_size
commission = final_profit * self.commission_rate if final_profit > 0 else 0
interest_cost = self.borrowed_capital * self.daily_interest_rate * days_held
net_profit = final_profit - commission - interest_cost
# 更新统计数据
total_holding_days += days_held
total_trades += 1
if net_profit > 0:
profitable_trades += 1
else:
loss_trades += 1
final_total_profit += net_profit
# print(f"\n>>> 到期强制平仓")
# print(f"持仓天数: {days_held}")
# print(f"持仓数量: {position_size:.0f}股")
# print(f"平仓价格: {final_price:.2f}")
# print(f"毛利润: {final_profit:,.2f}")
# print(f"手续费: {commission:,.2f}")
# print(f"利息成本: {interest_cost:,.2f}")
# print(f"净利润: {net_profit:,.2f}")
results[-1]['action'] = 'final_sell'
cumulative_profit += final_total_profit # 更新最终的累计收益
results[-1]['profit'] = cumulative_profit # 更新最后一天的累计收益
# 计算统计数据
avg_holding_days = total_holding_days / total_trades if total_trades > 0 else 0
win_rate = profitable_trades / total_trades * 100 if total_trades > 0 else 0
# print(f"\n{'='*50}")
# print(f"交易统计")
# print(f"总交易次数: {trades_count}")
# print(f"累计收益: {cumulative_profit:,.2f}")
# print(f"最大资金占用: {max_capital_usage:,.2f}")
# print(f"平均持仓天数: {avg_holding_days:.1f}")
# print(f"胜率: {win_rate:.1f}%")
# print(f"盈利交易: {profitable_trades}次")
# print(f"亏损交易: {loss_trades}次")
# print(f"{'='*50}\n")
# 将统计数据添加到结果中
if len(results) > 0:
results[-1]['stats'] = {
'total_trades': total_trades,
'profitable_trades': profitable_trades,
'loss_trades': loss_trades,
'win_rate': win_rate,
'avg_holding_days': avg_holding_days,
'max_capital_usage': max_capital_usage,
'final_profit': cumulative_profit
}
return pd.DataFrame(results)
def analyze_stock(self, symbol, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct):
"""分析单个股票"""
df = self.get_stock_data(symbol, start_date, end_date, include_history=True)
if len(df) < 240: # 确保总数据量足够
print(f"股票 {symbol} 总数据不足240天跳过分析")
return None
return self.calculate_trade_signals(
df,
take_profit_pct=take_profit_pct,
stop_loss_pct=stop_loss_pct,
pullback_entry_pct=pullback_entry_pct,
start_date=start_date
)
def analyze_all_stocks(self, start_date, end_date, take_profit_pct, stop_loss_pct, pullback_entry_pct):
"""分析所有股票"""
stocks = self.get_stock_list()
all_results = {}
analysis_summary = {
'processed': [],
'skipped': [],
'error': []
}
print(f"\n开始分析,共 {len(stocks)} 支股票")
for symbol in tqdm(stocks, desc="Analyzing stocks"):
try:
# 获取股票数据
df = self.get_stock_data(symbol, start_date, end_date, include_history=True)
if df.empty:
print(f"\n股票 {symbol} 数据为空,跳过分析")
analysis_summary['skipped'].append({
'symbol': symbol,
'reason': 'empty_data',
'take_profit_pct': take_profit_pct
})
continue
# 检查历史数据
historical_data = df[df['timestamp'] < start_date]
if len(historical_data) < 240:
# print(f"\n股票 {symbol} 历史数据不足240天({len(historical_data)}天),跳过分析")
analysis_summary['skipped'].append({
'symbol': symbol,
'reason': f'insufficient_history_{len(historical_data)}days',
'take_profit_pct': take_profit_pct
})
continue
# 检查分析期数据
analysis_data = df[df['timestamp'] >= start_date]
if len(analysis_data) < 10:
print(f"\n股票 {symbol} 分析期数据不足10天({len(analysis_data)}天),跳过分析")
analysis_summary['skipped'].append({
'symbol': symbol,
'reason': f'insufficient_analysis_data_{len(analysis_data)}days',
'take_profit_pct': take_profit_pct
})
continue
# 计算交易信号
result = self.calculate_trade_signals(
df,
take_profit_pct,
stop_loss_pct,
pullback_entry_pct,
start_date
)
if result is not None and not result.empty:
# 检查是否有交易记录
has_trades = any(action in ['sell_profit', 'sell_loss', 'time_limit', 'final_sell']
for action in result['action'] if pd.notna(action))
if has_trades:
all_results[symbol] = result
analysis_summary['processed'].append({
'symbol': symbol,
'take_profit_pct': take_profit_pct,
'trades': len([x for x in result['action'] if pd.notna(x) and x != 'hold']),
'profit': result['profit'].iloc[-1]
})
else:
# print(f"\n股票 {symbol} 没有产生有效的交易信号")
analysis_summary['skipped'].append({
'symbol': symbol,
'reason': 'no_valid_signals',
'take_profit_pct': take_profit_pct
})
except Exception as e:
print(f"\n处理股票 {symbol} 时出错: {str(e)}")
analysis_summary['error'].append({
'symbol': symbol,
'error': str(e),
'take_profit_pct': take_profit_pct
})
continue
# 打印分析总结
print(f"\n分析完成:")
print(f"成功分析: {len(analysis_summary['processed'])} 支股票")
print(f"跳过分析: {len(analysis_summary['skipped'])} 支股票")
print(f"出错股票: {len(analysis_summary['error'])} 支股票")
# 保存分析总结
if not os.path.exists('results/analysis_summary'):
os.makedirs('results/analysis_summary')
# 保存处理成功的股票信息
if analysis_summary['processed']:
pd.DataFrame(analysis_summary['processed']).to_csv(
f'results/analysis_summary/processed_stocks_{take_profit_pct*100:.1f}.csv',
index=False
)
# 保存跳过的股票信息
if analysis_summary['skipped']:
pd.DataFrame(analysis_summary['skipped']).to_csv(
f'results/analysis_summary/skipped_stocks_{take_profit_pct*100:.1f}.csv',
index=False
)
# 保存出错的股票信息
if analysis_summary['error']:
pd.DataFrame(analysis_summary['error']).to_csv(
f'results/analysis_summary/error_stocks_{take_profit_pct*100:.1f}.csv',
index=False
)
return all_results
def plot_results(self, results, symbol):
"""绘制单个股票的收益走势图"""
plt.figure(figsize=(12, 6))
plt.plot(results['timestamp'], results['profit'].cumsum(), label='Cumulative Profit')
plt.title(f'Stock {symbol} Trading Results')
plt.xlabel('Date')
plt.ylabel('Profit (CNY)')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
return plt
def plot_all_stocks(self, all_results):
"""绘制所有股票的收益走势图"""
plt.figure(figsize=(15, 8))
# 记录所有股票的累计收益
total_profits = {}
max_profit = float('-inf')
min_profit = float('inf')
# 绘制每支股票的曲线
for symbol, results in all_results.items():
# 直接使用已经计算好的累计收益
profits = results['profit']
plt.plot(results['timestamp'], profits, label=symbol, alpha=0.5, linewidth=1)
# 更新最大最小收益
max_profit = max(max_profit, profits.max())
min_profit = min(min_profit, profits.min())
# 记录总收益(使用最后一个累计收益值)
total_profits[symbol] = profits.iloc[-1]
# 计算所有股票的中位数收益曲线
# 首先找到共同的日期范围
all_dates = set()
for results in all_results.values():
all_dates.update(results['timestamp'])
all_dates = sorted(list(all_dates))
# 创建一个包含所有日期的DataFrame
profits_df = pd.DataFrame(index=all_dates)
# 对每支股票,填充所有日期的收益
for symbol, results in all_results.items():
daily_profits = pd.Series(index=all_dates, data=np.nan)
# 直接使用已经计算好的累计收益
for date, profit in zip(results['timestamp'], results['profit']):
daily_profits[date] = profit
# 对于没有交易的日期,使用最后一个累计收益填充
daily_profits.fillna(method='ffill', inplace=True)
daily_profits.fillna(0, inplace=True) # 对开始之前的日期填充0
profits_df[symbol] = daily_profits
# 计算并绘制中位数收益曲线
median_line = profits_df.median(axis=1)
plt.plot(all_dates, median_line, 'r-', label='Median', linewidth=2)
plt.title('All Stocks Trading Results')
plt.xlabel('Date')
plt.ylabel('Profit (CNY)')
plt.grid(True)
plt.xticks(rotation=45)
# 添加利润区间标注
plt.text(0.02, 0.98,
f'Profit Range:\nMax: {max_profit:,.0f}\nMin: {min_profit:,.0f}\nMedian: {median_line.iloc[-1]:,.0f}',
transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8))
# 计算所有股票的统计数据
total_stats = {
'total_trades': 0,
'profitable_trades': 0,
'loss_trades': 0,
'total_holding_days': 0,
'max_capital_usage': 0,
'total_profit': 0,
'profitable_stocks': 0,
'loss_stocks': 0,
'total_capital_usage': 0, # 添加总资金占用
'total_win_rate': 0, # 添加总胜率
}
# 获取前5名和后5名股票
top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5]
bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5]
# 添加排名信息
rank_text = "Top 5 Stocks:\n"
for symbol, profit in top_5:
rank_text += f"{symbol}: {profit:,.0f}\n"
rank_text += "\nBottom 5 Stocks:\n"
for symbol, profit in bottom_5:
rank_text += f"{symbol}: {profit:,.0f}\n"
plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top')
# 记录每支股票的统计数据用于计算平均值
stock_stats = []
for symbol, results in all_results.items():
if 'stats' in results.iloc[-1]:
stats = results.iloc[-1]['stats']
stock_stats.append(stats) # 记录每支股票的统计数据
total_stats['total_trades'] += stats['total_trades']
total_stats['profitable_trades'] += stats['profitable_trades']
total_stats['loss_trades'] += stats['loss_trades']
total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades']
total_stats['max_capital_usage'] = max(total_stats['max_capital_usage'], stats['max_capital_usage'])
total_stats['total_capital_usage'] += stats['max_capital_usage']
total_stats['total_profit'] += stats['final_profit']
total_stats['total_win_rate'] += stats['win_rate']
if stats['final_profit'] > 0:
total_stats['profitable_stocks'] += 1
else:
total_stats['loss_stocks'] += 1
# 计算总体统计
total_stocks = total_stats['profitable_stocks'] + total_stats['loss_stocks']
avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0
win_rate = total_stats['profitable_trades'] / total_stats['total_trades'] * 100 if total_stats['total_trades'] > 0 else 0
stock_win_rate = total_stats['profitable_stocks'] / total_stocks * 100 if total_stocks > 0 else 0
# 计算平均值统计
avg_trades_per_stock = total_stats['total_trades'] / total_stocks if total_stocks > 0 else 0
avg_profit_per_stock = total_stats['total_profit'] / total_stocks if total_stocks > 0 else 0
avg_capital_usage = total_stats['total_capital_usage'] / total_stocks if total_stocks > 0 else 0
avg_win_rate = total_stats['total_win_rate'] / total_stocks if total_stocks > 0 else 0
# 添加总体统计信息
stats_text = (
f"总体统计:\n"
f"总交易次数: {total_stats['total_trades']}\n"
# f"最大资金占用: {total_stats['max_capital_usage']:,.0f}\n"
f"平均持仓天数: {avg_holding_days:.1f}\n"
f"交易胜率: {win_rate:.1f}%\n"
f"股票胜率: {stock_win_rate:.1f}%\n"
f"盈利股票数: {total_stats['profitable_stocks']}\n"
f"亏损股票数: {total_stats['loss_stocks']}\n"
f"\n平均值统计:\n"
f"平均交易次数: {avg_trades_per_stock:.1f}\n"
f"平均收益: {avg_profit_per_stock:,.0f}\n"
f"平均资金占用: {avg_capital_usage:,.0f}\n"
f"平均持仓天数: {avg_holding_days:.1f}\n"
f"平均胜率: {avg_win_rate:.1f}%"
)
plt.text(0.02, 0.6, stats_text,
transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
verticalalignment='top',
fontsize=10)
# 获取前5名和后5名股票
top_5 = sorted(total_profits.items(), key=lambda x: x[1], reverse=True)[:5]
bottom_5 = sorted(total_profits.items(), key=lambda x: x[1])[:5]
# 添加排名信息
rank_text = "Top 5 Stocks:\n"
for symbol, profit in top_5:
rank_text += f"{symbol}: {profit:,.0f}\n"
rank_text += "\nBottom 5 Stocks:\n"
for symbol, profit in bottom_5:
rank_text += f"{symbol}: {profit:,.0f}\n"
plt.text(1.02, 0.98, rank_text, transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8), verticalalignment='top')
plt.tight_layout()
return plt, total_profits
def plot_profit_distribution(self, total_profits):
"""绘制所有股票的收益分布小提琴图"""
plt.figure(figsize=(12, 8))
# 将收益数据转换为数组,确保使用最终累计收益
profits = np.array([profit for profit in total_profits.values()])
# 计算统计数据
mean_profit = np.mean(profits)
median_profit = np.median(profits)
max_profit = np.max(profits)
min_profit = np.min(profits)
std_profit = np.std(profits)
# 绘制小提琴图
violin_parts = plt.violinplot(profits, positions=[0], showmeans=True, showmedians=True)
# 设置小提琴图的颜色
violin_parts['bodies'][0].set_facecolor('lightblue')
violin_parts['bodies'][0].set_alpha(0.7)
violin_parts['cmeans'].set_color('red')
violin_parts['cmedians'].set_color('blue')
# 添加统计信息
stats_text = (
f"统计信息:\n"
f"平均收益: {mean_profit:,.0f}\n"
f"中位数收益: {median_profit:,.0f}\n"
f"最大收益: {max_profit:,.0f}\n"
f"最小收益: {min_profit:,.0f}\n"
f"标准差: {std_profit:,.0f}"
)
plt.text(0.65, 0.95, stats_text,
transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
verticalalignment='top',
fontsize=10)
# 设置图表样式
plt.title('股票收益分布图', fontsize=14, pad=20)
plt.ylabel('收益 (元)', fontsize=12)
plt.xticks([0], ['所有股票'], fontsize=10)
plt.grid(True, axis='y', alpha=0.3)
# 添加零线
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
# 计算盈利和亏损的股票数量
profit_count = np.sum(profits > 0)
loss_count = np.sum(profits < 0)
total_count = len(profits)
# 添加盈亏比例信息
ratio_text = (
f"盈亏比例:\n"
f"盈利: {profit_count}支 ({profit_count/total_count*100:.1f}%)\n"
f"亏损: {loss_count}支 ({loss_count/total_count*100:.1f}%)\n"
f"总计: {total_count}"
)
plt.text(0.65, 0.6, ratio_text,
transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
verticalalignment='top',
fontsize=10)
plt.tight_layout()
return plt
def plot_profit_matrix(self, df):
"""绘制止盈比例分析矩阵的结果图"""
plt.figure(figsize=(15, 10))
# 创建子图
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# 1. 总收益曲线
ax1.plot(df.index, df['total_profit'], 'b-', linewidth=2)
ax1.set_title('总收益 vs 止盈比例')
ax1.set_xlabel('止盈比例 (%)')
ax1.set_ylabel('总收益 (元)')
ax1.grid(True)
# 标记最优止盈比例点
best_profit_pct = df['total_profit'].idxmax()
best_profit = df['total_profit'].max()
ax1.plot(best_profit_pct, best_profit, 'ro')
ax1.annotate(f'最优: {best_profit_pct:.1f}%\n{best_profit:,.0f}',
(best_profit_pct, best_profit),
xytext=(10, 10), textcoords='offset points')
# 2. 平均每笔收益和中位数收益
ax2.plot(df.index, df['avg_profit_per_trade'], 'g-', linewidth=2, label='平均收益')
ax2.plot(df.index, df['median_profit_per_trade'], 'r--', linewidth=2, label='中位数收益')
ax2.set_title('每笔交易收益 vs 止盈比例')
ax2.set_xlabel('止盈比例 (%)')
ax2.set_ylabel('收益 (元)')
ax2.legend()
ax2.grid(True)
# 3. 平均持仓天数曲线
ax3.plot(df.index, df['avg_holding_days'], 'r-', linewidth=2)
ax3.set_title('平均持仓天数 vs 止盈比例')
ax3.set_xlabel('止盈比例 (%)')
ax3.set_ylabel('持仓天数')
ax3.grid(True)
# 4. 胜率对比
ax4.plot(df.index, df['trade_win_rate'], 'c-', linewidth=2, label='交易胜率')
ax4.plot(df.index, df['stock_win_rate'], 'm--', linewidth=2, label='股票胜率')
ax4.set_title('胜率 vs 止盈比例')
ax4.set_xlabel('止盈比例 (%)')
ax4.set_ylabel('胜率 (%)')
ax4.legend()
ax4.grid(True)
# 调整布局
plt.tight_layout()
return plt
def analyze_profit_matrix(self, start_date, end_date, stop_loss_pct, pullback_entry_pct):
"""分析不同止盈比例的表现矩阵"""
# 创建results目录和临时文件目录
if not os.path.exists('results/temp'):
os.makedirs('results/temp')
# 从30%到5%每次减少1%
profit_percentages = list(np.arange(0.30, 0.04, -0.01))
for take_profit_pct in profit_percentages:
# 检查是否已经分析过这个止盈点
temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv'
if os.path.exists(temp_file):
print(f"\n止盈比例 {take_profit_pct*100:.1f}% 已分析,跳过")
continue
print(f"\n分析止盈比例: {take_profit_pct*100:.1f}%")
try:
# 运行分析
results = self.analyze_all_stocks(
start_date,
end_date,
take_profit_pct,
stop_loss_pct,
pullback_entry_pct
)
# 过滤掉空结果
valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty}
if valid_results:
# 计算统计数据
total_stats = {
'total_trades': 0,
'profitable_trades': 0,
'loss_trades': 0,
'total_holding_days': 0,
'total_profit': 0,
'total_capital_usage': 0,
'total_win_rate': 0,
'stock_count': len(valid_results),
'profitable_stocks': 0,
'all_trade_profits': []
}
# 收集每支股票的统计数据
for symbol, result in valid_results.items():
if 'stats' in result.iloc[-1]:
stats = result.iloc[-1]['stats']
total_stats['total_trades'] += stats['total_trades']
total_stats['profitable_trades'] += stats['profitable_trades']
total_stats['loss_trades'] += stats['loss_trades']
total_stats['total_holding_days'] += stats['avg_holding_days'] * stats['total_trades']
total_stats['total_profit'] += stats['final_profit']
total_stats['total_capital_usage'] += stats['max_capital_usage']
total_stats['total_win_rate'] += stats['win_rate']
if stats['final_profit'] > 0:
total_stats['profitable_stocks'] += 1
if hasattr(stats, 'trade_profits'):
total_stats['all_trade_profits'].extend(stats['trade_profits'])
# 计算统计数据
avg_trades = total_stats['total_trades'] / total_stats['stock_count']
avg_holding_days = total_stats['total_holding_days'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0
avg_profit_per_trade = total_stats['total_profit'] / total_stats['total_trades'] if total_stats['total_trades'] > 0 else 0
median_profit_per_trade = np.median(total_stats['all_trade_profits']) if total_stats['all_trade_profits'] else 0
total_profit = total_stats['total_profit']
trade_win_rate = (total_stats['profitable_trades'] / total_stats['total_trades'] * 100) if total_stats['total_trades'] > 0 else 0
stock_win_rate = (total_stats['profitable_stocks'] / total_stats['stock_count'] * 100) if total_stats['stock_count'] > 0 else 0
# 创建单个止盈点的结果DataFrame
result_df = pd.DataFrame([{
'take_profit_pct': take_profit_pct * 100,
'total_profit': total_profit,
'avg_profit_per_trade': avg_profit_per_trade,
'median_profit_per_trade': median_profit_per_trade,
'avg_holding_days': avg_holding_days,
'trade_win_rate': trade_win_rate,
'stock_win_rate': stock_win_rate,
'total_trades': total_stats['total_trades'],
'stock_count': total_stats['stock_count']
}])
# 保存单个止盈点的结果
result_df.to_csv(temp_file, index=False)
print(f"保存止盈点 {take_profit_pct*100:.1f}% 的分析结果")
# 清理内存
del valid_results
del results
if 'total_stats' in locals():
del total_stats
except Exception as e:
print(f"处理止盈点 {take_profit_pct*100:.1f}% 时出错: {str(e)}")
continue
# 合并所有结果
print("\n合并所有分析结果...")
all_results = []
for take_profit_pct in profit_percentages:
temp_file = f'results/temp/profit_{take_profit_pct*100:.1f}.csv'
if os.path.exists(temp_file):
try:
df = pd.read_csv(temp_file)
all_results.append(df)
except Exception as e:
print(f"读取文件 {temp_file} 时出错: {str(e)}")
if all_results:
# 合并所有结果
df = pd.concat(all_results, ignore_index=True)
df = df.round({
'take_profit_pct': 1,
'total_profit': 0,
'avg_profit_per_trade': 0,
'median_profit_per_trade': 0,
'avg_holding_days': 1,
'trade_win_rate': 1,
'stock_win_rate': 1,
'total_trades': 0,
'stock_count': 0
})
# 设置止盈比例为索引并排序
df.set_index('take_profit_pct', inplace=True)
df.sort_index(ascending=False, inplace=True)
# 保存最终结果
df.to_csv('results/profit_matrix_analysis.csv')
# 打印结果
print("\n止盈比例分析矩阵:")
print("=" * 120)
print(df.to_string())
print("=" * 120)
try:
# 绘制并保存矩阵分析图表
plt = self.plot_profit_matrix(df)
plt.savefig('results/profit_matrix_analysis.png', bbox_inches='tight', dpi=300)
plt.close()
except Exception as e:
print(f"绘制图表时出错: {str(e)}")
return df
else:
print("没有找到任何分析结果")
return pd.DataFrame()
def main():
# 数据库连接配置
db_config = {
'host': '192.168.18.199',
'port': 3306,
'user': 'root',
'password': 'Chlry#$.8',
'database': 'db_gp_cj'
}
try:
connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}"
# 创建results目录
if not os.path.exists('results'):
os.makedirs('results')
# 创建分析器实例
analyzer = StockAnalyzer(connection_string)
# 设置分析参数
start_date = '2022-05-05'
end_date = '2023-08-28'
stop_loss_pct = -0.99 # -99%止损
pullback_entry_pct = -0.05 # -5%回调建仓
# 运行止盈比例分析
profit_matrix = analyzer.analyze_profit_matrix(
start_date,
end_date,
stop_loss_pct,
pullback_entry_pct
)
if not profit_matrix.empty:
# 使用最优止盈比例进行完整分析
take_profit_pct = profit_matrix['total_profit'].idxmax() / 100
print(f"\n使用最优止盈比例 {take_profit_pct*100:.1f}% 运行详细分析")
# 运行分析并只生成所需的图表
results = analyzer.analyze_all_stocks(
start_date,
end_date,
take_profit_pct,
stop_loss_pct,
pullback_entry_pct
)
# 过滤掉空结果
valid_results = {symbol: result for symbol, result in results.items() if result is not None and not result.empty}
if valid_results:
# 只生成所需的两个图表
plt, _ = analyzer.plot_all_stocks(valid_results)
plt.savefig('results/all_stocks_analysis.png', bbox_inches='tight', dpi=300)
plt.close()
except Exception as e:
print(f"程序运行出错: {str(e)}")
raise
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

1249
src/stock_simulation.py Normal file

File diff suppressed because it is too large Load Diff

81
src/tq_demo.py Normal file
View File

@ -0,0 +1,81 @@
# 用前须知
## xtdata提供和MiniQmt的交互接口本质是和MiniQmt建立连接由MiniQmt处理行情数据请求再把结果回传返回到python层。使用的行情服务器以及能获取到的行情数据和MiniQmt是一致的要检查数据或者切换连接时直接操作MiniQmt即可。
## 对于数据获取接口使用时需要先确保MiniQmt已有所需要的数据如果不足可以通过补充数据接口补充再调用数据获取接口获取。
## 对于订阅接口,直接设置数据回调,数据到来时会由回调返回。订阅接收到的数据一般会保存下来,同种数据不需要再单独补充。
# 代码讲解
# 从本地python导入xtquant库如果出现报错则说明安装失败
from xtquant import xtdata
import time
# 设定一个标的列表
code_list = ["000001.SZ"]
# 设定获取数据的周期
period = "1d"
# 下载标的行情数据
if 1:
## 为了方便用户进行数据管理xtquant的大部分历史数据都是以压缩形式存储在本地的
## 比如行情数据需要通过download_history_data下载财务数据需要通过
## 所以在取历史数据之前,我们需要调用数据下载接口,将数据下载到本地
for i in code_list:
xtdata.download_history_data(i, period=period, incrementally=True) # 增量下载行情数据(开高低收,等等)到本地
xtdata.download_financial_data(code_list) # 下载财务数据到本地
xtdata.download_sector_data() # 下载板块数据到本地
# 更多数据的下载方式可以通过数据字典查询
# 读取本地历史行情数据
history_data = xtdata.get_market_data_ex([], code_list, period=period, count=-1)
print(history_data)
print("=" * 20)
# 如果需要盘中的实时行情,需要向服务器进行订阅后才能获取
# 订阅后get_market_data函数于get_market_data_ex函数将会自动拼接本地历史行情与服务器实时行情
# 向服务器订阅数据
for i in code_list:
xtdata.subscribe_quote(i, period=period, count=-1) # 设置count = -1来取到当天所有实时行情
# 等待订阅完成
time.sleep(1)
# 获取订阅后的行情
kline_data = xtdata.get_market_data_ex([], code_list, period=period)
print(kline_data)
# 获取订阅后的行情,并以固定间隔进行刷新,预期会循环打印10次
for i in range(10):
# 这边做演示就用for来循环了实际使用中可以用while True
kline_data = xtdata.get_market_data_ex([], code_list, period=period)
print(kline_data)
time.sleep(3) # 三秒后再次获取行情
# 如果不想用固定间隔触发,可以以用订阅后的回调来执行
# 这种模式下当订阅的callback回调函数将会异步的执行每当订阅的标的tick发生变化更新callback回调函数就会被调用一次
# 本地已有的数据不会触发callback
# 定义的回测函数
## 回调函数中data是本次触发回调的数据只有一条
def f(data):
# print(data)
code_list = list(data.keys()) # 获取到本次触发的标的代码
kline_in_callabck = xtdata.get_market_data_ex([], code_list, period=period) # 在回调中获取klines数据
print(kline_in_callabck)
for i in code_list:
xtdata.subscribe_quote(i, period=period, count=-1, callback=f) # 订阅时设定回调函数
# 使用回调时必须要同时使用xtdata.run()来阻塞程序,否则程序运行到最后一行就直接结束退出了。
xtdata.run()