新增:标签分析逻辑。

This commit is contained in:
liao 2026-01-09 09:48:50 +08:00
parent 16f3efe3ae
commit 3aa280090e
9 changed files with 2538 additions and 2 deletions

View File

@ -123,6 +123,15 @@ fear_greed_manager = FearGreedIndexManager()
# 创建指数分析器实例 # 创建指数分析器实例
index_analyzer = IndexAnalyzer() index_analyzer = IndexAnalyzer()
# 创建标签关联分析API实例
try:
from src.stock_tag_analysis.tag_relation_api import TagRelationAPI
tag_relation_api = TagRelationAPI()
logger.info("标签关联分析API初始化成功")
except Exception as e:
logger.error(f"标签关联分析API初始化失败: {str(e)}")
tag_relation_api = None
# 获取项目根目录 # 获取项目根目录
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports') REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports')
@ -3887,6 +3896,240 @@ def run_tech_fundamental_strategy_batch():
return jsonify({"status": "error", "message": str(e)}), 500 return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/tag/process', methods=['POST'])
def process_tag_relation():
"""处理标签关联分析接口
请求体参数:
- tag_name: 标签名称必填
- tag_type: 标签类型可选默认为None会自动判断
- model_type: 大模型类型可选默认为ds3_2
- enable_web_search: 是否启用联网搜索可选默认为True
- temperature: 大模型温度参数可选默认为0.7
- max_tokens: 最大token数可选默认为4096
返回:
{
"status": "success",
"data": {
"success": true,
"tag_code": "123",
"tag_type": "AI标签",
"source": "ai",
"stock_count": 0,
"message": "AI分析任务已加入队列共5000只股票待处理",
"error": null,
"status": "processing"
}
}
"""
try:
# 检查API是否初始化成功
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
# 获取参数
data = request.get_json() if request.is_json else request.form
tag_name = data.get('tag_name')
tag_type = data.get('tag_type')
model_type = data.get('model_type', 'ds3_2')
enable_web_search = data.get('enable_web_search', True)
temperature = float(data.get('temperature', 0.7))
max_tokens = int(data.get('max_tokens', 4096))
# 验证必要参数
if not tag_name:
return jsonify({
"status": "error",
"message": "缺少必要参数: tag_name"
}), 400
# 调用处理方法
result = tag_relation_api.process_tag(
tag_name=tag_name,
tag_type=tag_type,
model_type=model_type,
enable_web_search=enable_web_search,
temperature=temperature,
max_tokens=max_tokens
)
# 返回结果
if result.get('success'):
return jsonify({
"status": "success",
"data": result
})
else:
return jsonify({
"status": "error",
"message": result.get('message', '处理失败'),
"data": result
}), 500
except Exception as e:
logger.error(f"处理标签关联分析失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"处理标签关联分析失败: {str(e)}"
}), 500
@app.route('/api/tag/progress', methods=['GET'])
def get_tag_progress():
"""获取标签处理进度接口
参数:
- tag_name: 标签名称必填
返回:
{
"status": "success",
"data": {
"tag_code": "123",
"tag_name": "商业航天",
"status": "processing",
"progress": 45.5,
"total_stocks": 5000,
"processed": 2275,
"remaining": 2725,
"queue_length": 5000
}
}
"""
try:
# 检查API是否初始化成功
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
# 获取参数
tag_name = request.args.get('tag_name')
# 验证参数
if not tag_name:
return jsonify({
"status": "error",
"message": "缺少必要参数: tag_name"
}), 400
# 调用获取进度方法
result = tag_relation_api.get_tag_progress(tag_name=tag_name)
# 检查是否有错误
if 'error' in result:
return jsonify({
"status": "error",
"message": result.get('error', '获取进度失败')
}), 404
# 返回结果
return jsonify({
"status": "success",
"data": result
})
except Exception as e:
logger.error(f"获取标签进度失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"获取标签进度失败: {str(e)}"
}), 500
@app.route('/api/tag/delete', methods=['POST', 'DELETE'])
def delete_tag_relation():
"""删除标签及其相关数据接口
请求参数:
- tag_name: 标签名称必填
- delete_analysis: 是否删除分析结果可选默认为False只删除标签和关联关系
返回:
{
"status": "success",
"data": {
"success": true,
"tag_code": "123",
"deleted_tag": true,
"deleted_relations": 50,
"deleted_analyses": 0,
"deleted_redis_keys": 2,
"message": "删除成功: 标签=true, 关联关系=50条, Redis键=2个",
"error": null
}
}
"""
try:
# 检查API是否初始化成功
if tag_relation_api is None:
return jsonify({
"status": "error",
"message": "标签关联分析API未初始化"
}), 500
# 获取参数
if request.is_json:
data = request.get_json()
else:
data = request.form
tag_name = data.get('tag_name')
delete_analysis = data.get('delete_analysis', False)
# 处理delete_analysis参数可能是字符串"true"/"false"
if isinstance(delete_analysis, str):
delete_analysis = delete_analysis.lower() in ('true', '1', 'yes')
else:
delete_analysis = bool(delete_analysis)
# 验证必要参数
if not tag_name:
return jsonify({
"status": "error",
"message": "缺少必要参数: tag_name"
}), 400
# 调用删除方法
result = tag_relation_api.delete_tag_by_name(
tag_name=tag_name,
delete_analysis=delete_analysis
)
# 返回结果
if result.get('success'):
return jsonify({
"status": "success",
"data": result
})
else:
# 如果是标签不存在返回404
if "不存在" in result.get('message', ''):
return jsonify({
"status": "error",
"message": result.get('message', '标签不存在'),
"data": result
}), 404
else:
return jsonify({
"status": "error",
"message": result.get('message', '删除失败'),
"data": result
}), 500
except Exception as e:
logger.error(f"删除标签失败: {str(e)}", exc_info=True)
return jsonify({
"status": "error",
"message": f"删除标签失败: {str(e)}"
}), 500
if __name__ == '__main__': if __name__ == '__main__':
# 启动Web服务器 # 启动Web服务器

View File

@ -63,7 +63,8 @@ MODEL_CONFIGS = {
"models": { "models": {
"offline_model": "ep-20250326090920-v7wns", "offline_model": "ep-20250326090920-v7wns",
"online_bot": "bot-20250325102825-h9kpq", "online_bot": "bot-20250325102825-h9kpq",
"doubao": "doubao-1-5-pro-32k-250115" "doubao": "doubao-1-5-pro-32k-250115",
"ds3_2": "ep-20260104113811-m2h2f"
} }
}, },
# 谷歌Gemini # 谷歌Gemini

View File

@ -0,0 +1,276 @@
import logging
import os
import sys
import time
from typing import Dict, Any
from datetime import datetime
from openai import OpenAI
# 设置日志记录
logger = logging.getLogger(__name__)
# 获取项目根目录
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 导入配置
sys.path.append(os.path.dirname(ROOT_DIR))
sys.path.append(ROOT_DIR)
try:
from scripts.config import get_random_api_key, get_model
except ImportError:
from src.scripts.config import get_random_api_key, get_model
class ChatBot:
def __init__(self, model_type: str = "online_bot", enable_web_search: bool = True):
"""初始化聊天机器人
Args:
model_type: 模型类型默认为 online_bot
enable_web_search: 是否启用联网搜索默认True
"""
self.api_key = get_random_api_key()
self.model = get_model(model_type)
self.enable_web_search = enable_web_search
logger.info(f"初始化ChatBot模型: {self.model}, 联网搜索: {self.enable_web_search}")
# 初始化客户端
self.client = OpenAI(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=self.api_key
)
# 联网搜索工具配置
self.tools = [{
"type": "web_search",
"limit": 10,
"sources": ["douyin", "moji", "toutiao"],
"user_location": {
"type": "approximate",
"country": "中国",
"region": "北京",
"city": "北京"
}
}] if enable_web_search else None
# 系统提示语
self.system_message = """你是一个专业的股票分析助手,擅长进行深入的基本面分析。你的分析应该:
1. 专业严谨 - 使用准确的专业术语引用可靠的数据来源分析逻辑清晰结论有理有据
2. 全面细致 - 深入分析问题的各个方面关注细节和关键信息考虑多个影响因素
3. 客观中立 - 保持独立判断不夸大或贬低平衡利弊分析指出潜在风险
4. 实用性强 - 分析结论具体明确建议具有可操作性关注实际投资价值
5. 及时更新 - 关注最新信息反映市场变化动态调整分析保持信息时效性
请根据用户的具体需求提供专业深入的分析如果遇到不确定的信息请明确说明"""
# 对话历史(简化版,只保存用户和助手消息)
self.conversation_history = []
def format_reference(self, ref):
"""格式化参考资料"""
if isinstance(ref, str):
return ref
elif isinstance(ref, dict):
parts = []
if ref.get('title'):
parts.append(f"标题:{ref['title']}")
if ref.get('summary'):
parts.append(f"摘要:{ref['summary']}")
if ref.get('url'):
parts.append(f"链接:{ref['url']}")
if ref.get('publish_time'):
parts.append(f"发布时间:{ref['publish_time']}")
return "\n".join(parts)
return str(ref)
def chat(self, user_input: str, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]:
"""与AI进行对话
Args:
user_input: 用户输入的问题
temperature: 控制输出的随机性范围0-2默认0.7
top_p: 控制输出的多样性范围0-1默认0.7
max_tokens: 控制输出的最大长度默认4096
frequency_penalty: 控制重复惩罚范围-2到2默认0.0
Returns:
Dict[str, Any]: 包含 response, reasoning_process, references, tool_usage, tool_usage_details
"""
try:
# 添加用户消息到对话历史
self.conversation_history.append({
"role": "user",
"content": [{"type": "input_text", "text": user_input}]
})
# 构建输入消息
input_messages = [
# 系统提示词
{
"role": "system",
"content": [{"type": "input_text", "text": self.system_message}]
}
]
# 添加最近的对话历史保留最近5轮
for msg in self.conversation_history[-5:]:
input_messages.append(msg)
# 调用API流式输出
response = self.client.responses.create(
model=self.model,
input=input_messages,
tools=self.tools,
stream=True
)
# 收集回复
full_response = ""
reasoning_text = ""
references = []
tool_usage = None
tool_usage_details = None
search_keywords = []
# 状态变量
thinking_started = False
answering_started = False
print("\nAI: ", end="", flush=True)
# 处理事件流
for chunk in response:
chunk_type = getattr(chunk, "type", "")
# 处理AI思考过程
if chunk_type == "response.reasoning_summary_text.delta":
if not thinking_started:
thinking_started = True
delta = getattr(chunk, "delta", "")
if delta:
reasoning_text += delta
# 处理搜索状态
elif "web_search_call" in chunk_type:
if "in_progress" in chunk_type:
logger.debug("开始搜索")
elif "completed" in chunk_type:
logger.debug("搜索完成")
# 处理搜索关键词
elif (chunk_type == "response.output_item.done"
and hasattr(chunk, "item")
and str(getattr(chunk.item, "id", "")).startswith("ws_")):
if hasattr(chunk.item, "action") and hasattr(chunk.item.action, "query"):
search_keyword = chunk.item.action.query
search_keywords.append(search_keyword)
logger.debug(f"搜索关键词: {search_keyword}")
# 处理最终回答
elif chunk_type == "response.output_text.delta":
if not answering_started:
answering_started = True
delta = getattr(chunk, "delta", "")
if delta:
print(delta, end="", flush=True)
full_response += delta
time.sleep(0.01)
# 处理完成事件,提取工具使用统计
elif chunk_type == "response.completed":
if hasattr(chunk, "response") and hasattr(chunk.response, "usage"):
usage = chunk.response.usage
if hasattr(usage, "tool_usage"):
tool_usage = usage.tool_usage
if hasattr(usage, "tool_usage_details"):
tool_usage_details = usage.tool_usage_details
print() # 换行
# 提取推理过程和参考资料
main_response = full_response
reasoning_process = reasoning_text
if "推理过程:" in full_response:
parts = full_response.split("推理过程:")
main_response = parts[0].strip()
if len(parts) > 1:
reasoning_parts = parts[1].split("参考资料:")
reasoning_process = reasoning_parts[0].strip() if reasoning_parts else reasoning_text
# 添加AI回复到对话历史
self.conversation_history.append({
"role": "assistant",
"content": [{"type": "input_text", "text": main_response}]
})
return {
"response": main_response,
"reasoning_process": reasoning_process,
"references": references,
"tool_usage": tool_usage,
"tool_usage_details": tool_usage_details,
"search_keywords": search_keywords
}
except Exception as e:
logger.error(f"对话失败: {str(e)}", exc_info=True)
return {
"response": f"抱歉,处理您的请求时出现错误: {str(e)}",
"reasoning_process": "",
"references": [],
"tool_usage": None,
"tool_usage_details": None,
"search_keywords": []
}
def clear_history(self):
"""清除对话历史"""
self.conversation_history = []
print("对话历史已清除")
def run(self):
"""运行聊天机器人"""
print("欢迎使用AI助手输入 'quit' 退出,输入 'clear' 清除对话历史。")
if self.enable_web_search:
print("该版本支持联网搜索,可以回答实时信息。")
print("-" * 50)
while True:
try:
user_input = input("\n你: ").strip()
if user_input.lower() == 'quit':
print("感谢使用,再见!")
break
if user_input.lower() == 'clear':
self.clear_history()
continue
if not user_input:
continue
self.chat(user_input)
print("-" * 50)
except KeyboardInterrupt:
print("\n感谢使用,再见!")
break
except Exception as e:
logger.error(f"运行错误: {str(e)}")
print(f"发生错误: {str(e)}")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
bot = ChatBot(enable_web_search=True)
bot.run()

View File

@ -0,0 +1,323 @@
"""
个股标签关联分析器
使用联网大模型分析个股与标签的关联性并输出结构化结果
"""
import logging
import json
import re
from typing import Dict, Any, Optional, List
from datetime import datetime
# 导入联网大模型工具
import sys
import os
try:
from .chat_bot import ChatBot
except ImportError:
try:
from src.stock_tag_analysis.chat_bot import ChatBot
except ImportError:
# 如果相对导入失败,添加路径并使用绝对导入
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
from chat_bot import ChatBot
# 设置日志记录
logger = logging.getLogger(__name__)
class TagRelationAnalyzer:
"""个股标签关联分析器"""
def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True):
"""初始化分析器
Args:
model_type: 大模型类型默认为 ds3_2联网智能体
enable_web_search: 是否启用联网搜索默认True
"""
self.chat_bot = ChatBot(model_type=model_type, enable_web_search=enable_web_search)
logger.info(f"TagRelationAnalyzer 初始化完成,模型: {model_type}, 联网搜索: {enable_web_search}")
def _build_analysis_prompt(self, stock_code: str, stock_name: str, tag_name: str) -> str:
"""构建分析提示词
Args:
stock_code: 股票代码
stock_name: 股票名称
tag_name: 标签名称
Returns:
str: 分析提示词
"""
prompt = f"""请分析股票"{stock_name}"(股票代码:{stock_code})与标签"{tag_name}"是否存在关联性。
请通过联网搜索获取最新的信息包括但不限于
1. 公司主营业务是否与标签相关
2. 公司产品服务是否涉及标签领域
3. 公司是否在标签相关行业有布局
4. 公司公告新闻中是否提及标签相关内容
5. 公司财务数据是否显示与标签相关的业务占比
请按照以下JSON格式输出分析结果
{{
"has_relation": true/false, // 是否存在关联性
"relation_score": 0-100, // 关联度评分0-100数值越高关联度越高
"relation_type": "直接关联/间接关联/无关联", // 关联类型
"analysis_summary": "简要分析总结100字以内",
"detailed_analysis": "详细分析内容(包括关联依据、业务占比、市场表现等)",
"key_evidence": [ // 关键证据列表
{{
"evidence_type": "主营业务/产品服务/行业布局/公告新闻/财务数据",
"description": "证据描述",
"relevance": "高/中/低"
}}
],
"business_ratio": "与标签相关的业务占比30%、主要业务、次要业务等,如无相关信息则填写'未知'",
"market_performance": "市场表现相关描述(如:该标签概念股表现、行业地位等,如无相关信息则填写'未知'",
"conclusion": "最终结论50字以内"
}}
请确保输出的是有效的JSON格式不要包含任何其他文字说明"""
return prompt
def _parse_json_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""从响应文本中解析JSON
Args:
response_text: 响应文本
Returns:
Optional[Dict[str, Any]]: 解析后的JSON字典如果解析失败返回None
"""
try:
# 尝试直接解析JSON
# 查找JSON代码块
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response_text, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 查找大括号包裹的JSON
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
json_str = json_match.group(0)
else:
# 尝试直接解析整个响应
json_str = response_text.strip()
# 清理可能的Markdown标记
json_str = json_str.strip()
if json_str.startswith('```'):
json_str = re.sub(r'^```[a-z]*\n?', '', json_str)
if json_str.endswith('```'):
json_str = re.sub(r'\n?```$', '', json_str)
# 解析JSON
result = json.loads(json_str)
return result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {str(e)}")
logger.error(f"响应文本: {response_text[:500]}")
return None
except Exception as e:
logger.error(f"解析响应时出错: {str(e)}")
return None
def analyze_tag_relation(
self,
stock_code: str,
stock_name: str,
tag_name: str,
temperature: float = 0.7,
max_tokens: int = 4096
) -> Dict[str, Any]:
"""分析个股与标签的关联性
Args:
stock_code: 股票代码300750.SZ
stock_name: 股票名称宁德时代
tag_name: 标签名称医药新能源金融科技等
temperature: 大模型温度参数默认0.7
max_tokens: 最大token数默认4096
Returns:
Dict[str, Any]: 分析结果包含以下字段
- success: 是否成功
- stock_code: 股票代码
- stock_name: 股票名称
- tag_name: 标签名称
- has_relation: 是否存在关联性
- relation_score: 关联度评分0-100
- relation_type: 关联类型
- analysis_summary: 简要分析总结
- detailed_analysis: 详细分析内容
- key_evidence: 关键证据列表
- business_ratio: 业务占比
- market_performance: 市场表现
- conclusion: 最终结论
- raw_response: 原始响应
- reasoning_process: 推理过程
- references: 参考资料
- analysis_time: 分析时间
- error: 错误信息如果失败
"""
try:
logger.info(f"开始分析 {stock_name}({stock_code}) 与标签 {tag_name} 的关联性")
# 构建分析提示词
prompt = self._build_analysis_prompt(stock_code, stock_name, tag_name)
# 调用大模型进行分析
result = self.chat_bot.chat(
user_input=prompt,
temperature=temperature,
max_tokens=max_tokens
)
# 解析响应
response_text = result.get("response", "")
reasoning_process = result.get("reasoning_process", "")
references = result.get("references", [])
# 解析JSON结果
parsed_result = self._parse_json_response(response_text)
if parsed_result is None:
# 如果解析失败,返回原始响应
logger.warning("JSON解析失败返回原始响应")
return {
"success": False,
"stock_code": stock_code,
"stock_name": stock_name,
"tag_name": tag_name,
"has_relation": None,
"relation_score": None,
"relation_type": None,
"analysis_summary": None,
"detailed_analysis": response_text,
"key_evidence": [],
"business_ratio": None,
"market_performance": None,
"conclusion": None,
"raw_response": response_text,
"reasoning_process": reasoning_process,
"references": references,
"analysis_time": datetime.now().isoformat(),
"error": "JSON解析失败请检查响应格式"
}
# 构建完整结果
analysis_result = {
"success": True,
"stock_code": stock_code,
"stock_name": stock_name,
"tag_name": tag_name,
"has_relation": parsed_result.get("has_relation", False),
"relation_score": parsed_result.get("relation_score", 0),
"relation_type": parsed_result.get("relation_type", "无关联"),
"analysis_summary": parsed_result.get("analysis_summary", ""),
"detailed_analysis": parsed_result.get("detailed_analysis", ""),
"key_evidence": parsed_result.get("key_evidence", []),
"business_ratio": parsed_result.get("business_ratio", "未知"),
"market_performance": parsed_result.get("market_performance", "未知"),
"conclusion": parsed_result.get("conclusion", ""),
"raw_response": response_text,
"reasoning_process": reasoning_process,
"references": references,
"analysis_time": datetime.now().isoformat(),
"error": None
}
logger.info(f"分析完成: {stock_name}({stock_code}) 与 {tag_name} 的关联度评分为 {analysis_result['relation_score']}")
return analysis_result
except Exception as e:
logger.error(f"分析过程中出错: {str(e)}", exc_info=True)
return {
"success": False,
"stock_code": stock_code,
"stock_name": stock_name,
"tag_name": tag_name,
"has_relation": None,
"relation_score": None,
"relation_type": None,
"analysis_summary": None,
"detailed_analysis": None,
"key_evidence": [],
"business_ratio": None,
"market_performance": None,
"conclusion": None,
"raw_response": None,
"reasoning_process": None,
"references": [],
"analysis_time": datetime.now().isoformat(),
"error": str(e)
}
def batch_analyze_tag_relation(
self,
stock_list: List[Dict[str, str]],
tag_name: str,
temperature: float = 0.7,
max_tokens: int = 4096
) -> List[Dict[str, Any]]:
"""批量分析多只股票与标签的关联性
Args:
stock_list: 股票列表每个元素包含 stock_code stock_name
tag_name: 标签名称
temperature: 大模型温度参数
max_tokens: 最大token数
Returns:
List[Dict[str, Any]]: 分析结果列表
"""
results = []
for stock in stock_list:
stock_code = stock.get("stock_code")
stock_name = stock.get("stock_name")
if not stock_code or not stock_name:
logger.warning(f"跳过无效股票信息: {stock}")
continue
result = self.analyze_tag_relation(
stock_code=stock_code,
stock_name=stock_name,
tag_name=tag_name,
temperature=temperature,
max_tokens=max_tokens
)
results.append(result)
# 清除对话历史,避免上下文干扰
self.chat_bot.clear_history()
return results
if __name__ == "__main__":
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 测试示例
analyzer = TagRelationAnalyzer()
# 测试:分析宁德时代与医药标签的关联性
result = analyzer.analyze_tag_relation(
stock_code="300750.SZ",
stock_name="宁德时代",
tag_name="医药"
)
print("\n" + "="*50)
print("分析结果:")
print("="*50)
print(json.dumps(result, ensure_ascii=False, indent=2))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,448 @@
"""
标签关联分析数据库操作模块
用于连接 MongoDB 并存储分析结果
"""
import logging
import pymongo
from typing import Dict, Any, Optional, List
from datetime import datetime
# 导入配置
import sys
import os
# 添加项目路径
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_DIR)
try:
from valuation_analysis.config import MONGO_CONFIG222
except ImportError:
try:
from src.valuation_analysis.config import MONGO_CONFIG222
except ImportError:
# 尝试从当前目录导入
sys.path.append(os.path.dirname(ROOT_DIR))
from src.valuation_analysis.config import MONGO_CONFIG222
# 设置日志记录
logger = logging.getLogger(__name__)
class TagRelationDatabase:
"""标签关联分析数据库操作类"""
def __init__(self, collection_name: Optional[str] = None):
"""初始化数据库连接
Args:
collection_name: 集合名称如果不指定则使用配置中的默认集合名
"""
self.mongo_client = None
self.db = None
self.collection_name = collection_name or MONGO_CONFIG222.get('collection', 'tag_relation_analysis')
self.collection = None
# 标签集合和关联集合
self.tags_collection = None
self.tag_stock_relations_collection = None
self.connect_mongodb()
def connect_mongodb(self):
"""连接MongoDB数据库"""
try:
self.mongo_client = pymongo.MongoClient(
host=MONGO_CONFIG222['host'],
port=MONGO_CONFIG222['port'],
username=MONGO_CONFIG222['username'],
password=MONGO_CONFIG222['password'],
authSource='admin' # 认证数据库
)
self.db = self.mongo_client[MONGO_CONFIG222['db']]
self.collection = self.db[self.collection_name]
# 初始化标签集合和关联集合
self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合
self.tag_stock_relations_collection = self.db['tag_stock_relations']
# 测试连接
self.mongo_client.admin.command('ping')
logger.info(f"MongoDB连接成功数据库: {MONGO_CONFIG222['db']}, 集合: {self.collection_name}")
except Exception as e:
logger.error(f"MongoDB连接失败: {str(e)}")
raise
def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool:
"""保存分析结果到MongoDB
Args:
analysis_result: 分析结果字典包含所有分析字段
Returns:
bool: 是否保存成功
"""
try:
if not analysis_result:
logger.warning("分析结果为空,无法保存")
return False
# 添加保存时间
analysis_result['save_time'] = datetime.now().isoformat()
# 使用 stock_code 和 tag_name 作为唯一标识
filter_condition = {
'stock_code': analysis_result.get('stock_code'),
'tag_name': analysis_result.get('tag_name')
}
# 检查是否已存在
existing_record = self.collection.find_one(filter_condition)
if existing_record:
# 更新现有记录
update_result = self.collection.update_one(
filter_condition,
{'$set': analysis_result}
)
logger.info(f"更新分析结果: {analysis_result.get('stock_code')} - {analysis_result.get('tag_name')}")
return update_result.modified_count > 0
else:
# 插入新记录
insert_result = self.collection.insert_one(analysis_result)
logger.info(f"插入新分析结果: {analysis_result.get('stock_code')} - {analysis_result.get('tag_name')}, ID: {insert_result.inserted_id}")
return insert_result.inserted_id is not None
except Exception as e:
logger.error(f"保存分析结果失败: {str(e)}", exc_info=True)
return False
def get_analysis_result(self, stock_code: str, tag_name: str) -> Optional[Dict[str, Any]]:
"""获取指定股票和标签的分析结果
Args:
stock_code: 股票代码
tag_name: 标签名称
Returns:
Optional[Dict[str, Any]]: 分析结果如果不存在返回None
"""
try:
result = self.collection.find_one({
'stock_code': stock_code,
'tag_name': tag_name
})
if result:
# 移除 MongoDB 的 _id 字段(如果需要)
result.pop('_id', None)
return result
except Exception as e:
logger.error(f"获取分析结果失败: {str(e)}")
return None
def get_stock_analyses(self, stock_code: str) -> list:
"""获取指定股票的所有标签分析结果
Args:
stock_code: 股票代码
Returns:
list: 分析结果列表
"""
try:
results = list(self.collection.find({'stock_code': stock_code}))
# 移除 MongoDB 的 _id 字段
for result in results:
result.pop('_id', None)
return results
except Exception as e:
logger.error(f"获取股票分析结果失败: {str(e)}")
return []
def get_tag_analyses(self, tag_name: str) -> list:
"""获取指定标签的所有股票分析结果
Args:
tag_name: 标签名称
Returns:
list: 分析结果列表
"""
try:
results = list(self.collection.find({'tag_name': tag_name}))
# 移除 MongoDB 的 _id 字段
for result in results:
result.pop('_id', None)
return results
except Exception as e:
logger.error(f"获取标签分析结果失败: {str(e)}")
return []
def save_tag(self, tag_data: Dict[str, Any]) -> Optional[str]:
"""保存标签到tags集合
Args:
tag_data: 标签数据字典包含tag_id, tag_name, tag_type等
Returns:
Optional[str]: 标签IDtag_code如果保存失败返回None
"""
try:
# 检查标签是否已存在根据tag_name
existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')})
if existing_tag:
# 更新现有标签
tag_code = existing_tag.get('tag_code')
self.tags_collection.update_one(
{'tag_name': tag_data.get('tag_name')},
{'$set': tag_data}
)
logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}")
return tag_code
else:
# 插入新标签
# 如果没有tag_code生成一个
if 'tag_code' not in tag_data or not tag_data.get('tag_code'):
# 生成tag_code使用当前最大ID+1或者使用时间戳
max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)])
if max_tag and max_tag.get('tag_code'):
try:
max_id = int(max_tag['tag_code'])
tag_code = str(max_id + 1)
except:
tag_code = str(int(datetime.now().timestamp() * 1000))
else:
tag_code = "1"
tag_data['tag_code'] = tag_code
tag_data['create_time'] = datetime.now().isoformat()
result = self.tags_collection.insert_one(tag_data)
logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}")
return tag_data.get('tag_code')
except Exception as e:
logger.error(f"保存标签失败: {str(e)}", exc_info=True)
return None
def get_tag_by_name(self, tag_name: str) -> Optional[Dict[str, Any]]:
"""根据标签名称获取标签信息
Args:
tag_name: 标签名称
Returns:
Optional[Dict[str, Any]]: 标签信息如果不存在返回None
"""
try:
tag = self.tags_collection.find_one({'tag_name': tag_name})
if tag:
tag.pop('_id', None)
return tag
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
def save_tag_stock_relations(self, tag_code: str, stock_codes: List[str], replace_all: bool = False) -> bool:
"""保存标签与股票的关联关系
Args:
tag_code: 标签代码
stock_codes: 股票代码列表
replace_all: 是否替换所有关联关系True=全量替换False=增量添加默认False
Returns:
bool: 是否保存成功
"""
try:
if not stock_codes:
logger.warning("股票代码列表为空")
return False
# 如果replace_all为True先删除该标签的所有旧关联关系
if replace_all:
self.tag_stock_relations_collection.delete_many({'tag_code': tag_code})
# 批量插入新关联关系使用upsert避免重复
relations = []
for stock_code in stock_codes:
# 检查是否已存在
existing = self.tag_stock_relations_collection.find_one({
'tag_code': tag_code,
'stock_code': stock_code
})
if not existing:
relations.append({
'tag_code': tag_code,
'stock_code': stock_code,
'create_time': datetime.now().isoformat()
})
if relations:
self.tag_stock_relations_collection.insert_many(relations)
logger.info(f"保存标签关联关系: tag_code={tag_code}, 新增股票数量={len(relations)}")
return True
except Exception as e:
logger.error(f"保存标签关联关系失败: {str(e)}", exc_info=True)
return False
def get_tag_stocks(self, tag_code: str) -> List[str]:
"""获取标签关联的所有股票代码
Args:
tag_code: 标签代码
Returns:
List[str]: 股票代码列表
"""
try:
relations = list(self.tag_stock_relations_collection.find({'tag_code': tag_code}))
stock_codes = [r.get('stock_code') for r in relations if r.get('stock_code')]
return stock_codes
except Exception as e:
logger.error(f"获取标签关联股票失败: {str(e)}")
return []
def update_tag_status(self, tag_code: str, status: str, progress: Optional[float] = None):
"""更新标签处理状态和进度
Args:
tag_code: 标签代码
status: 处理状态pending/processing/completed/failed
progress: 处理进度0-100可选
"""
try:
update_data = {
'status': status,
'update_time': datetime.now().isoformat()
}
if progress is not None:
update_data['progress'] = progress
self.tags_collection.update_one(
{'tag_code': tag_code},
{'$set': update_data}
)
logger.info(f"更新标签状态: tag_code={tag_code}, status={status}, progress={progress}")
except Exception as e:
logger.error(f"更新标签状态失败: {str(e)}")
def delete_tag(self, tag_code: str, tag_name: str, delete_analysis: bool = False) -> Dict[str, Any]:
"""删除标签及其相关数据
Args:
tag_code: 标签代码
tag_name: 标签名称用于删除分析结果
delete_analysis: 是否删除分析结果默认False
Returns:
Dict[str, Any]: 删除结果
- success: 是否成功
- deleted_tag: 是否删除了标签
- deleted_relations: 删除的关联关系数量
- deleted_analyses: 删除的分析结果数量如果delete_analysis=True
- error: 错误信息
"""
result = {
'success': False,
'deleted_tag': False,
'deleted_relations': 0,
'deleted_analyses': 0,
'error': None
}
try:
# 1. 删除标签与股票的关联关系(中间表)
relations_result = self.tag_stock_relations_collection.delete_many({'tag_code': tag_code})
result['deleted_relations'] = relations_result.deleted_count
logger.info(f"删除标签关联关系: tag_code={tag_code}, 删除数量={relations_result.deleted_count}")
# 2. 删除标签本身
tag_result = self.tags_collection.delete_one({'tag_code': tag_code})
result['deleted_tag'] = tag_result.deleted_count > 0
logger.info(f"删除标签: tag_code={tag_code}, tag_name={tag_name}, 是否删除={result['deleted_tag']}")
# 3. 可选:删除分析结果
if delete_analysis:
analyses_result = self.collection.delete_many({'tag_name': tag_name})
result['deleted_analyses'] = analyses_result.deleted_count
logger.info(f"删除分析结果: tag_name={tag_name}, 删除数量={analyses_result.deleted_count}")
result['success'] = True
return result
except Exception as e:
logger.error(f"删除标签失败: {str(e)}", exc_info=True)
result['error'] = str(e)
return result
def delete_tag_by_name(self, tag_name: str, delete_analysis: bool = False) -> Dict[str, Any]:
"""根据标签名称删除标签及其相关数据
Args:
tag_name: 标签名称
delete_analysis: 是否删除分析结果默认False
Returns:
Dict[str, Any]: 删除结果
- success: 是否成功
- tag_code: 标签代码如果找到
- deleted_tag: 是否删除了标签
- deleted_relations: 删除的关联关系数量
- deleted_analyses: 删除的分析结果数量如果delete_analysis=True
- error: 错误信息
"""
result = {
'success': False,
'tag_code': None,
'deleted_tag': False,
'deleted_relations': 0,
'deleted_analyses': 0,
'error': None
}
try:
# 先查找标签
tag = self.get_tag_by_name(tag_name)
if not tag:
result['error'] = f"标签不存在: {tag_name}"
logger.warning(result['error'])
return result
tag_code = tag.get('tag_code')
result['tag_code'] = tag_code
# 调用根据tag_code删除的方法
delete_result = self.delete_tag(tag_code, tag_name, delete_analysis)
result.update(delete_result)
return result
except Exception as e:
logger.error(f"根据标签名称删除标签失败: {str(e)}", exc_info=True)
result['error'] = str(e)
return result
def close_connection(self):
"""关闭数据库连接"""
try:
if self.mongo_client:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
except Exception as e:
logger.error(f"关闭MongoDB连接失败: {str(e)}")

View File

@ -0,0 +1,222 @@
"""
标签关联分析服务类
整合分析器和数据库操作提供完整的分析+存储功能
"""
import logging
import sys
import os
from typing import Dict, Any, Optional, List
# 处理相对导入和绝对导入
try:
from .tag_relation_analyzer import TagRelationAnalyzer
from .tag_relation_database import TagRelationDatabase
except ImportError:
# 如果相对导入失败,使用绝对导入
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from stock_tag_analysis.tag_relation_analyzer import TagRelationAnalyzer
from stock_tag_analysis.tag_relation_database import TagRelationDatabase
# 设置日志记录
logger = logging.getLogger(__name__)
class TagRelationService:
"""标签关联分析服务类"""
def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True, collection_name: Optional[str] = None):
"""初始化服务
Args:
model_type: 大模型类型默认为 ds3_2
enable_web_search: 是否启用联网搜索默认True
collection_name: MongoDB集合名称如果不指定则使用默认集合名
"""
# 初始化分析器
self.analyzer = TagRelationAnalyzer(
model_type=model_type,
enable_web_search=enable_web_search
)
# 初始化数据库
self.database = TagRelationDatabase(collection_name=collection_name)
logger.info("TagRelationService 初始化完成")
def analyze_and_save(
self,
stock_code: str,
stock_name: str,
tag_name: str,
temperature: float = 0.7,
max_tokens: int = 4096,
save_to_db: bool = True
) -> Dict[str, Any]:
"""分析个股与标签的关联性并保存到数据库
Args:
stock_code: 股票代码300750.SZ
stock_name: 股票名称宁德时代
tag_name: 标签名称医药新能源金融科技等
temperature: 大模型温度参数默认0.7
max_tokens: 最大token数默认4096
save_to_db: 是否保存到数据库默认True
Returns:
Dict[str, Any]: 分析结果包含所有分析字段和保存状态
"""
try:
# 执行分析
logger.info(f"开始分析 {stock_name}({stock_code}) 与标签 {tag_name} 的关联性")
analysis_result = self.analyzer.analyze_tag_relation(
stock_code=stock_code,
stock_name=stock_name,
tag_name=tag_name,
temperature=temperature,
max_tokens=max_tokens
)
# 保存到数据库
if save_to_db and analysis_result.get('success'):
save_success = self.database.save_analysis_result(analysis_result)
analysis_result['saved_to_db'] = save_success
if save_success:
logger.info(f"分析结果已保存到数据库: {stock_code} - {tag_name}")
else:
logger.warning(f"分析结果保存失败: {stock_code} - {tag_name}")
else:
analysis_result['saved_to_db'] = False
if not save_to_db:
logger.info(f"跳过数据库保存: {stock_code} - {tag_name}")
return analysis_result
except Exception as e:
logger.error(f"分析并保存过程中出错: {str(e)}", exc_info=True)
return {
"success": False,
"stock_code": stock_code,
"stock_name": stock_name,
"tag_name": tag_name,
"error": str(e),
"saved_to_db": False
}
def batch_analyze_and_save(
self,
stock_list: List[Dict[str, str]],
tag_name: str,
temperature: float = 0.7,
max_tokens: int = 4096,
save_to_db: bool = True
) -> List[Dict[str, Any]]:
"""批量分析多只股票与标签的关联性并保存到数据库
Args:
stock_list: 股票列表每个元素包含 stock_code stock_name
tag_name: 标签名称
temperature: 大模型温度参数
max_tokens: 最大token数
save_to_db: 是否保存到数据库默认True
Returns:
List[Dict[str, Any]]: 分析结果列表
"""
results = []
total = len(stock_list)
for idx, stock in enumerate(stock_list, 1):
stock_code = stock.get("stock_code")
stock_name = stock.get("stock_name")
if not stock_code or not stock_name:
logger.warning(f"跳过无效股票信息: {stock}")
continue
logger.info(f"处理进度: {idx}/{total} - {stock_name}({stock_code})")
result = self.analyze_and_save(
stock_code=stock_code,
stock_name=stock_name,
tag_name=tag_name,
temperature=temperature,
max_tokens=max_tokens,
save_to_db=save_to_db
)
results.append(result)
# 统计结果
success_count = sum(1 for r in results if r.get('success'))
saved_count = sum(1 for r in results if r.get('saved_to_db'))
logger.info(f"批量分析完成 - 总数: {len(results)}, 成功: {success_count}, 已保存: {saved_count}")
return results
def get_analysis_result(self, stock_code: str, tag_name: str) -> Optional[Dict[str, Any]]:
"""从数据库获取分析结果
Args:
stock_code: 股票代码
tag_name: 标签名称
Returns:
Optional[Dict[str, Any]]: 分析结果如果不存在返回None
"""
return self.database.get_analysis_result(stock_code, tag_name)
def get_stock_analyses(self, stock_code: str) -> list:
"""获取指定股票的所有标签分析结果
Args:
stock_code: 股票代码
Returns:
list: 分析结果列表
"""
return self.database.get_stock_analyses(stock_code)
def get_tag_analyses(self, tag_name: str) -> list:
"""获取指定标签的所有股票分析结果
Args:
tag_name: 标签名称
Returns:
list: 分析结果列表
"""
return self.database.get_tag_analyses(tag_name)
def close(self):
"""关闭数据库连接"""
self.database.close_connection()
if __name__ == "__main__":
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 测试示例
service = TagRelationService()
# 测试:分析宁德时代与医药标签的关联性并保存
result = service.analyze_and_save(
stock_code="000547.SZ",
stock_name="航天发展",
tag_name="商业航天"
)
print("\n" + "="*50)
print("分析结果:")
print("="*50)
print(f"成功: {result.get('success')}")
print(f"已保存: {result.get('saved_to_db')}")
print(f"关联度评分: {result.get('relation_score')}")
print(f"结论: {result.get('conclusion')}")
# 关闭连接
service.close()

View File

@ -468,7 +468,7 @@ if __name__ == "__main__":
# 4. 采集指定日期范围的数据 # 4. 采集指定日期范围的数据
collect_chip_distribution(db_url, tushare_token, mode='full', collect_chip_distribution(db_url, tushare_token, mode='full',
start_date='2021-01-03', end_date='2021-09-30') start_date='2025-11-28', end_date='2025-12-09')
# 5. 调整批量入库大小默认100只股票一批 # 5. 调整批量入库大小默认100只股票一批
# collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200) # collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200)

View File

@ -43,6 +43,27 @@ MONGO_CONFIG2 = {
'collection': 'wind_financial_analysis' 'collection': 'wind_financial_analysis'
} }
# --- MySQL配置用于个股标签关联分析模块
DB_CONFIG_TAG_ANALYSIS = {
'host': '192.168.18.199',
'port': 3306,
'user': 'root',
'password': 'Chlry#$.8',
'database': 'db_gp_cj'
}
# 创建数据库连接URL用于个股标签关联分析模块
DB_URL_TAG_ANALYSIS = f"mysql+pymysql://{DB_CONFIG_TAG_ANALYSIS['user']}:{DB_CONFIG_TAG_ANALYSIS['password']}@{DB_CONFIG_TAG_ANALYSIS['host']}:{DB_CONFIG_TAG_ANALYSIS['port']}/{DB_CONFIG_TAG_ANALYSIS['database']}"
# MongoDB配置用于个股标签关联分析模块
MONGO_CONFIG222 = {
'host': '192.168.16.222',
'port': 27017,
'db': 'stock_predictions',
'username': 'stockvix',
'password': '!stock@vix2',
'collection': 'tag_relation_analysis' # 默认集合名
}
# 项目根目录 # 项目根目录
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))