新增:标签分析逻辑。
This commit is contained in:
parent
16f3efe3ae
commit
3aa280090e
243
src/app.py
243
src/app.py
|
|
@ -123,6 +123,15 @@ fear_greed_manager = FearGreedIndexManager()
|
|||
# 创建指数分析器实例
|
||||
index_analyzer = IndexAnalyzer()
|
||||
|
||||
# 创建标签关联分析API实例
|
||||
try:
|
||||
from src.stock_tag_analysis.tag_relation_api import TagRelationAPI
|
||||
tag_relation_api = TagRelationAPI()
|
||||
logger.info("标签关联分析API初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"标签关联分析API初始化失败: {str(e)}")
|
||||
tag_relation_api = None
|
||||
|
||||
# 获取项目根目录
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports')
|
||||
|
|
@ -3887,6 +3896,240 @@ def run_tech_fundamental_strategy_batch():
|
|||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/tag/process', methods=['POST'])
|
||||
def process_tag_relation():
|
||||
"""处理标签关联分析接口
|
||||
|
||||
请求体参数:
|
||||
- tag_name: 标签名称(必填)
|
||||
- tag_type: 标签类型(可选,默认为None,会自动判断)
|
||||
- model_type: 大模型类型(可选,默认为ds3_2)
|
||||
- enable_web_search: 是否启用联网搜索(可选,默认为True)
|
||||
- temperature: 大模型温度参数(可选,默认为0.7)
|
||||
- max_tokens: 最大token数(可选,默认为4096)
|
||||
|
||||
返回:
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"success": true,
|
||||
"tag_code": "123",
|
||||
"tag_type": "AI标签",
|
||||
"source": "ai",
|
||||
"stock_count": 0,
|
||||
"message": "AI分析任务已加入队列,共5000只股票待处理",
|
||||
"error": null,
|
||||
"status": "processing"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 检查API是否初始化成功
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
# 获取参数
|
||||
data = request.get_json() if request.is_json else request.form
|
||||
tag_name = data.get('tag_name')
|
||||
tag_type = data.get('tag_type')
|
||||
model_type = data.get('model_type', 'ds3_2')
|
||||
enable_web_search = data.get('enable_web_search', True)
|
||||
temperature = float(data.get('temperature', 0.7))
|
||||
max_tokens = int(data.get('max_tokens', 4096))
|
||||
|
||||
# 验证必要参数
|
||||
if not tag_name:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "缺少必要参数: tag_name"
|
||||
}), 400
|
||||
|
||||
# 调用处理方法
|
||||
result = tag_relation_api.process_tag(
|
||||
tag_name=tag_name,
|
||||
tag_type=tag_type,
|
||||
model_type=model_type,
|
||||
enable_web_search=enable_web_search,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# 返回结果
|
||||
if result.get('success'):
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get('message', '处理失败'),
|
||||
"data": result
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理标签关联分析失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"处理标签关联分析失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tag/progress', methods=['GET'])
|
||||
def get_tag_progress():
|
||||
"""获取标签处理进度接口
|
||||
|
||||
参数:
|
||||
- tag_name: 标签名称(必填)
|
||||
|
||||
返回:
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"tag_code": "123",
|
||||
"tag_name": "商业航天",
|
||||
"status": "processing",
|
||||
"progress": 45.5,
|
||||
"total_stocks": 5000,
|
||||
"processed": 2275,
|
||||
"remaining": 2725,
|
||||
"queue_length": 5000
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 检查API是否初始化成功
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
# 获取参数
|
||||
tag_name = request.args.get('tag_name')
|
||||
|
||||
# 验证参数
|
||||
if not tag_name:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "缺少必要参数: tag_name"
|
||||
}), 400
|
||||
|
||||
# 调用获取进度方法
|
||||
result = tag_relation_api.get_tag_progress(tag_name=tag_name)
|
||||
|
||||
# 检查是否有错误
|
||||
if 'error' in result:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get('error', '获取进度失败')
|
||||
}), 404
|
||||
|
||||
# 返回结果
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取标签进度失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取标签进度失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tag/delete', methods=['POST', 'DELETE'])
|
||||
def delete_tag_relation():
|
||||
"""删除标签及其相关数据接口
|
||||
|
||||
请求参数:
|
||||
- tag_name: 标签名称(必填)
|
||||
- delete_analysis: 是否删除分析结果(可选,默认为False,只删除标签和关联关系)
|
||||
|
||||
返回:
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"success": true,
|
||||
"tag_code": "123",
|
||||
"deleted_tag": true,
|
||||
"deleted_relations": 50,
|
||||
"deleted_analyses": 0,
|
||||
"deleted_redis_keys": 2,
|
||||
"message": "删除成功: 标签=true, 关联关系=50条, Redis键=2个",
|
||||
"error": null
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 检查API是否初始化成功
|
||||
if tag_relation_api is None:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "标签关联分析API未初始化"
|
||||
}), 500
|
||||
|
||||
# 获取参数
|
||||
if request.is_json:
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form
|
||||
|
||||
tag_name = data.get('tag_name')
|
||||
delete_analysis = data.get('delete_analysis', False)
|
||||
|
||||
# 处理delete_analysis参数(可能是字符串"true"/"false")
|
||||
if isinstance(delete_analysis, str):
|
||||
delete_analysis = delete_analysis.lower() in ('true', '1', 'yes')
|
||||
else:
|
||||
delete_analysis = bool(delete_analysis)
|
||||
|
||||
# 验证必要参数
|
||||
if not tag_name:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "缺少必要参数: tag_name"
|
||||
}), 400
|
||||
|
||||
# 调用删除方法
|
||||
result = tag_relation_api.delete_tag_by_name(
|
||||
tag_name=tag_name,
|
||||
delete_analysis=delete_analysis
|
||||
)
|
||||
|
||||
# 返回结果
|
||||
if result.get('success'):
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
else:
|
||||
# 如果是标签不存在,返回404
|
||||
if "不存在" in result.get('message', ''):
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get('message', '标签不存在'),
|
||||
"data": result
|
||||
}), 404
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": result.get('message', '删除失败'),
|
||||
"data": result
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除标签失败: {str(e)}", exc_info=True)
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"删除标签失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# 启动Web服务器
|
||||
|
|
|
|||
|
|
@ -63,7 +63,8 @@ MODEL_CONFIGS = {
|
|||
"models": {
|
||||
"offline_model": "ep-20250326090920-v7wns",
|
||||
"online_bot": "bot-20250325102825-h9kpq",
|
||||
"doubao": "doubao-1-5-pro-32k-250115"
|
||||
"doubao": "doubao-1-5-pro-32k-250115",
|
||||
"ds3_2": "ep-20260104113811-m2h2f"
|
||||
}
|
||||
},
|
||||
# 谷歌Gemini
|
||||
|
|
|
|||
|
|
@ -0,0 +1,276 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# 设置日志记录
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取项目根目录
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# 导入配置
|
||||
sys.path.append(os.path.dirname(ROOT_DIR))
|
||||
sys.path.append(ROOT_DIR)
|
||||
|
||||
try:
|
||||
from scripts.config import get_random_api_key, get_model
|
||||
except ImportError:
|
||||
from src.scripts.config import get_random_api_key, get_model
|
||||
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self, model_type: str = "online_bot", enable_web_search: bool = True):
|
||||
"""初始化聊天机器人
|
||||
|
||||
Args:
|
||||
model_type: 模型类型,默认为 online_bot
|
||||
enable_web_search: 是否启用联网搜索,默认True
|
||||
"""
|
||||
self.api_key = get_random_api_key()
|
||||
self.model = get_model(model_type)
|
||||
self.enable_web_search = enable_web_search
|
||||
|
||||
logger.info(f"初始化ChatBot,模型: {self.model}, 联网搜索: {self.enable_web_search}")
|
||||
|
||||
# 初始化客户端
|
||||
self.client = OpenAI(
|
||||
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
||||
api_key=self.api_key
|
||||
)
|
||||
|
||||
# 联网搜索工具配置
|
||||
self.tools = [{
|
||||
"type": "web_search",
|
||||
"limit": 10,
|
||||
"sources": ["douyin", "moji", "toutiao"],
|
||||
"user_location": {
|
||||
"type": "approximate",
|
||||
"country": "中国",
|
||||
"region": "北京",
|
||||
"city": "北京"
|
||||
}
|
||||
}] if enable_web_search else None
|
||||
|
||||
# 系统提示语
|
||||
self.system_message = """你是一个专业的股票分析助手,擅长进行深入的基本面分析。你的分析应该:
|
||||
|
||||
1. 专业严谨 - 使用准确的专业术语,引用可靠的数据来源,分析逻辑清晰,结论有理有据
|
||||
2. 全面细致 - 深入分析问题的各个方面,关注细节和关键信息,考虑多个影响因素
|
||||
3. 客观中立 - 保持独立判断,不夸大或贬低,平衡利弊分析,指出潜在风险
|
||||
4. 实用性强 - 分析结论具体明确,建议具有可操作性,关注实际投资价值
|
||||
5. 及时更新 - 关注最新信息,反映市场变化,动态调整分析,保持信息时效性
|
||||
|
||||
请根据用户的具体需求,提供专业、深入的分析。如果遇到不确定的信息,请明确说明。"""
|
||||
|
||||
# 对话历史(简化版,只保存用户和助手消息)
|
||||
self.conversation_history = []
|
||||
|
||||
def format_reference(self, ref):
|
||||
"""格式化参考资料"""
|
||||
if isinstance(ref, str):
|
||||
return ref
|
||||
elif isinstance(ref, dict):
|
||||
parts = []
|
||||
if ref.get('title'):
|
||||
parts.append(f"标题:{ref['title']}")
|
||||
if ref.get('summary'):
|
||||
parts.append(f"摘要:{ref['summary']}")
|
||||
if ref.get('url'):
|
||||
parts.append(f"链接:{ref['url']}")
|
||||
if ref.get('publish_time'):
|
||||
parts.append(f"发布时间:{ref['publish_time']}")
|
||||
return "\n".join(parts)
|
||||
return str(ref)
|
||||
|
||||
def chat(self, user_input: str, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]:
|
||||
"""与AI进行对话
|
||||
|
||||
Args:
|
||||
user_input: 用户输入的问题
|
||||
temperature: 控制输出的随机性,范围0-2,默认0.7
|
||||
top_p: 控制输出的多样性,范围0-1,默认0.7
|
||||
max_tokens: 控制输出的最大长度,默认4096
|
||||
frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含 response, reasoning_process, references, tool_usage, tool_usage_details
|
||||
"""
|
||||
try:
|
||||
# 添加用户消息到对话历史
|
||||
self.conversation_history.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": user_input}]
|
||||
})
|
||||
|
||||
# 构建输入消息
|
||||
input_messages = [
|
||||
# 系统提示词
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "input_text", "text": self.system_message}]
|
||||
}
|
||||
]
|
||||
|
||||
# 添加最近的对话历史(保留最近5轮)
|
||||
for msg in self.conversation_history[-5:]:
|
||||
input_messages.append(msg)
|
||||
|
||||
# 调用API(流式输出)
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=input_messages,
|
||||
tools=self.tools,
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 收集回复
|
||||
full_response = ""
|
||||
reasoning_text = ""
|
||||
references = []
|
||||
tool_usage = None
|
||||
tool_usage_details = None
|
||||
search_keywords = []
|
||||
|
||||
# 状态变量
|
||||
thinking_started = False
|
||||
answering_started = False
|
||||
|
||||
print("\nAI: ", end="", flush=True)
|
||||
|
||||
# 处理事件流
|
||||
for chunk in response:
|
||||
chunk_type = getattr(chunk, "type", "")
|
||||
|
||||
# 处理AI思考过程
|
||||
if chunk_type == "response.reasoning_summary_text.delta":
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
delta = getattr(chunk, "delta", "")
|
||||
if delta:
|
||||
reasoning_text += delta
|
||||
|
||||
# 处理搜索状态
|
||||
elif "web_search_call" in chunk_type:
|
||||
if "in_progress" in chunk_type:
|
||||
logger.debug("开始搜索")
|
||||
elif "completed" in chunk_type:
|
||||
logger.debug("搜索完成")
|
||||
|
||||
# 处理搜索关键词
|
||||
elif (chunk_type == "response.output_item.done"
|
||||
and hasattr(chunk, "item")
|
||||
and str(getattr(chunk.item, "id", "")).startswith("ws_")):
|
||||
if hasattr(chunk.item, "action") and hasattr(chunk.item.action, "query"):
|
||||
search_keyword = chunk.item.action.query
|
||||
search_keywords.append(search_keyword)
|
||||
logger.debug(f"搜索关键词: {search_keyword}")
|
||||
|
||||
# 处理最终回答
|
||||
elif chunk_type == "response.output_text.delta":
|
||||
if not answering_started:
|
||||
answering_started = True
|
||||
delta = getattr(chunk, "delta", "")
|
||||
if delta:
|
||||
print(delta, end="", flush=True)
|
||||
full_response += delta
|
||||
time.sleep(0.01)
|
||||
|
||||
# 处理完成事件,提取工具使用统计
|
||||
elif chunk_type == "response.completed":
|
||||
if hasattr(chunk, "response") and hasattr(chunk.response, "usage"):
|
||||
usage = chunk.response.usage
|
||||
if hasattr(usage, "tool_usage"):
|
||||
tool_usage = usage.tool_usage
|
||||
if hasattr(usage, "tool_usage_details"):
|
||||
tool_usage_details = usage.tool_usage_details
|
||||
|
||||
print() # 换行
|
||||
|
||||
# 提取推理过程和参考资料
|
||||
main_response = full_response
|
||||
reasoning_process = reasoning_text
|
||||
|
||||
if "推理过程:" in full_response:
|
||||
parts = full_response.split("推理过程:")
|
||||
main_response = parts[0].strip()
|
||||
if len(parts) > 1:
|
||||
reasoning_parts = parts[1].split("参考资料:")
|
||||
reasoning_process = reasoning_parts[0].strip() if reasoning_parts else reasoning_text
|
||||
|
||||
# 添加AI回复到对话历史
|
||||
self.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "input_text", "text": main_response}]
|
||||
})
|
||||
|
||||
return {
|
||||
"response": main_response,
|
||||
"reasoning_process": reasoning_process,
|
||||
"references": references,
|
||||
"tool_usage": tool_usage,
|
||||
"tool_usage_details": tool_usage_details,
|
||||
"search_keywords": search_keywords
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对话失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"response": f"抱歉,处理您的请求时出现错误: {str(e)}",
|
||||
"reasoning_process": "",
|
||||
"references": [],
|
||||
"tool_usage": None,
|
||||
"tool_usage_details": None,
|
||||
"search_keywords": []
|
||||
}
|
||||
|
||||
def clear_history(self):
|
||||
"""清除对话历史"""
|
||||
self.conversation_history = []
|
||||
print("对话历史已清除")
|
||||
|
||||
def run(self):
|
||||
"""运行聊天机器人"""
|
||||
print("欢迎使用AI助手!输入 'quit' 退出,输入 'clear' 清除对话历史。")
|
||||
if self.enable_web_search:
|
||||
print("该版本支持联网搜索,可以回答实时信息。")
|
||||
print("-" * 50)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n你: ").strip()
|
||||
|
||||
if user_input.lower() == 'quit':
|
||||
print("感谢使用,再见!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'clear':
|
||||
self.clear_history()
|
||||
continue
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
self.chat(user_input)
|
||||
print("-" * 50)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n感谢使用,再见!")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"运行错误: {str(e)}")
|
||||
print(f"发生错误: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
bot = ChatBot(enable_web_search=True)
|
||||
bot.run()
|
||||
|
|
@ -0,0 +1,323 @@
|
|||
"""
|
||||
个股标签关联分析器
|
||||
使用联网大模型分析个股与标签的关联性,并输出结构化结果
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
# 导入联网大模型工具
|
||||
import sys
|
||||
import os
|
||||
|
||||
try:
|
||||
from .chat_bot import ChatBot
|
||||
except ImportError:
|
||||
try:
|
||||
from src.stock_tag_analysis.chat_bot import ChatBot
|
||||
except ImportError:
|
||||
# 如果相对导入失败,添加路径并使用绝对导入
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
from chat_bot import ChatBot
|
||||
|
||||
# 设置日志记录
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TagRelationAnalyzer:
|
||||
"""个股标签关联分析器"""
|
||||
|
||||
def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True):
|
||||
"""初始化分析器
|
||||
|
||||
Args:
|
||||
model_type: 大模型类型,默认为 ds3_2(联网智能体)
|
||||
enable_web_search: 是否启用联网搜索,默认True
|
||||
"""
|
||||
self.chat_bot = ChatBot(model_type=model_type, enable_web_search=enable_web_search)
|
||||
logger.info(f"TagRelationAnalyzer 初始化完成,模型: {model_type}, 联网搜索: {enable_web_search}")
|
||||
|
||||
def _build_analysis_prompt(self, stock_code: str, stock_name: str, tag_name: str) -> str:
|
||||
"""构建分析提示词
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
stock_name: 股票名称
|
||||
tag_name: 标签名称
|
||||
|
||||
Returns:
|
||||
str: 分析提示词
|
||||
"""
|
||||
prompt = f"""请分析股票"{stock_name}"(股票代码:{stock_code})与标签"{tag_name}"是否存在关联性。
|
||||
|
||||
请通过联网搜索获取最新的信息,包括但不限于:
|
||||
1. 公司主营业务是否与标签相关
|
||||
2. 公司产品、服务是否涉及标签领域
|
||||
3. 公司是否在标签相关行业有布局
|
||||
4. 公司公告、新闻中是否提及标签相关内容
|
||||
5. 公司财务数据是否显示与标签相关的业务占比
|
||||
|
||||
请按照以下JSON格式输出分析结果:
|
||||
{{
|
||||
"has_relation": true/false, // 是否存在关联性
|
||||
"relation_score": 0-100, // 关联度评分(0-100,数值越高关联度越高)
|
||||
"relation_type": "直接关联/间接关联/无关联", // 关联类型
|
||||
"analysis_summary": "简要分析总结(100字以内)",
|
||||
"detailed_analysis": "详细分析内容(包括关联依据、业务占比、市场表现等)",
|
||||
"key_evidence": [ // 关键证据列表
|
||||
{{
|
||||
"evidence_type": "主营业务/产品服务/行业布局/公告新闻/财务数据",
|
||||
"description": "证据描述",
|
||||
"relevance": "高/中/低"
|
||||
}}
|
||||
],
|
||||
"business_ratio": "与标签相关的业务占比(如:30%、主要业务、次要业务等,如无相关信息则填写'未知')",
|
||||
"market_performance": "市场表现相关描述(如:该标签概念股表现、行业地位等,如无相关信息则填写'未知')",
|
||||
"conclusion": "最终结论(50字以内)"
|
||||
}}
|
||||
|
||||
请确保输出的是有效的JSON格式,不要包含任何其他文字说明。"""
|
||||
return prompt
|
||||
|
||||
def _parse_json_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""从响应文本中解析JSON
|
||||
|
||||
Args:
|
||||
response_text: 响应文本
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 解析后的JSON字典,如果解析失败返回None
|
||||
"""
|
||||
try:
|
||||
# 尝试直接解析JSON
|
||||
# 查找JSON代码块
|
||||
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 查找大括号包裹的JSON
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response_text.strip()
|
||||
|
||||
# 清理可能的Markdown标记
|
||||
json_str = json_str.strip()
|
||||
if json_str.startswith('```'):
|
||||
json_str = re.sub(r'^```[a-z]*\n?', '', json_str)
|
||||
if json_str.endswith('```'):
|
||||
json_str = re.sub(r'\n?```$', '', json_str)
|
||||
|
||||
# 解析JSON
|
||||
result = json.loads(json_str)
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {str(e)}")
|
||||
logger.error(f"响应文本: {response_text[:500]}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析响应时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def analyze_tag_relation(
|
||||
self,
|
||||
stock_code: str,
|
||||
stock_name: str,
|
||||
tag_name: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096
|
||||
) -> Dict[str, Any]:
|
||||
"""分析个股与标签的关联性
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码(如:300750.SZ)
|
||||
stock_name: 股票名称(如:宁德时代)
|
||||
tag_name: 标签名称(如:医药、新能源、金融科技等)
|
||||
temperature: 大模型温度参数,默认0.7
|
||||
max_tokens: 最大token数,默认4096
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果,包含以下字段:
|
||||
- success: 是否成功
|
||||
- stock_code: 股票代码
|
||||
- stock_name: 股票名称
|
||||
- tag_name: 标签名称
|
||||
- has_relation: 是否存在关联性
|
||||
- relation_score: 关联度评分(0-100)
|
||||
- relation_type: 关联类型
|
||||
- analysis_summary: 简要分析总结
|
||||
- detailed_analysis: 详细分析内容
|
||||
- key_evidence: 关键证据列表
|
||||
- business_ratio: 业务占比
|
||||
- market_performance: 市场表现
|
||||
- conclusion: 最终结论
|
||||
- raw_response: 原始响应
|
||||
- reasoning_process: 推理过程
|
||||
- references: 参考资料
|
||||
- analysis_time: 分析时间
|
||||
- error: 错误信息(如果失败)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始分析 {stock_name}({stock_code}) 与标签 {tag_name} 的关联性")
|
||||
|
||||
# 构建分析提示词
|
||||
prompt = self._build_analysis_prompt(stock_code, stock_name, tag_name)
|
||||
|
||||
# 调用大模型进行分析
|
||||
result = self.chat_bot.chat(
|
||||
user_input=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
response_text = result.get("response", "")
|
||||
reasoning_process = result.get("reasoning_process", "")
|
||||
references = result.get("references", [])
|
||||
|
||||
# 解析JSON结果
|
||||
parsed_result = self._parse_json_response(response_text)
|
||||
|
||||
if parsed_result is None:
|
||||
# 如果解析失败,返回原始响应
|
||||
logger.warning("JSON解析失败,返回原始响应")
|
||||
return {
|
||||
"success": False,
|
||||
"stock_code": stock_code,
|
||||
"stock_name": stock_name,
|
||||
"tag_name": tag_name,
|
||||
"has_relation": None,
|
||||
"relation_score": None,
|
||||
"relation_type": None,
|
||||
"analysis_summary": None,
|
||||
"detailed_analysis": response_text,
|
||||
"key_evidence": [],
|
||||
"business_ratio": None,
|
||||
"market_performance": None,
|
||||
"conclusion": None,
|
||||
"raw_response": response_text,
|
||||
"reasoning_process": reasoning_process,
|
||||
"references": references,
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"error": "JSON解析失败,请检查响应格式"
|
||||
}
|
||||
|
||||
# 构建完整结果
|
||||
analysis_result = {
|
||||
"success": True,
|
||||
"stock_code": stock_code,
|
||||
"stock_name": stock_name,
|
||||
"tag_name": tag_name,
|
||||
"has_relation": parsed_result.get("has_relation", False),
|
||||
"relation_score": parsed_result.get("relation_score", 0),
|
||||
"relation_type": parsed_result.get("relation_type", "无关联"),
|
||||
"analysis_summary": parsed_result.get("analysis_summary", ""),
|
||||
"detailed_analysis": parsed_result.get("detailed_analysis", ""),
|
||||
"key_evidence": parsed_result.get("key_evidence", []),
|
||||
"business_ratio": parsed_result.get("business_ratio", "未知"),
|
||||
"market_performance": parsed_result.get("market_performance", "未知"),
|
||||
"conclusion": parsed_result.get("conclusion", ""),
|
||||
"raw_response": response_text,
|
||||
"reasoning_process": reasoning_process,
|
||||
"references": references,
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"error": None
|
||||
}
|
||||
|
||||
logger.info(f"分析完成: {stock_name}({stock_code}) 与 {tag_name} 的关联度评分为 {analysis_result['relation_score']}")
|
||||
|
||||
return analysis_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析过程中出错: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"stock_code": stock_code,
|
||||
"stock_name": stock_name,
|
||||
"tag_name": tag_name,
|
||||
"has_relation": None,
|
||||
"relation_score": None,
|
||||
"relation_type": None,
|
||||
"analysis_summary": None,
|
||||
"detailed_analysis": None,
|
||||
"key_evidence": [],
|
||||
"business_ratio": None,
|
||||
"market_performance": None,
|
||||
"conclusion": None,
|
||||
"raw_response": None,
|
||||
"reasoning_process": None,
|
||||
"references": [],
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def batch_analyze_tag_relation(
|
||||
self,
|
||||
stock_list: List[Dict[str, str]],
|
||||
tag_name: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""批量分析多只股票与标签的关联性
|
||||
|
||||
Args:
|
||||
stock_list: 股票列表,每个元素包含 stock_code 和 stock_name
|
||||
tag_name: 标签名称
|
||||
temperature: 大模型温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 分析结果列表
|
||||
"""
|
||||
results = []
|
||||
for stock in stock_list:
|
||||
stock_code = stock.get("stock_code")
|
||||
stock_name = stock.get("stock_name")
|
||||
|
||||
if not stock_code or not stock_name:
|
||||
logger.warning(f"跳过无效股票信息: {stock}")
|
||||
continue
|
||||
|
||||
result = self.analyze_tag_relation(
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
tag_name=tag_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# 清除对话历史,避免上下文干扰
|
||||
self.chat_bot.clear_history()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 测试示例
|
||||
analyzer = TagRelationAnalyzer()
|
||||
|
||||
# 测试:分析宁德时代与医药标签的关联性
|
||||
result = analyzer.analyze_tag_relation(
|
||||
stock_code="300750.SZ",
|
||||
stock_name="宁德时代",
|
||||
tag_name="医药"
|
||||
)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("分析结果:")
|
||||
print("="*50)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,448 @@
|
|||
"""
|
||||
标签关联分析数据库操作模块
|
||||
用于连接 MongoDB 并存储分析结果
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pymongo
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
# 导入配置
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_DIR)
|
||||
|
||||
try:
|
||||
from valuation_analysis.config import MONGO_CONFIG222
|
||||
except ImportError:
|
||||
try:
|
||||
from src.valuation_analysis.config import MONGO_CONFIG222
|
||||
except ImportError:
|
||||
# 尝试从当前目录导入
|
||||
sys.path.append(os.path.dirname(ROOT_DIR))
|
||||
from src.valuation_analysis.config import MONGO_CONFIG222
|
||||
|
||||
# 设置日志记录
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TagRelationDatabase:
|
||||
"""标签关联分析数据库操作类"""
|
||||
|
||||
def __init__(self, collection_name: Optional[str] = None):
|
||||
"""初始化数据库连接
|
||||
|
||||
Args:
|
||||
collection_name: 集合名称,如果不指定则使用配置中的默认集合名
|
||||
"""
|
||||
self.mongo_client = None
|
||||
self.db = None
|
||||
self.collection_name = collection_name or MONGO_CONFIG222.get('collection', 'tag_relation_analysis')
|
||||
self.collection = None
|
||||
|
||||
# 标签集合和关联集合
|
||||
self.tags_collection = None
|
||||
self.tag_stock_relations_collection = None
|
||||
|
||||
self.connect_mongodb()
|
||||
|
||||
def connect_mongodb(self):
|
||||
"""连接MongoDB数据库"""
|
||||
try:
|
||||
self.mongo_client = pymongo.MongoClient(
|
||||
host=MONGO_CONFIG222['host'],
|
||||
port=MONGO_CONFIG222['port'],
|
||||
username=MONGO_CONFIG222['username'],
|
||||
password=MONGO_CONFIG222['password'],
|
||||
authSource='admin' # 认证数据库
|
||||
)
|
||||
self.db = self.mongo_client[MONGO_CONFIG222['db']]
|
||||
self.collection = self.db[self.collection_name]
|
||||
|
||||
# 初始化标签集合和关联集合
|
||||
self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合
|
||||
self.tag_stock_relations_collection = self.db['tag_stock_relations']
|
||||
|
||||
# 测试连接
|
||||
self.mongo_client.admin.command('ping')
|
||||
logger.info(f"MongoDB连接成功,数据库: {MONGO_CONFIG222['db']}, 集合: {self.collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MongoDB连接失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool:
|
||||
"""保存分析结果到MongoDB
|
||||
|
||||
Args:
|
||||
analysis_result: 分析结果字典,包含所有分析字段
|
||||
|
||||
Returns:
|
||||
bool: 是否保存成功
|
||||
"""
|
||||
try:
|
||||
if not analysis_result:
|
||||
logger.warning("分析结果为空,无法保存")
|
||||
return False
|
||||
|
||||
# 添加保存时间
|
||||
analysis_result['save_time'] = datetime.now().isoformat()
|
||||
|
||||
# 使用 stock_code 和 tag_name 作为唯一标识
|
||||
filter_condition = {
|
||||
'stock_code': analysis_result.get('stock_code'),
|
||||
'tag_name': analysis_result.get('tag_name')
|
||||
}
|
||||
|
||||
# 检查是否已存在
|
||||
existing_record = self.collection.find_one(filter_condition)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
update_result = self.collection.update_one(
|
||||
filter_condition,
|
||||
{'$set': analysis_result}
|
||||
)
|
||||
logger.info(f"更新分析结果: {analysis_result.get('stock_code')} - {analysis_result.get('tag_name')}")
|
||||
return update_result.modified_count > 0
|
||||
else:
|
||||
# 插入新记录
|
||||
insert_result = self.collection.insert_one(analysis_result)
|
||||
logger.info(f"插入新分析结果: {analysis_result.get('stock_code')} - {analysis_result.get('tag_name')}, ID: {insert_result.inserted_id}")
|
||||
return insert_result.inserted_id is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存分析结果失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_analysis_result(self, stock_code: str, tag_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取指定股票和标签的分析结果
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
tag_name: 标签名称
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 分析结果,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
result = self.collection.find_one({
|
||||
'stock_code': stock_code,
|
||||
'tag_name': tag_name
|
||||
})
|
||||
|
||||
if result:
|
||||
# 移除 MongoDB 的 _id 字段(如果需要)
|
||||
result.pop('_id', None)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取分析结果失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_stock_analyses(self, stock_code: str) -> list:
|
||||
"""获取指定股票的所有标签分析结果
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
list: 分析结果列表
|
||||
"""
|
||||
try:
|
||||
results = list(self.collection.find({'stock_code': stock_code}))
|
||||
|
||||
# 移除 MongoDB 的 _id 字段
|
||||
for result in results:
|
||||
result.pop('_id', None)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票分析结果失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_tag_analyses(self, tag_name: str) -> list:
|
||||
"""获取指定标签的所有股票分析结果
|
||||
|
||||
Args:
|
||||
tag_name: 标签名称
|
||||
|
||||
Returns:
|
||||
list: 分析结果列表
|
||||
"""
|
||||
try:
|
||||
results = list(self.collection.find({'tag_name': tag_name}))
|
||||
|
||||
# 移除 MongoDB 的 _id 字段
|
||||
for result in results:
|
||||
result.pop('_id', None)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取标签分析结果失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def save_tag(self, tag_data: Dict[str, Any]) -> Optional[str]:
|
||||
"""保存标签到tags集合
|
||||
|
||||
Args:
|
||||
tag_data: 标签数据字典,包含tag_id, tag_name, tag_type等
|
||||
|
||||
Returns:
|
||||
Optional[str]: 标签ID(tag_code),如果保存失败返回None
|
||||
"""
|
||||
try:
|
||||
# 检查标签是否已存在(根据tag_name)
|
||||
existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')})
|
||||
|
||||
if existing_tag:
|
||||
# 更新现有标签
|
||||
tag_code = existing_tag.get('tag_code')
|
||||
self.tags_collection.update_one(
|
||||
{'tag_name': tag_data.get('tag_name')},
|
||||
{'$set': tag_data}
|
||||
)
|
||||
logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}")
|
||||
return tag_code
|
||||
else:
|
||||
# 插入新标签
|
||||
# 如果没有tag_code,生成一个
|
||||
if 'tag_code' not in tag_data or not tag_data.get('tag_code'):
|
||||
# 生成tag_code:使用当前最大ID+1,或者使用时间戳
|
||||
max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)])
|
||||
if max_tag and max_tag.get('tag_code'):
|
||||
try:
|
||||
max_id = int(max_tag['tag_code'])
|
||||
tag_code = str(max_id + 1)
|
||||
except:
|
||||
tag_code = str(int(datetime.now().timestamp() * 1000))
|
||||
else:
|
||||
tag_code = "1"
|
||||
tag_data['tag_code'] = tag_code
|
||||
|
||||
tag_data['create_time'] = datetime.now().isoformat()
|
||||
result = self.tags_collection.insert_one(tag_data)
|
||||
logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}")
|
||||
return tag_data.get('tag_code')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存标签失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_tag_by_name(self, tag_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据标签名称获取标签信息
|
||||
|
||||
Args:
|
||||
tag_name: 标签名称
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 标签信息,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
tag = self.tags_collection.find_one({'tag_name': tag_name})
|
||||
if tag:
|
||||
tag.pop('_id', None)
|
||||
return tag
|
||||
except Exception as e:
|
||||
logger.error(f"获取标签失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def save_tag_stock_relations(self, tag_code: str, stock_codes: List[str], replace_all: bool = False) -> bool:
|
||||
"""保存标签与股票的关联关系
|
||||
|
||||
Args:
|
||||
tag_code: 标签代码
|
||||
stock_codes: 股票代码列表
|
||||
replace_all: 是否替换所有关联关系(True=全量替换,False=增量添加),默认False
|
||||
|
||||
Returns:
|
||||
bool: 是否保存成功
|
||||
"""
|
||||
try:
|
||||
if not stock_codes:
|
||||
logger.warning("股票代码列表为空")
|
||||
return False
|
||||
|
||||
# 如果replace_all为True,先删除该标签的所有旧关联关系
|
||||
if replace_all:
|
||||
self.tag_stock_relations_collection.delete_many({'tag_code': tag_code})
|
||||
|
||||
# 批量插入新关联关系(使用upsert避免重复)
|
||||
relations = []
|
||||
for stock_code in stock_codes:
|
||||
# 检查是否已存在
|
||||
existing = self.tag_stock_relations_collection.find_one({
|
||||
'tag_code': tag_code,
|
||||
'stock_code': stock_code
|
||||
})
|
||||
|
||||
if not existing:
|
||||
relations.append({
|
||||
'tag_code': tag_code,
|
||||
'stock_code': stock_code,
|
||||
'create_time': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
if relations:
|
||||
self.tag_stock_relations_collection.insert_many(relations)
|
||||
logger.info(f"保存标签关联关系: tag_code={tag_code}, 新增股票数量={len(relations)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存标签关联关系失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_tag_stocks(self, tag_code: str) -> List[str]:
|
||||
"""获取标签关联的所有股票代码
|
||||
|
||||
Args:
|
||||
tag_code: 标签代码
|
||||
|
||||
Returns:
|
||||
List[str]: 股票代码列表
|
||||
"""
|
||||
try:
|
||||
relations = list(self.tag_stock_relations_collection.find({'tag_code': tag_code}))
|
||||
stock_codes = [r.get('stock_code') for r in relations if r.get('stock_code')]
|
||||
return stock_codes
|
||||
except Exception as e:
|
||||
logger.error(f"获取标签关联股票失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def update_tag_status(self, tag_code: str, status: str, progress: Optional[float] = None):
|
||||
"""更新标签处理状态和进度
|
||||
|
||||
Args:
|
||||
tag_code: 标签代码
|
||||
status: 处理状态(pending/processing/completed/failed)
|
||||
progress: 处理进度(0-100),可选
|
||||
"""
|
||||
try:
|
||||
update_data = {
|
||||
'status': status,
|
||||
'update_time': datetime.now().isoformat()
|
||||
}
|
||||
if progress is not None:
|
||||
update_data['progress'] = progress
|
||||
|
||||
self.tags_collection.update_one(
|
||||
{'tag_code': tag_code},
|
||||
{'$set': update_data}
|
||||
)
|
||||
logger.info(f"更新标签状态: tag_code={tag_code}, status={status}, progress={progress}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新标签状态失败: {str(e)}")
|
||||
|
||||
def delete_tag(self, tag_code: str, tag_name: str, delete_analysis: bool = False) -> Dict[str, Any]:
|
||||
"""删除标签及其相关数据
|
||||
|
||||
Args:
|
||||
tag_code: 标签代码
|
||||
tag_name: 标签名称(用于删除分析结果)
|
||||
delete_analysis: 是否删除分析结果,默认False
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 删除结果
|
||||
- success: 是否成功
|
||||
- deleted_tag: 是否删除了标签
|
||||
- deleted_relations: 删除的关联关系数量
|
||||
- deleted_analyses: 删除的分析结果数量(如果delete_analysis=True)
|
||||
- error: 错误信息
|
||||
"""
|
||||
result = {
|
||||
'success': False,
|
||||
'deleted_tag': False,
|
||||
'deleted_relations': 0,
|
||||
'deleted_analyses': 0,
|
||||
'error': None
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 删除标签与股票的关联关系(中间表)
|
||||
relations_result = self.tag_stock_relations_collection.delete_many({'tag_code': tag_code})
|
||||
result['deleted_relations'] = relations_result.deleted_count
|
||||
logger.info(f"删除标签关联关系: tag_code={tag_code}, 删除数量={relations_result.deleted_count}")
|
||||
|
||||
# 2. 删除标签本身
|
||||
tag_result = self.tags_collection.delete_one({'tag_code': tag_code})
|
||||
result['deleted_tag'] = tag_result.deleted_count > 0
|
||||
logger.info(f"删除标签: tag_code={tag_code}, tag_name={tag_name}, 是否删除={result['deleted_tag']}")
|
||||
|
||||
# 3. 可选:删除分析结果
|
||||
if delete_analysis:
|
||||
analyses_result = self.collection.delete_many({'tag_name': tag_name})
|
||||
result['deleted_analyses'] = analyses_result.deleted_count
|
||||
logger.info(f"删除分析结果: tag_name={tag_name}, 删除数量={analyses_result.deleted_count}")
|
||||
|
||||
result['success'] = True
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除标签失败: {str(e)}", exc_info=True)
|
||||
result['error'] = str(e)
|
||||
return result
|
||||
|
||||
def delete_tag_by_name(self, tag_name: str, delete_analysis: bool = False) -> Dict[str, Any]:
|
||||
"""根据标签名称删除标签及其相关数据
|
||||
|
||||
Args:
|
||||
tag_name: 标签名称
|
||||
delete_analysis: 是否删除分析结果,默认False
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 删除结果
|
||||
- success: 是否成功
|
||||
- tag_code: 标签代码(如果找到)
|
||||
- deleted_tag: 是否删除了标签
|
||||
- deleted_relations: 删除的关联关系数量
|
||||
- deleted_analyses: 删除的分析结果数量(如果delete_analysis=True)
|
||||
- error: 错误信息
|
||||
"""
|
||||
result = {
|
||||
'success': False,
|
||||
'tag_code': None,
|
||||
'deleted_tag': False,
|
||||
'deleted_relations': 0,
|
||||
'deleted_analyses': 0,
|
||||
'error': None
|
||||
}
|
||||
|
||||
try:
|
||||
# 先查找标签
|
||||
tag = self.get_tag_by_name(tag_name)
|
||||
if not tag:
|
||||
result['error'] = f"标签不存在: {tag_name}"
|
||||
logger.warning(result['error'])
|
||||
return result
|
||||
|
||||
tag_code = tag.get('tag_code')
|
||||
result['tag_code'] = tag_code
|
||||
|
||||
# 调用根据tag_code删除的方法
|
||||
delete_result = self.delete_tag(tag_code, tag_name, delete_analysis)
|
||||
|
||||
result.update(delete_result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据标签名称删除标签失败: {str(e)}", exc_info=True)
|
||||
result['error'] = str(e)
|
||||
return result
|
||||
|
||||
def close_connection(self):
|
||||
"""关闭数据库连接"""
|
||||
try:
|
||||
if self.mongo_client:
|
||||
self.mongo_client.close()
|
||||
logger.info("MongoDB连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭MongoDB连接失败: {str(e)}")
|
||||
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
标签关联分析服务类
|
||||
整合分析器和数据库操作,提供完整的分析+存储功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
# 处理相对导入和绝对导入
|
||||
try:
|
||||
from .tag_relation_analyzer import TagRelationAnalyzer
|
||||
from .tag_relation_database import TagRelationDatabase
|
||||
except ImportError:
|
||||
# 如果相对导入失败,使用绝对导入
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from stock_tag_analysis.tag_relation_analyzer import TagRelationAnalyzer
|
||||
from stock_tag_analysis.tag_relation_database import TagRelationDatabase
|
||||
|
||||
# 设置日志记录
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TagRelationService:
|
||||
"""标签关联分析服务类"""
|
||||
|
||||
def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True, collection_name: Optional[str] = None):
|
||||
"""初始化服务
|
||||
|
||||
Args:
|
||||
model_type: 大模型类型,默认为 ds3_2
|
||||
enable_web_search: 是否启用联网搜索,默认True
|
||||
collection_name: MongoDB集合名称,如果不指定则使用默认集合名
|
||||
"""
|
||||
# 初始化分析器
|
||||
self.analyzer = TagRelationAnalyzer(
|
||||
model_type=model_type,
|
||||
enable_web_search=enable_web_search
|
||||
)
|
||||
|
||||
# 初始化数据库
|
||||
self.database = TagRelationDatabase(collection_name=collection_name)
|
||||
|
||||
logger.info("TagRelationService 初始化完成")
|
||||
|
||||
def analyze_and_save(
|
||||
self,
|
||||
stock_code: str,
|
||||
stock_name: str,
|
||||
tag_name: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
save_to_db: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""分析个股与标签的关联性并保存到数据库
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码(如:300750.SZ)
|
||||
stock_name: 股票名称(如:宁德时代)
|
||||
tag_name: 标签名称(如:医药、新能源、金融科技等)
|
||||
temperature: 大模型温度参数,默认0.7
|
||||
max_tokens: 最大token数,默认4096
|
||||
save_to_db: 是否保存到数据库,默认True
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果,包含所有分析字段和保存状态
|
||||
"""
|
||||
try:
|
||||
# 执行分析
|
||||
logger.info(f"开始分析 {stock_name}({stock_code}) 与标签 {tag_name} 的关联性")
|
||||
analysis_result = self.analyzer.analyze_tag_relation(
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
tag_name=tag_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
if save_to_db and analysis_result.get('success'):
|
||||
save_success = self.database.save_analysis_result(analysis_result)
|
||||
analysis_result['saved_to_db'] = save_success
|
||||
if save_success:
|
||||
logger.info(f"分析结果已保存到数据库: {stock_code} - {tag_name}")
|
||||
else:
|
||||
logger.warning(f"分析结果保存失败: {stock_code} - {tag_name}")
|
||||
else:
|
||||
analysis_result['saved_to_db'] = False
|
||||
if not save_to_db:
|
||||
logger.info(f"跳过数据库保存: {stock_code} - {tag_name}")
|
||||
|
||||
return analysis_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析并保存过程中出错: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"stock_code": stock_code,
|
||||
"stock_name": stock_name,
|
||||
"tag_name": tag_name,
|
||||
"error": str(e),
|
||||
"saved_to_db": False
|
||||
}
|
||||
|
||||
def batch_analyze_and_save(
|
||||
self,
|
||||
stock_list: List[Dict[str, str]],
|
||||
tag_name: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
save_to_db: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""批量分析多只股票与标签的关联性并保存到数据库
|
||||
|
||||
Args:
|
||||
stock_list: 股票列表,每个元素包含 stock_code 和 stock_name
|
||||
tag_name: 标签名称
|
||||
temperature: 大模型温度参数
|
||||
max_tokens: 最大token数
|
||||
save_to_db: 是否保存到数据库,默认True
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 分析结果列表
|
||||
"""
|
||||
results = []
|
||||
total = len(stock_list)
|
||||
|
||||
for idx, stock in enumerate(stock_list, 1):
|
||||
stock_code = stock.get("stock_code")
|
||||
stock_name = stock.get("stock_name")
|
||||
|
||||
if not stock_code or not stock_name:
|
||||
logger.warning(f"跳过无效股票信息: {stock}")
|
||||
continue
|
||||
|
||||
logger.info(f"处理进度: {idx}/{total} - {stock_name}({stock_code})")
|
||||
|
||||
result = self.analyze_and_save(
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
tag_name=tag_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
save_to_db=save_to_db
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# 统计结果
|
||||
success_count = sum(1 for r in results if r.get('success'))
|
||||
saved_count = sum(1 for r in results if r.get('saved_to_db'))
|
||||
logger.info(f"批量分析完成 - 总数: {len(results)}, 成功: {success_count}, 已保存: {saved_count}")
|
||||
|
||||
return results
|
||||
|
||||
def get_analysis_result(self, stock_code: str, tag_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""从数据库获取分析结果
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
tag_name: 标签名称
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 分析结果,如果不存在返回None
|
||||
"""
|
||||
return self.database.get_analysis_result(stock_code, tag_name)
|
||||
|
||||
def get_stock_analyses(self, stock_code: str) -> list:
|
||||
"""获取指定股票的所有标签分析结果
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
list: 分析结果列表
|
||||
"""
|
||||
return self.database.get_stock_analyses(stock_code)
|
||||
|
||||
def get_tag_analyses(self, tag_name: str) -> list:
|
||||
"""获取指定标签的所有股票分析结果
|
||||
|
||||
Args:
|
||||
tag_name: 标签名称
|
||||
|
||||
Returns:
|
||||
list: 分析结果列表
|
||||
"""
|
||||
return self.database.get_tag_analyses(tag_name)
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
self.database.close_connection()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 测试示例
|
||||
service = TagRelationService()
|
||||
|
||||
# 测试:分析宁德时代与医药标签的关联性并保存
|
||||
result = service.analyze_and_save(
|
||||
stock_code="000547.SZ",
|
||||
stock_name="航天发展",
|
||||
tag_name="商业航天"
|
||||
)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("分析结果:")
|
||||
print("="*50)
|
||||
print(f"成功: {result.get('success')}")
|
||||
print(f"已保存: {result.get('saved_to_db')}")
|
||||
print(f"关联度评分: {result.get('relation_score')}")
|
||||
print(f"结论: {result.get('conclusion')}")
|
||||
|
||||
# 关闭连接
|
||||
service.close()
|
||||
|
||||
|
|
@ -468,7 +468,7 @@ if __name__ == "__main__":
|
|||
|
||||
# 4. 采集指定日期范围的数据
|
||||
collect_chip_distribution(db_url, tushare_token, mode='full',
|
||||
start_date='2021-01-03', end_date='2021-09-30')
|
||||
start_date='2025-11-28', end_date='2025-12-09')
|
||||
|
||||
# 5. 调整批量入库大小(默认100只股票一批)
|
||||
# collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200)
|
||||
|
|
|
|||
|
|
@ -43,6 +43,27 @@ MONGO_CONFIG2 = {
|
|||
'collection': 'wind_financial_analysis'
|
||||
}
|
||||
|
||||
# --- MySQL配置(用于个股标签关联分析模块)
|
||||
DB_CONFIG_TAG_ANALYSIS = {
|
||||
'host': '192.168.18.199',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': 'Chlry#$.8',
|
||||
'database': 'db_gp_cj'
|
||||
}
|
||||
# 创建数据库连接URL(用于个股标签关联分析模块)
|
||||
DB_URL_TAG_ANALYSIS = f"mysql+pymysql://{DB_CONFIG_TAG_ANALYSIS['user']}:{DB_CONFIG_TAG_ANALYSIS['password']}@{DB_CONFIG_TAG_ANALYSIS['host']}:{DB_CONFIG_TAG_ANALYSIS['port']}/{DB_CONFIG_TAG_ANALYSIS['database']}"
|
||||
|
||||
# MongoDB配置(用于个股标签关联分析模块)
|
||||
MONGO_CONFIG222 = {
|
||||
'host': '192.168.16.222',
|
||||
'port': 27017,
|
||||
'db': 'stock_predictions',
|
||||
'username': 'stockvix',
|
||||
'password': '!stock@vix2',
|
||||
'collection': 'tag_relation_analysis' # 默认集合名
|
||||
}
|
||||
|
||||
# 项目根目录
|
||||
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue