diff --git a/src/app.py b/src/app.py index ce4a7de..c1a3eaf 100644 --- a/src/app.py +++ b/src/app.py @@ -123,6 +123,15 @@ fear_greed_manager = FearGreedIndexManager() # 创建指数分析器实例 index_analyzer = IndexAnalyzer() +# 创建标签关联分析API实例 +try: + from src.stock_tag_analysis.tag_relation_api import TagRelationAPI + tag_relation_api = TagRelationAPI() + logger.info("标签关联分析API初始化成功") +except Exception as e: + logger.error(f"标签关联分析API初始化失败: {str(e)}") + tag_relation_api = None + # 获取项目根目录 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) REPORTS_DIR = os.path.join(ROOT_DIR, 'src', 'reports') @@ -3887,6 +3896,240 @@ def run_tech_fundamental_strategy_batch(): return jsonify({"status": "error", "message": str(e)}), 500 +@app.route('/api/tag/process', methods=['POST']) +def process_tag_relation(): + """处理标签关联分析接口 + + 请求体参数: + - tag_name: 标签名称(必填) + - tag_type: 标签类型(可选,默认为None,会自动判断) + - model_type: 大模型类型(可选,默认为ds3_2) + - enable_web_search: 是否启用联网搜索(可选,默认为True) + - temperature: 大模型温度参数(可选,默认为0.7) + - max_tokens: 最大token数(可选,默认为4096) + + 返回: + { + "status": "success", + "data": { + "success": true, + "tag_code": "123", + "tag_type": "AI标签", + "source": "ai", + "stock_count": 0, + "message": "AI分析任务已加入队列,共5000只股票待处理", + "error": null, + "status": "processing" + } + } + """ + try: + # 检查API是否初始化成功 + if tag_relation_api is None: + return jsonify({ + "status": "error", + "message": "标签关联分析API未初始化" + }), 500 + + # 获取参数 + data = request.get_json() if request.is_json else request.form + tag_name = data.get('tag_name') + tag_type = data.get('tag_type') + model_type = data.get('model_type', 'ds3_2') + enable_web_search = data.get('enable_web_search', True) + temperature = float(data.get('temperature', 0.7)) + max_tokens = int(data.get('max_tokens', 4096)) + + # 验证必要参数 + if not tag_name: + return jsonify({ + "status": "error", + "message": "缺少必要参数: tag_name" + }), 400 + + # 调用处理方法 + result = tag_relation_api.process_tag( + tag_name=tag_name, + tag_type=tag_type, + model_type=model_type, + enable_web_search=enable_web_search, + temperature=temperature, + max_tokens=max_tokens + ) + + # 返回结果 + if result.get('success'): + return jsonify({ + "status": "success", + "data": result + }) + else: + return jsonify({ + "status": "error", + "message": result.get('message', '处理失败'), + "data": result + }), 500 + + except Exception as e: + logger.error(f"处理标签关联分析失败: {str(e)}", exc_info=True) + return jsonify({ + "status": "error", + "message": f"处理标签关联分析失败: {str(e)}" + }), 500 + + +@app.route('/api/tag/progress', methods=['GET']) +def get_tag_progress(): + """获取标签处理进度接口 + + 参数: + - tag_name: 标签名称(必填) + + 返回: + { + "status": "success", + "data": { + "tag_code": "123", + "tag_name": "商业航天", + "status": "processing", + "progress": 45.5, + "total_stocks": 5000, + "processed": 2275, + "remaining": 2725, + "queue_length": 5000 + } + } + """ + try: + # 检查API是否初始化成功 + if tag_relation_api is None: + return jsonify({ + "status": "error", + "message": "标签关联分析API未初始化" + }), 500 + + # 获取参数 + tag_name = request.args.get('tag_name') + + # 验证参数 + if not tag_name: + return jsonify({ + "status": "error", + "message": "缺少必要参数: tag_name" + }), 400 + + # 调用获取进度方法 + result = tag_relation_api.get_tag_progress(tag_name=tag_name) + + # 检查是否有错误 + if 'error' in result: + return jsonify({ + "status": "error", + "message": result.get('error', '获取进度失败') + }), 404 + + # 返回结果 + return jsonify({ + "status": "success", + "data": result + }) + + except Exception as e: + logger.error(f"获取标签进度失败: {str(e)}", exc_info=True) + return jsonify({ + "status": "error", + "message": f"获取标签进度失败: {str(e)}" + }), 500 + + +@app.route('/api/tag/delete', methods=['POST', 'DELETE']) +def delete_tag_relation(): + """删除标签及其相关数据接口 + + 请求参数: + - tag_name: 标签名称(必填) + - delete_analysis: 是否删除分析结果(可选,默认为False,只删除标签和关联关系) + + 返回: + { + "status": "success", + "data": { + "success": true, + "tag_code": "123", + "deleted_tag": true, + "deleted_relations": 50, + "deleted_analyses": 0, + "deleted_redis_keys": 2, + "message": "删除成功: 标签=true, 关联关系=50条, Redis键=2个", + "error": null + } + } + """ + try: + # 检查API是否初始化成功 + if tag_relation_api is None: + return jsonify({ + "status": "error", + "message": "标签关联分析API未初始化" + }), 500 + + # 获取参数 + if request.is_json: + data = request.get_json() + else: + data = request.form + + tag_name = data.get('tag_name') + delete_analysis = data.get('delete_analysis', False) + + # 处理delete_analysis参数(可能是字符串"true"/"false") + if isinstance(delete_analysis, str): + delete_analysis = delete_analysis.lower() in ('true', '1', 'yes') + else: + delete_analysis = bool(delete_analysis) + + # 验证必要参数 + if not tag_name: + return jsonify({ + "status": "error", + "message": "缺少必要参数: tag_name" + }), 400 + + # 调用删除方法 + result = tag_relation_api.delete_tag_by_name( + tag_name=tag_name, + delete_analysis=delete_analysis + ) + + # 返回结果 + if result.get('success'): + return jsonify({ + "status": "success", + "data": result + }) + else: + # 如果是标签不存在,返回404 + if "不存在" in result.get('message', ''): + return jsonify({ + "status": "error", + "message": result.get('message', '标签不存在'), + "data": result + }), 404 + else: + return jsonify({ + "status": "error", + "message": result.get('message', '删除失败'), + "data": result + }), 500 + + except Exception as e: + logger.error(f"删除标签失败: {str(e)}", exc_info=True) + return jsonify({ + "status": "error", + "message": f"删除标签失败: {str(e)}" + }), 500 + + if __name__ == '__main__': # 启动Web服务器 diff --git a/src/scripts/config.py b/src/scripts/config.py index c11aa72..a8091e8 100644 --- a/src/scripts/config.py +++ b/src/scripts/config.py @@ -63,7 +63,8 @@ MODEL_CONFIGS = { "models": { "offline_model": "ep-20250326090920-v7wns", "online_bot": "bot-20250325102825-h9kpq", - "doubao": "doubao-1-5-pro-32k-250115" + "doubao": "doubao-1-5-pro-32k-250115", + "ds3_2": "ep-20260104113811-m2h2f" } }, # 谷歌Gemini diff --git a/src/stock_tag_analysis/chat_bot.py b/src/stock_tag_analysis/chat_bot.py new file mode 100644 index 0000000..435401d --- /dev/null +++ b/src/stock_tag_analysis/chat_bot.py @@ -0,0 +1,276 @@ +import logging +import os +import sys +import time +from typing import Dict, Any +from datetime import datetime + +from openai import OpenAI + +# 设置日志记录 +logger = logging.getLogger(__name__) + +# 获取项目根目录 +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# 导入配置 +sys.path.append(os.path.dirname(ROOT_DIR)) +sys.path.append(ROOT_DIR) + +try: + from scripts.config import get_random_api_key, get_model +except ImportError: + from src.scripts.config import get_random_api_key, get_model + + +class ChatBot: + def __init__(self, model_type: str = "online_bot", enable_web_search: bool = True): + """初始化聊天机器人 + + Args: + model_type: 模型类型,默认为 online_bot + enable_web_search: 是否启用联网搜索,默认True + """ + self.api_key = get_random_api_key() + self.model = get_model(model_type) + self.enable_web_search = enable_web_search + + logger.info(f"初始化ChatBot,模型: {self.model}, 联网搜索: {self.enable_web_search}") + + # 初始化客户端 + self.client = OpenAI( + base_url="https://ark.cn-beijing.volces.com/api/v3", + api_key=self.api_key + ) + + # 联网搜索工具配置 + self.tools = [{ + "type": "web_search", + "limit": 10, + "sources": ["douyin", "moji", "toutiao"], + "user_location": { + "type": "approximate", + "country": "中国", + "region": "北京", + "city": "北京" + } + }] if enable_web_search else None + + # 系统提示语 + self.system_message = """你是一个专业的股票分析助手,擅长进行深入的基本面分析。你的分析应该: + +1. 专业严谨 - 使用准确的专业术语,引用可靠的数据来源,分析逻辑清晰,结论有理有据 +2. 全面细致 - 深入分析问题的各个方面,关注细节和关键信息,考虑多个影响因素 +3. 客观中立 - 保持独立判断,不夸大或贬低,平衡利弊分析,指出潜在风险 +4. 实用性强 - 分析结论具体明确,建议具有可操作性,关注实际投资价值 +5. 及时更新 - 关注最新信息,反映市场变化,动态调整分析,保持信息时效性 + +请根据用户的具体需求,提供专业、深入的分析。如果遇到不确定的信息,请明确说明。""" + + # 对话历史(简化版,只保存用户和助手消息) + self.conversation_history = [] + + def format_reference(self, ref): + """格式化参考资料""" + if isinstance(ref, str): + return ref + elif isinstance(ref, dict): + parts = [] + if ref.get('title'): + parts.append(f"标题:{ref['title']}") + if ref.get('summary'): + parts.append(f"摘要:{ref['summary']}") + if ref.get('url'): + parts.append(f"链接:{ref['url']}") + if ref.get('publish_time'): + parts.append(f"发布时间:{ref['publish_time']}") + return "\n".join(parts) + return str(ref) + + def chat(self, user_input: str, temperature: float = 0.7, top_p: float = 0.7, max_tokens: int = 4096, frequency_penalty: float = 0.0) -> Dict[str, Any]: + """与AI进行对话 + + Args: + user_input: 用户输入的问题 + temperature: 控制输出的随机性,范围0-2,默认0.7 + top_p: 控制输出的多样性,范围0-1,默认0.7 + max_tokens: 控制输出的最大长度,默认4096 + frequency_penalty: 控制重复惩罚,范围-2到2,默认0.0 + + Returns: + Dict[str, Any]: 包含 response, reasoning_process, references, tool_usage, tool_usage_details + """ + try: + # 添加用户消息到对话历史 + self.conversation_history.append({ + "role": "user", + "content": [{"type": "input_text", "text": user_input}] + }) + + # 构建输入消息 + input_messages = [ + # 系统提示词 + { + "role": "system", + "content": [{"type": "input_text", "text": self.system_message}] + } + ] + + # 添加最近的对话历史(保留最近5轮) + for msg in self.conversation_history[-5:]: + input_messages.append(msg) + + # 调用API(流式输出) + response = self.client.responses.create( + model=self.model, + input=input_messages, + tools=self.tools, + stream=True + ) + + # 收集回复 + full_response = "" + reasoning_text = "" + references = [] + tool_usage = None + tool_usage_details = None + search_keywords = [] + + # 状态变量 + thinking_started = False + answering_started = False + + print("\nAI: ", end="", flush=True) + + # 处理事件流 + for chunk in response: + chunk_type = getattr(chunk, "type", "") + + # 处理AI思考过程 + if chunk_type == "response.reasoning_summary_text.delta": + if not thinking_started: + thinking_started = True + delta = getattr(chunk, "delta", "") + if delta: + reasoning_text += delta + + # 处理搜索状态 + elif "web_search_call" in chunk_type: + if "in_progress" in chunk_type: + logger.debug("开始搜索") + elif "completed" in chunk_type: + logger.debug("搜索完成") + + # 处理搜索关键词 + elif (chunk_type == "response.output_item.done" + and hasattr(chunk, "item") + and str(getattr(chunk.item, "id", "")).startswith("ws_")): + if hasattr(chunk.item, "action") and hasattr(chunk.item.action, "query"): + search_keyword = chunk.item.action.query + search_keywords.append(search_keyword) + logger.debug(f"搜索关键词: {search_keyword}") + + # 处理最终回答 + elif chunk_type == "response.output_text.delta": + if not answering_started: + answering_started = True + delta = getattr(chunk, "delta", "") + if delta: + print(delta, end="", flush=True) + full_response += delta + time.sleep(0.01) + + # 处理完成事件,提取工具使用统计 + elif chunk_type == "response.completed": + if hasattr(chunk, "response") and hasattr(chunk.response, "usage"): + usage = chunk.response.usage + if hasattr(usage, "tool_usage"): + tool_usage = usage.tool_usage + if hasattr(usage, "tool_usage_details"): + tool_usage_details = usage.tool_usage_details + + print() # 换行 + + # 提取推理过程和参考资料 + main_response = full_response + reasoning_process = reasoning_text + + if "推理过程:" in full_response: + parts = full_response.split("推理过程:") + main_response = parts[0].strip() + if len(parts) > 1: + reasoning_parts = parts[1].split("参考资料:") + reasoning_process = reasoning_parts[0].strip() if reasoning_parts else reasoning_text + + # 添加AI回复到对话历史 + self.conversation_history.append({ + "role": "assistant", + "content": [{"type": "input_text", "text": main_response}] + }) + + return { + "response": main_response, + "reasoning_process": reasoning_process, + "references": references, + "tool_usage": tool_usage, + "tool_usage_details": tool_usage_details, + "search_keywords": search_keywords + } + + except Exception as e: + logger.error(f"对话失败: {str(e)}", exc_info=True) + return { + "response": f"抱歉,处理您的请求时出现错误: {str(e)}", + "reasoning_process": "", + "references": [], + "tool_usage": None, + "tool_usage_details": None, + "search_keywords": [] + } + + def clear_history(self): + """清除对话历史""" + self.conversation_history = [] + print("对话历史已清除") + + def run(self): + """运行聊天机器人""" + print("欢迎使用AI助手!输入 'quit' 退出,输入 'clear' 清除对话历史。") + if self.enable_web_search: + print("该版本支持联网搜索,可以回答实时信息。") + print("-" * 50) + + while True: + try: + user_input = input("\n你: ").strip() + + if user_input.lower() == 'quit': + print("感谢使用,再见!") + break + + if user_input.lower() == 'clear': + self.clear_history() + continue + + if not user_input: + continue + + self.chat(user_input) + print("-" * 50) + + except KeyboardInterrupt: + print("\n感谢使用,再见!") + break + except Exception as e: + logger.error(f"运行错误: {str(e)}") + print(f"发生错误: {str(e)}") + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + bot = ChatBot(enable_web_search=True) + bot.run() diff --git a/src/stock_tag_analysis/tag_relation_analyzer.py b/src/stock_tag_analysis/tag_relation_analyzer.py new file mode 100644 index 0000000..cb742f4 --- /dev/null +++ b/src/stock_tag_analysis/tag_relation_analyzer.py @@ -0,0 +1,323 @@ +""" +个股标签关联分析器 +使用联网大模型分析个股与标签的关联性,并输出结构化结果 +""" + +import logging +import json +import re +from typing import Dict, Any, Optional, List +from datetime import datetime + +# 导入联网大模型工具 +import sys +import os + +try: + from .chat_bot import ChatBot +except ImportError: + try: + from src.stock_tag_analysis.chat_bot import ChatBot + except ImportError: + # 如果相对导入失败,添加路径并使用绝对导入 + current_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, current_dir) + from chat_bot import ChatBot + +# 设置日志记录 +logger = logging.getLogger(__name__) + + +class TagRelationAnalyzer: + """个股标签关联分析器""" + + def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True): + """初始化分析器 + + Args: + model_type: 大模型类型,默认为 ds3_2(联网智能体) + enable_web_search: 是否启用联网搜索,默认True + """ + self.chat_bot = ChatBot(model_type=model_type, enable_web_search=enable_web_search) + logger.info(f"TagRelationAnalyzer 初始化完成,模型: {model_type}, 联网搜索: {enable_web_search}") + + def _build_analysis_prompt(self, stock_code: str, stock_name: str, tag_name: str) -> str: + """构建分析提示词 + + Args: + stock_code: 股票代码 + stock_name: 股票名称 + tag_name: 标签名称 + + Returns: + str: 分析提示词 + """ + prompt = f"""请分析股票"{stock_name}"(股票代码:{stock_code})与标签"{tag_name}"是否存在关联性。 + +请通过联网搜索获取最新的信息,包括但不限于: +1. 公司主营业务是否与标签相关 +2. 公司产品、服务是否涉及标签领域 +3. 公司是否在标签相关行业有布局 +4. 公司公告、新闻中是否提及标签相关内容 +5. 公司财务数据是否显示与标签相关的业务占比 + +请按照以下JSON格式输出分析结果: +{{ + "has_relation": true/false, // 是否存在关联性 + "relation_score": 0-100, // 关联度评分(0-100,数值越高关联度越高) + "relation_type": "直接关联/间接关联/无关联", // 关联类型 + "analysis_summary": "简要分析总结(100字以内)", + "detailed_analysis": "详细分析内容(包括关联依据、业务占比、市场表现等)", + "key_evidence": [ // 关键证据列表 + {{ + "evidence_type": "主营业务/产品服务/行业布局/公告新闻/财务数据", + "description": "证据描述", + "relevance": "高/中/低" + }} + ], + "business_ratio": "与标签相关的业务占比(如:30%、主要业务、次要业务等,如无相关信息则填写'未知')", + "market_performance": "市场表现相关描述(如:该标签概念股表现、行业地位等,如无相关信息则填写'未知')", + "conclusion": "最终结论(50字以内)" +}} + +请确保输出的是有效的JSON格式,不要包含任何其他文字说明。""" + return prompt + + def _parse_json_response(self, response_text: str) -> Optional[Dict[str, Any]]: + """从响应文本中解析JSON + + Args: + response_text: 响应文本 + + Returns: + Optional[Dict[str, Any]]: 解析后的JSON字典,如果解析失败返回None + """ + try: + # 尝试直接解析JSON + # 查找JSON代码块 + json_match = re.search(r'```json\s*(\{.*?\})\s*```', response_text, re.DOTALL) + if json_match: + json_str = json_match.group(1) + else: + # 查找大括号包裹的JSON + json_match = re.search(r'\{.*\}', response_text, re.DOTALL) + if json_match: + json_str = json_match.group(0) + else: + # 尝试直接解析整个响应 + json_str = response_text.strip() + + # 清理可能的Markdown标记 + json_str = json_str.strip() + if json_str.startswith('```'): + json_str = re.sub(r'^```[a-z]*\n?', '', json_str) + if json_str.endswith('```'): + json_str = re.sub(r'\n?```$', '', json_str) + + # 解析JSON + result = json.loads(json_str) + return result + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {str(e)}") + logger.error(f"响应文本: {response_text[:500]}") + return None + except Exception as e: + logger.error(f"解析响应时出错: {str(e)}") + return None + + def analyze_tag_relation( + self, + stock_code: str, + stock_name: str, + tag_name: str, + temperature: float = 0.7, + max_tokens: int = 4096 + ) -> Dict[str, Any]: + """分析个股与标签的关联性 + + Args: + stock_code: 股票代码(如:300750.SZ) + stock_name: 股票名称(如:宁德时代) + tag_name: 标签名称(如:医药、新能源、金融科技等) + temperature: 大模型温度参数,默认0.7 + max_tokens: 最大token数,默认4096 + + Returns: + Dict[str, Any]: 分析结果,包含以下字段: + - success: 是否成功 + - stock_code: 股票代码 + - stock_name: 股票名称 + - tag_name: 标签名称 + - has_relation: 是否存在关联性 + - relation_score: 关联度评分(0-100) + - relation_type: 关联类型 + - analysis_summary: 简要分析总结 + - detailed_analysis: 详细分析内容 + - key_evidence: 关键证据列表 + - business_ratio: 业务占比 + - market_performance: 市场表现 + - conclusion: 最终结论 + - raw_response: 原始响应 + - reasoning_process: 推理过程 + - references: 参考资料 + - analysis_time: 分析时间 + - error: 错误信息(如果失败) + """ + try: + logger.info(f"开始分析 {stock_name}({stock_code}) 与标签 {tag_name} 的关联性") + + # 构建分析提示词 + prompt = self._build_analysis_prompt(stock_code, stock_name, tag_name) + + # 调用大模型进行分析 + result = self.chat_bot.chat( + user_input=prompt, + temperature=temperature, + max_tokens=max_tokens + ) + + # 解析响应 + response_text = result.get("response", "") + reasoning_process = result.get("reasoning_process", "") + references = result.get("references", []) + + # 解析JSON结果 + parsed_result = self._parse_json_response(response_text) + + if parsed_result is None: + # 如果解析失败,返回原始响应 + logger.warning("JSON解析失败,返回原始响应") + return { + "success": False, + "stock_code": stock_code, + "stock_name": stock_name, + "tag_name": tag_name, + "has_relation": None, + "relation_score": None, + "relation_type": None, + "analysis_summary": None, + "detailed_analysis": response_text, + "key_evidence": [], + "business_ratio": None, + "market_performance": None, + "conclusion": None, + "raw_response": response_text, + "reasoning_process": reasoning_process, + "references": references, + "analysis_time": datetime.now().isoformat(), + "error": "JSON解析失败,请检查响应格式" + } + + # 构建完整结果 + analysis_result = { + "success": True, + "stock_code": stock_code, + "stock_name": stock_name, + "tag_name": tag_name, + "has_relation": parsed_result.get("has_relation", False), + "relation_score": parsed_result.get("relation_score", 0), + "relation_type": parsed_result.get("relation_type", "无关联"), + "analysis_summary": parsed_result.get("analysis_summary", ""), + "detailed_analysis": parsed_result.get("detailed_analysis", ""), + "key_evidence": parsed_result.get("key_evidence", []), + "business_ratio": parsed_result.get("business_ratio", "未知"), + "market_performance": parsed_result.get("market_performance", "未知"), + "conclusion": parsed_result.get("conclusion", ""), + "raw_response": response_text, + "reasoning_process": reasoning_process, + "references": references, + "analysis_time": datetime.now().isoformat(), + "error": None + } + + logger.info(f"分析完成: {stock_name}({stock_code}) 与 {tag_name} 的关联度评分为 {analysis_result['relation_score']}") + + return analysis_result + + except Exception as e: + logger.error(f"分析过程中出错: {str(e)}", exc_info=True) + return { + "success": False, + "stock_code": stock_code, + "stock_name": stock_name, + "tag_name": tag_name, + "has_relation": None, + "relation_score": None, + "relation_type": None, + "analysis_summary": None, + "detailed_analysis": None, + "key_evidence": [], + "business_ratio": None, + "market_performance": None, + "conclusion": None, + "raw_response": None, + "reasoning_process": None, + "references": [], + "analysis_time": datetime.now().isoformat(), + "error": str(e) + } + + def batch_analyze_tag_relation( + self, + stock_list: List[Dict[str, str]], + tag_name: str, + temperature: float = 0.7, + max_tokens: int = 4096 + ) -> List[Dict[str, Any]]: + """批量分析多只股票与标签的关联性 + + Args: + stock_list: 股票列表,每个元素包含 stock_code 和 stock_name + tag_name: 标签名称 + temperature: 大模型温度参数 + max_tokens: 最大token数 + + Returns: + List[Dict[str, Any]]: 分析结果列表 + """ + results = [] + for stock in stock_list: + stock_code = stock.get("stock_code") + stock_name = stock.get("stock_name") + + if not stock_code or not stock_name: + logger.warning(f"跳过无效股票信息: {stock}") + continue + + result = self.analyze_tag_relation( + stock_code=stock_code, + stock_name=stock_name, + tag_name=tag_name, + temperature=temperature, + max_tokens=max_tokens + ) + results.append(result) + + # 清除对话历史,避免上下文干扰 + self.chat_bot.clear_history() + + return results + + +if __name__ == "__main__": + # 设置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # 测试示例 + analyzer = TagRelationAnalyzer() + + # 测试:分析宁德时代与医药标签的关联性 + result = analyzer.analyze_tag_relation( + stock_code="300750.SZ", + stock_name="宁德时代", + tag_name="医药" + ) + + print("\n" + "="*50) + print("分析结果:") + print("="*50) + print(json.dumps(result, ensure_ascii=False, indent=2)) + diff --git a/src/stock_tag_analysis/tag_relation_api.py b/src/stock_tag_analysis/tag_relation_api.py new file mode 100644 index 0000000..d941cd7 --- /dev/null +++ b/src/stock_tag_analysis/tag_relation_api.py @@ -0,0 +1,1002 @@ +""" +标签关联分析API接口 +提供Web端调用入口 +""" + +import logging +import sys +import os +import json +import threading +import time +from typing import Dict, Any, List, Optional +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +import redis + +# 导入配置 +try: + from valuation_analysis.config import DB_CONFIG_TAG_ANALYSIS, MONGO_CONFIG222, REDIS_CONFIG +except ImportError: + try: + from src.valuation_analysis.config import DB_CONFIG_TAG_ANALYSIS, MONGO_CONFIG222, REDIS_CONFIG + except ImportError: + # 添加项目路径 + ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + sys.path.append(ROOT_DIR) + from src.valuation_analysis.config import DB_CONFIG_TAG_ANALYSIS, MONGO_CONFIG222, REDIS_CONFIG + +# 导入服务类 +try: + from .tag_relation_service import TagRelationService +except ImportError: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from stock_tag_analysis.tag_relation_service import TagRelationService + +# 设置日志记录 +logger = logging.getLogger(__name__) + + +class TagRelationAPI: + """标签关联分析API类""" + + def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True): + """初始化API + + Args: + model_type: 大模型类型,默认为 ds3_2 + enable_web_search: 是否启用联网搜索,默认True + """ + # 初始化服务 + self.service = TagRelationService( + model_type=model_type, + enable_web_search=enable_web_search + ) + + # 初始化MySQL连接 + self.mysql_engine = None + self.connect_mysql() + + # 初始化Redis连接(用于任务队列) + self.redis_client = None + self.connect_redis() + + # 全局队列key + self.global_queue_key = "tag_analysis_queue:global" + + # 线程控制 + self.global_processing_thread = None + self.stop_processing = False + self.thread_lock = threading.Lock() + + # 启动全局处理线程 + self._start_global_processing_thread() + + logger.info("TagRelationAPI 初始化完成,全局处理线程已启动") + + def connect_mysql(self): + """连接MySQL数据库""" + try: + db_url = f"mysql+pymysql://{DB_CONFIG_TAG_ANALYSIS['user']}:{DB_CONFIG_TAG_ANALYSIS['password']}@{DB_CONFIG_TAG_ANALYSIS['host']}:{DB_CONFIG_TAG_ANALYSIS['port']}/{DB_CONFIG_TAG_ANALYSIS['database']}" + + self.mysql_engine = create_engine( + db_url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + echo=False + ) + + # 测试连接 + with self.mysql_engine.connect() as conn: + conn.execute(text("SELECT 1")) + + logger.info(f"MySQL数据库连接成功: {DB_CONFIG_TAG_ANALYSIS['host']}:{DB_CONFIG_TAG_ANALYSIS['port']}/{DB_CONFIG_TAG_ANALYSIS['database']}") + + except Exception as e: + logger.error(f"MySQL数据库连接失败: {str(e)}") + raise + + def connect_redis(self): + """连接Redis数据库""" + try: + self.redis_client = redis.Redis( + host=REDIS_CONFIG['host'], + port=REDIS_CONFIG['port'], + password=REDIS_CONFIG.get('password'), + db=REDIS_CONFIG.get('db', 13), + socket_timeout=REDIS_CONFIG.get('socket_timeout', 5), + decode_responses=True + ) + + # 测试连接 + self.redis_client.ping() + logger.info(f"Redis连接成功: {REDIS_CONFIG['host']}:{REDIS_CONFIG['port']}") + + except Exception as e: + logger.error(f"Redis连接失败: {str(e)}") + raise + + def analyze_tag_for_all_stocks( + self, + tag_name: str, + model_type: str = "ds3_2", + enable_web_search: bool = True, + temperature: float = 0.7, + max_tokens: int = 4096 + ) -> Dict[str, Any]: + """分析所有股票与指定标签的关联性并保存到MongoDB + + Args: + tag_name: 标签名称 + model_type: 大模型类型,默认ds3_2 + enable_web_search: 是否启用联网搜索,默认True + temperature: 大模型温度参数,默认0.7 + max_tokens: 最大token数,默认4096 + + Returns: + Dict[str, Any]: 包含处理结果的字典 + - success: 是否成功 + - total_stocks: 总股票数 + - processed_stocks: 已处理股票数 + - success_count: 成功分析数 + - failed_count: 失败数 + - saved_count: 已保存数 + - results: 详细结果列表 + - error: 错误信息 + """ + try: + logger.info(f"开始分析所有股票与标签 {tag_name} 的关联性") + + # 查询所有股票(排除黑名单) + stock_list = self._get_stock_list() + + if not stock_list: + logger.warning("未获取到股票列表") + return { + "success": False, + "total_stocks": 0, + "processed_stocks": 0, + "success_count": 0, + "failed_count": 0, + "saved_count": 0, + "results": [], + "error": "未获取到股票列表" + } + + logger.info(f"获取到 {len(stock_list)} 只股票,开始批量分析") + + # 批量分析并保存 + results = self.service.batch_analyze_and_save( + stock_list=stock_list, + tag_name=tag_name, + temperature=temperature, + max_tokens=max_tokens, + save_to_db=True + ) + + # 统计结果 + success_count = sum(1 for r in results if r.get('success')) + failed_count = sum(1 for r in results if not r.get('success')) + saved_count = sum(1 for r in results if r.get('saved_to_db')) + + logger.info(f"分析完成 - 总数: {len(results)}, 成功: {success_count}, 失败: {failed_count}, 已保存: {saved_count}") + + return { + "success": True, + "total_stocks": len(stock_list), + "processed_stocks": len(results), + "success_count": success_count, + "failed_count": failed_count, + "saved_count": saved_count, + "results": results, + "error": None + } + + except Exception as e: + logger.error(f"分析过程中出错: {str(e)}", exc_info=True) + return { + "success": False, + "total_stocks": 0, + "processed_stocks": 0, + "success_count": 0, + "failed_count": 0, + "saved_count": 0, + "results": [], + "error": str(e) + } + + def _get_stock_list(self) -> List[Dict[str, str]]: + """从MySQL数据库获取股票列表(排除黑名单、ST股票、BSE交易所) + + Returns: + List[Dict[str, str]]: 股票列表,每个元素包含 stock_code 和 stock_name + """ + try: + # SQL查询:获取gp_code_all表中的股票,排除黑名单、ST股票、BSE交易所 + sql = """ + SELECT + g.gp_code_two AS stock_code, + g.gp_name AS stock_name + FROM gp_code_all g + LEFT JOIN gp_blacklist b ON g.gp_code = b.stock_code + WHERE b.stock_code IS NULL + AND g.gp_code_two IS NOT NULL + AND g.gp_name IS NOT NULL + AND g.gp_name NOT LIKE '%ST%' + AND (g.exchange IS NULL OR g.exchange != 'BSE') + ORDER BY g.gp_code_two + """ + + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql)) + rows = result.fetchall() + + stock_list = [] + for row in rows: + stock_list.append({ + "stock_code": row[0], + "stock_name": row[1] + }) + + logger.info(f"从数据库获取到 {len(stock_list)} 只股票(已排除黑名单)") + return stock_list + + except Exception as e: + logger.error(f"获取股票列表失败: {str(e)}", exc_info=True) + return [] + + def get_stock_list_count(self) -> int: + """获取股票总数(排除黑名单、ST股票、BSE交易所) + + Returns: + int: 股票总数 + """ + try: + sql = """ + SELECT COUNT(*) + FROM gp_code_all g + LEFT JOIN gp_blacklist b ON g.gp_code = b.stock_code + WHERE b.stock_code IS NULL + AND g.gp_code_two IS NOT NULL + AND g.gp_name IS NOT NULL + AND g.gp_name NOT LIKE '%ST%' + AND (g.exchange IS NULL OR g.exchange != 'BSE') + """ + + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql)) + count = result.scalar() + return count or 0 + + except Exception as e: + logger.error(f"获取股票总数失败: {str(e)}") + return 0 + + def _filter_valid_stocks(self, stock_codes: List[str]) -> List[str]: + """过滤有效的股票代码(排除黑名单、ST股票、BSE交易所) + + Args: + stock_codes: 股票代码列表(gp_code格式) + + Returns: + List[str]: 过滤后的股票代码列表(gp_code_two格式) + """ + try: + if not stock_codes: + return [] + + # 使用IN子句,构建参数化查询 + placeholders = ','.join([f':code{i}' for i in range(len(stock_codes))]) + sql = f""" + SELECT DISTINCT g.gp_code_two + FROM gp_code_all g + LEFT JOIN gp_blacklist b ON g.gp_code = b.stock_code + WHERE g.gp_code IN ({placeholders}) + AND b.stock_code IS NULL + AND g.gp_code_two IS NOT NULL + AND g.gp_name IS NOT NULL + AND g.gp_name NOT LIKE '%ST%' + AND (g.exchange IS NULL OR g.exchange != 'BSE') + """ + + # 构建参数字典 + params = {f'code{i}': code for i, code in enumerate(stock_codes)} + + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql), params) + rows = result.fetchall() + valid_stocks = [row[0] for row in rows if row[0]] + + logger.info(f"过滤股票: 输入{len(stock_codes)}只,有效{len(valid_stocks)}只") + return valid_stocks + + except Exception as e: + logger.error(f"过滤股票失败: {str(e)}", exc_info=True) + return [] + + def _query_hybk_stocks(self, tag_name: str) -> List[str]: + """从gp_hybk表查询标签关联的股票代码 + + Args: + tag_name: 标签名称(行业名称) + + Returns: + List[str]: 股票代码列表(gp_code格式) + """ + try: + sql = """ + SELECT DISTINCT gp_code + FROM gp_hybk + WHERE bk_name = :tag_name + AND gp_code IS NOT NULL + """ + + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql), {'tag_name': tag_name}) + rows = result.fetchall() + stock_codes = [row[0] for row in rows if row[0]] + + logger.info(f"从gp_hybk查询到标签'{tag_name}'关联{len(stock_codes)}只股票") + return stock_codes + + except Exception as e: + logger.error(f"查询gp_hybk失败: {str(e)}") + return [] + + def _query_gnbk_stocks(self, tag_name: str) -> List[str]: + """从gp_gnbk表查询标签关联的股票代码 + + Args: + tag_name: 标签名称(概念名称) + + Returns: + List[str]: 股票代码列表(gp_code格式) + """ + try: + sql = """ + SELECT DISTINCT gp_code + FROM gp_gnbk + WHERE bk_name = :tag_name + AND gp_code IS NOT NULL + """ + + with self.mysql_engine.connect() as conn: + result = conn.execute(text(sql), {'tag_name': tag_name}) + rows = result.fetchall() + stock_codes = [row[0] for row in rows if row[0]] + + logger.info(f"从gp_gnbk查询到标签'{tag_name}'关联{len(stock_codes)}只股票") + return stock_codes + + except Exception as e: + logger.error(f"查询gp_gnbk失败: {str(e)}") + return [] + + def process_tag( + self, + tag_name: str, + tag_type: Optional[str] = None, + model_type: str = "ds3_2", + enable_web_search: bool = True, + temperature: float = 0.7, + max_tokens: int = 4096 + ) -> Dict[str, Any]: + """处理标签:匹配数据库或AI分析 + + 流程: + 1. 先在gp_hybk表匹配(行业标签) + 2. 如果匹配到,保存标签和关联关系 + 3. 如果没匹配到,去gp_gnbk表匹配(概念标签) + 4. 如果匹配到,保存标签和关联关系 + 5. 如果都没匹配到,调用AI分析(AI标签) + + Args: + tag_name: 标签名称 + tag_type: 标签类型(如果已知,可选:行业标签/概念标签/自定义标签/AI标签) + model_type: 大模型类型,默认ds3_2 + enable_web_search: 是否启用联网搜索,默认True + temperature: 大模型温度参数,默认0.7 + max_tokens: 最大token数,默认4096 + + Returns: + Dict[str, Any]: 处理结果 + - success: 是否成功 + - tag_code: 标签代码 + - tag_type: 标签类型 + - source: 数据来源(hybk/gnbk/ai) + - stock_count: 关联股票数量 + - message: 处理消息 + - error: 错误信息 + """ + try: + logger.info(f"开始处理标签: {tag_name}") + + # 检查标签是否已存在 + existing_tag = self.service.database.get_tag_by_name(tag_name) + if existing_tag: + tag_code = existing_tag.get('tag_code') + logger.info(f"标签已存在: {tag_name}, tag_code: {tag_code}") + # 获取关联股票数量 + stock_codes = self.service.database.get_tag_stocks(tag_code) + return { + "success": True, + "tag_code": tag_code, + "tag_type": existing_tag.get('tag_type'), + "source": existing_tag.get('source', 'unknown'), + "stock_count": len(stock_codes), + "message": "标签已存在", + "error": None + } + + # 1. 先在gp_hybk表匹配(行业标签) + hybk_stocks = self._query_hybk_stocks(tag_name) + if hybk_stocks: + logger.info(f"在gp_hybk表中找到标签'{tag_name}',关联{len(hybk_stocks)}只股票") + + # 过滤有效股票 + valid_stocks = self._filter_valid_stocks(hybk_stocks) + + # 保存标签 + tag_data = { + 'tag_name': tag_name, + 'tag_type': tag_type or '行业标签', + 'status': 'completed', + 'progress': 100.0, + 'source': 'hybk' + } + tag_code = self.service.database.save_tag(tag_data) + + if tag_code: + # 保存关联关系(全量替换) + self.service.database.save_tag_stock_relations(tag_code, valid_stocks, replace_all=True) + + return { + "success": True, + "tag_code": tag_code, + "tag_type": tag_data['tag_type'], + "source": "hybk", + "stock_count": len(valid_stocks), + "message": f"从行业板块表匹配成功,关联{len(valid_stocks)}只股票", + "error": None + } + + # 2. 在gp_gnbk表匹配(概念标签) + gnbk_stocks = self._query_gnbk_stocks(tag_name) + if gnbk_stocks: + logger.info(f"在gp_gnbk表中找到标签'{tag_name}',关联{len(gnbk_stocks)}只股票") + + # 过滤有效股票 + valid_stocks = self._filter_valid_stocks(gnbk_stocks) + + # 保存标签 + tag_data = { + 'tag_name': tag_name, + 'tag_type': tag_type or '概念标签', + 'status': 'completed', + 'progress': 100.0, + 'source': 'gnbk' + } + tag_code = self.service.database.save_tag(tag_data) + + if tag_code: + # 保存关联关系(全量替换) + self.service.database.save_tag_stock_relations(tag_code, valid_stocks, replace_all=True) + + return { + "success": True, + "tag_code": tag_code, + "tag_type": tag_data['tag_type'], + "source": "gnbk", + "stock_count": len(valid_stocks), + "message": f"从概念板块表匹配成功,关联{len(valid_stocks)}只股票", + "error": None + } + + # 3. 都没匹配到,使用AI分析(异步队列处理) + logger.info(f"在数据库表中未找到标签'{tag_name}',开始AI分析(异步队列)") + + # 创建标签记录(状态为processing) + tag_data = { + 'tag_name': tag_name, + 'tag_type': tag_type or 'AI标签', + 'status': 'processing', + 'progress': 0.0, + 'source': 'ai' + } + tag_code = self.service.database.save_tag(tag_data) + + if not tag_code: + return { + "success": False, + "tag_code": None, + "tag_type": tag_data['tag_type'], + "source": "ai", + "stock_count": 0, + "message": "创建标签记录失败", + "error": "无法创建标签记录" + } + + # 获取所有需要处理的股票列表 + stock_list = self._get_stock_list() + total_stocks = len(stock_list) + + if not stock_list: + self.service.database.update_tag_status(tag_code, 'failed', 0.0) + return { + "success": False, + "tag_code": tag_code, + "tag_type": tag_data['tag_type'], + "source": "ai", + "stock_count": 0, + "message": "未获取到股票列表", + "error": "股票列表为空" + } + + # 将任务添加到全局Redis队列 + for stock in stock_list: + task = { + 'tag_code': tag_code, + 'tag_name': tag_name, + 'stock_code': stock['stock_code'], + 'stock_name': stock['stock_name'], + 'model_type': model_type, + 'enable_web_search': enable_web_search, + 'temperature': temperature, + 'max_tokens': max_tokens + } + self.redis_client.lpush(self.global_queue_key, json.dumps(task, ensure_ascii=False)) + + # 记录总任务数(用于进度计算) + self.redis_client.set(f"tag_analysis_total:{tag_code}", total_stocks) + # 记录已处理数(初始为0) + self.redis_client.set(f"tag_analysis_processed:{tag_code}", 0) + + return { + "success": True, + "tag_code": tag_code, + "tag_type": tag_data['tag_type'], + "source": "ai", + "stock_count": 0, # 初始为0,处理完成后会更新 + "message": f"AI分析任务已加入队列,共{total_stocks}只股票待处理", + "error": None, + "status": "processing" + } + + except Exception as e: + logger.error(f"处理标签失败: {str(e)}", exc_info=True) + return { + "success": False, + "tag_code": None, + "tag_type": None, + "source": None, + "stock_count": 0, + "message": "处理标签时发生异常", + "error": str(e) + } + + def _start_global_processing_thread(self): + """启动全局处理线程""" + if self.global_processing_thread and self.global_processing_thread.is_alive(): + logger.warning("全局处理线程已存在") + return + + self.stop_processing = False + self.global_processing_thread = threading.Thread( + target=self._process_global_queue, + daemon=True + ) + self.global_processing_thread.start() + logger.info("全局处理线程已启动") + + def _process_global_queue(self): + """处理全局Redis队列中的任务(持续运行)""" + logger.info("全局队列处理线程开始运行") + + # 定期检查队列和MongoDB状态的间隔(秒) + check_interval = 10 + last_check_time = 0 + + while not self.stop_processing: + try: + current_time = time.time() + + # 定期检查队列是否为空,如果为空则更新MongoDB中processing状态的标签 + if current_time - last_check_time >= check_interval: + last_check_time = current_time + + # 检查队列是否为空 + queue_length = self.redis_client.llen(self.global_queue_key) + + if queue_length == 0: + # 队列为空,查找MongoDB中所有processing状态的标签,标记为completed + processing_tags = list(self.service.database.tags_collection.find({'status': 'processing'})) + + if processing_tags: + logger.info(f"队列为空,发现 {len(processing_tags)} 个processing状态的标签,标记为完成") + + for tag in processing_tags: + tag_code = tag.get('tag_code') + tag_name = tag.get('tag_name') + + if not tag_code or not tag_name: + continue + + try: + # 标记为完成 + self.service.database.update_tag_status(tag_code, 'completed', 100.0) + + # 清理Redis键(如果存在) + total_key = f"tag_analysis_total:{tag_code}" + processed_key = f"tag_analysis_processed:{tag_code}" + self.redis_client.delete(total_key) + self.redis_client.delete(processed_key) + + logger.info(f"标签 {tag_name}({tag_code}) 队列为空,标记为完成") + except Exception as e: + logger.error(f"更新标签 {tag_name}({tag_code}) 状态失败: {str(e)}") + continue + + # 从队列右侧取出任务(FIFO),阻塞等待,超时1秒 + task_json = self.redis_client.brpop(self.global_queue_key, timeout=1) + + if not task_json: + # 超时,继续循环 + continue + + # task_json是元组 (key, value) + task_data = json.loads(task_json[1]) + tag_code = task_data.get('tag_code') + tag_name = task_data.get('tag_name') + stock_code = task_data.get('stock_code') + stock_name = task_data.get('stock_name') + + if not all([tag_code, tag_name, stock_code, stock_name]): + logger.warning(f"任务数据不完整: {task_data}") + continue + + try: + # 检查是否已经处理过(避免重复处理) + existing_result = self.service.database.get_analysis_result(stock_code, tag_name) + + if existing_result: + # 已存在,跳过分析,直接使用已有结果 + logger.info(f"股票 {stock_code} 与标签 {tag_name} 已存在分析结果,跳过处理") + relation_score = existing_result.get('relation_score', 0) + has_relation = existing_result.get('has_relation', False) + relation_type = existing_result.get('relation_type', '') + else: + # 执行分析 + analysis_result = self.service.analyzer.analyze_tag_relation( + stock_code=stock_code, + stock_name=stock_name, + tag_name=tag_name, + temperature=task_data.get('temperature', 0.7), + max_tokens=task_data.get('max_tokens', 4096) + ) + + # 保存分析结果到 tag_relation_analysis 集合 + if analysis_result.get('success'): + self.service.database.save_analysis_result(analysis_result) + relation_score = analysis_result.get('relation_score', 0) + has_relation = analysis_result.get('has_relation', False) + relation_type = analysis_result.get('relation_type', '') + else: + # 分析失败,跳过关联关系更新 + relation_score = 0 + has_relation = False + relation_type = '' + + # 如果关联度评分>=50 且 关联类型为"直接关联",立即更新关联关系 + if relation_score >= 50 and has_relation and relation_type == "直接关联": + self.service.database.save_tag_stock_relations( + tag_code, + [stock_code] + ) + + # 更新已处理数 + processed_key = f"tag_analysis_processed:{tag_code}" + processed_count = self.redis_client.incr(processed_key) + + # 获取总任务数 + total_key = f"tag_analysis_total:{tag_code}" + total_stocks = int(self.redis_client.get(total_key) or 0) + + # 计算并更新进度 + if total_stocks > 0: + progress = (processed_count / total_stocks) * 100 + # 通过tag_name查询标签并更新进度 + tag_info = self.service.database.get_tag_by_name(tag_name) + if tag_info: + tag_code_from_db = tag_info.get('tag_code') + if tag_code_from_db == tag_code: + self.service.database.update_tag_status(tag_code, 'processing', progress) + + # 每处理10只股票或达到总数时,检查队列中是否还有相同标签的任务 + if processed_count % 10 == 0 or processed_count >= total_stocks: + self._check_and_update_tag_status(tag_code, tag_name) + + except Exception as e: + logger.error(f"处理任务失败: {str(e)}", exc_info=True) + # 即使失败也要更新已处理数,避免卡住 + processed_key = f"tag_analysis_processed:{tag_code}" + self.redis_client.incr(processed_key) + continue + + except Exception as e: + logger.error(f"全局队列处理异常: {str(e)}", exc_info=True) + time.sleep(1) # 出错后等待1秒再继续 + + logger.info("全局队列处理线程已停止") + + def _check_and_update_tag_status(self, tag_code: str, tag_name: str): + """检查队列中是否还有相同标签的任务,如果没有则标记完成 + + Args: + tag_code: 标签代码 + tag_name: 标签名称 + """ + try: + total_key = f"tag_analysis_total:{tag_code}" + processed_key = f"tag_analysis_processed:{tag_code}" + + total_stocks = int(self.redis_client.get(total_key) or 0) + processed_count = int(self.redis_client.get(processed_key) or 0) + + # 如果已处理数达到总数,检查队列中是否还有该标签的任务 + if processed_count >= total_stocks: + # 检查队列中是否还有该标签的任务 + queue_length = self.redis_client.llen(self.global_queue_key) + + if queue_length == 0: + # 队列为空,标记完成 + self.service.database.update_tag_status(tag_code, 'completed', 100.0) + # 清理Redis键 + self.redis_client.delete(total_key) + self.redis_client.delete(processed_key) + logger.info(f"标签 {tag_name}({tag_code}) 处理完成,队列已空") + else: + # 检查队列中是否还有该标签的任务 + # 使用更准确的方式:遍历整个队列(由于队列可能很大,我们限制检查范围) + has_same_tag = False + # 检查队列中的所有任务(最多检查1000个,避免性能问题) + check_count = min(queue_length, 1000) + for i in range(check_count): + task_json = self.redis_client.lindex(self.global_queue_key, i) + if task_json: + try: + task_data = json.loads(task_json) + if task_data.get('tag_code') == tag_code: + has_same_tag = True + break + except: + continue + + if not has_same_tag: + # 队列中没有该标签的任务了,标记完成 + self.service.database.update_tag_status(tag_code, 'completed', 100.0) + # 清理Redis键 + self.redis_client.delete(total_key) + self.redis_client.delete(processed_key) + logger.info(f"标签 {tag_name}({tag_code}) 处理完成,队列中无该标签任务") + except Exception as e: + logger.error(f"检查标签状态失败: {str(e)}") + + def get_tag_progress(self, tag_code: str = None, tag_name: str = None) -> Dict[str, Any]: + """获取标签处理进度 + + Args: + tag_code: 标签代码(优先使用) + tag_name: 标签名称(如果tag_code为空则使用) + + Returns: + Dict[str, Any]: 进度信息 + """ + try: + # 根据tag_code或tag_name查询标签 + if tag_code: + tag = self.service.database.tags_collection.find_one({'tag_code': tag_code}) + elif tag_name: + tag = self.service.database.get_tag_by_name(tag_name) + else: + return {"error": "必须提供tag_code或tag_name"} + + if not tag: + return {"error": "标签不存在"} + + tag_code = tag.get('tag_code') + total_key = f"tag_analysis_total:{tag_code}" + processed_key = f"tag_analysis_processed:{tag_code}" + + total_stocks = int(self.redis_client.get(total_key) or 0) + processed_count = int(self.redis_client.get(processed_key) or 0) + + # 检查队列中该标签的剩余任务数(估算) + queue_length = self.redis_client.llen(self.global_queue_key) + remaining = max(0, total_stocks - processed_count) if total_stocks > 0 else 0 + + progress = tag.get('progress', 0.0) + status = tag.get('status', 'unknown') + + return { + "tag_code": tag_code, + "tag_name": tag.get('tag_name'), + "status": status, + "progress": progress, + "total_stocks": total_stocks, + "processed": processed_count, + "remaining": remaining, + "queue_length": queue_length # 全局队列长度 + } + except Exception as e: + logger.error(f"获取标签进度失败: {str(e)}") + return {"error": str(e)} + + def delete_tag_by_name( + self, + tag_name: str, + delete_analysis: bool = False + ) -> Dict[str, Any]: + """根据标签名称删除标签及其相关数据 + + Args: + tag_name: 标签名称 + delete_analysis: 是否删除分析结果,默认False(只删除标签和关联关系) + + Returns: + Dict[str, Any]: 删除结果 + - success: 是否成功 + - tag_code: 标签代码 + - deleted_tag: 是否删除了标签 + - deleted_relations: 删除的关联关系数量 + - deleted_analyses: 删除的分析结果数量(如果delete_analysis=True) + - deleted_redis_keys: 删除的Redis键数量 + - message: 处理消息 + - error: 错误信息 + """ + try: + logger.info(f"开始删除标签: {tag_name}, delete_analysis={delete_analysis}") + + # 先查找标签 + tag = self.service.database.get_tag_by_name(tag_name) + if not tag: + logger.warning(f"标签不存在: {tag_name}") + return { + "success": False, + "tag_code": None, + "deleted_tag": False, + "deleted_relations": 0, + "deleted_analyses": 0, + "deleted_redis_keys": 0, + "message": f"标签不存在: {tag_name}", + "error": "标签不存在" + } + + tag_code = tag.get('tag_code') + + # 1. 删除MongoDB中的标签和相关数据 + delete_result = self.service.database.delete_tag_by_name( + tag_name=tag_name, + delete_analysis=delete_analysis + ) + + if not delete_result.get('success'): + logger.error(f"删除标签失败: {delete_result.get('error')}") + return { + "success": False, + "tag_code": tag_code, + "deleted_tag": False, + "deleted_relations": 0, + "deleted_analyses": 0, + "deleted_redis_keys": 0, + "message": f"删除标签失败: {delete_result.get('error')}", + "error": delete_result.get('error') + } + + # 2. 删除Redis中的相关键 + deleted_redis_keys = 0 + try: + total_key = f"tag_analysis_total:{tag_code}" + processed_key = f"tag_analysis_processed:{tag_code}" + + # 删除进度相关键 + if self.redis_client.exists(total_key): + self.redis_client.delete(total_key) + deleted_redis_keys += 1 + + if self.redis_client.exists(processed_key): + self.redis_client.delete(processed_key) + deleted_redis_keys += 1 + + # 从队列中删除该标签的任务(可选:如果队列中有该标签的任务) + # 注意:这里只标记删除,实际任务可能在处理中,这里不做强制删除 + logger.info(f"删除Redis键: tag_code={tag_code}, 删除数量={deleted_redis_keys}") + + except Exception as e: + logger.warning(f"删除Redis键失败: {str(e)}") + # Redis删除失败不影响整体结果 + + # 统计删除结果 + deleted_tag = delete_result.get('deleted_tag', False) + deleted_relations = delete_result.get('deleted_relations', 0) + deleted_analyses = delete_result.get('deleted_analyses', 0) + + message = f"删除成功: 标签={deleted_tag}, 关联关系={deleted_relations}条" + if delete_analysis: + message += f", 分析结果={deleted_analyses}条" + message += f", Redis键={deleted_redis_keys}个" + + logger.info(f"标签删除完成: {tag_name}, {message}") + + return { + "success": True, + "tag_code": tag_code, + "deleted_tag": deleted_tag, + "deleted_relations": deleted_relations, + "deleted_analyses": deleted_analyses, + "deleted_redis_keys": deleted_redis_keys, + "message": message, + "error": None + } + + except Exception as e: + logger.error(f"删除标签失败: {str(e)}", exc_info=True) + return { + "success": False, + "tag_code": None, + "deleted_tag": False, + "deleted_relations": 0, + "deleted_analyses": 0, + "deleted_redis_keys": 0, + "message": f"删除标签时发生异常: {str(e)}", + "error": str(e) + } + + def close(self): + """关闭所有连接""" + try: + # 停止全局处理线程 + self.stop_processing = True + if self.global_processing_thread and self.global_processing_thread.is_alive(): + logger.info("等待全局处理线程结束...") + self.global_processing_thread.join(timeout=5) + + if self.mysql_engine: + self.mysql_engine.dispose() + logger.info("MySQL连接已关闭") + + if self.redis_client: + self.redis_client.close() + logger.info("Redis连接已关闭") + + self.service.close() + except Exception as e: + logger.error(f"关闭连接失败: {str(e)}") + + +if __name__ == "__main__": + # 设置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # 测试示例 + api = TagRelationAPI() + + # 测试:分析所有股票与"商业航天"标签的关联性 + result = api.analyze_tag_for_all_stocks(tag_name="商业航天") + + print("\n" + "="*50) + print("分析结果统计:") + print("="*50) + print(f"成功: {result.get('success')}") + print(f"总股票数: {result.get('total_stocks')}") + print(f"已处理: {result.get('processed_stocks')}") + print(f"成功分析: {result.get('success_count')}") + print(f"失败: {result.get('failed_count')}") + print(f"已保存: {result.get('saved_count')}") + + # 关闭连接 + api.close() + diff --git a/src/stock_tag_analysis/tag_relation_database.py b/src/stock_tag_analysis/tag_relation_database.py new file mode 100644 index 0000000..64fe06a --- /dev/null +++ b/src/stock_tag_analysis/tag_relation_database.py @@ -0,0 +1,448 @@ +""" +标签关联分析数据库操作模块 +用于连接 MongoDB 并存储分析结果 +""" + +import logging +import pymongo +from typing import Dict, Any, Optional, List +from datetime import datetime + +# 导入配置 +import sys +import os + +# 添加项目路径 +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_DIR) + +try: + from valuation_analysis.config import MONGO_CONFIG222 +except ImportError: + try: + from src.valuation_analysis.config import MONGO_CONFIG222 + except ImportError: + # 尝试从当前目录导入 + sys.path.append(os.path.dirname(ROOT_DIR)) + from src.valuation_analysis.config import MONGO_CONFIG222 + +# 设置日志记录 +logger = logging.getLogger(__name__) + + +class TagRelationDatabase: + """标签关联分析数据库操作类""" + + def __init__(self, collection_name: Optional[str] = None): + """初始化数据库连接 + + Args: + collection_name: 集合名称,如果不指定则使用配置中的默认集合名 + """ + self.mongo_client = None + self.db = None + self.collection_name = collection_name or MONGO_CONFIG222.get('collection', 'tag_relation_analysis') + self.collection = None + + # 标签集合和关联集合 + self.tags_collection = None + self.tag_stock_relations_collection = None + + self.connect_mongodb() + + def connect_mongodb(self): + """连接MongoDB数据库""" + try: + self.mongo_client = pymongo.MongoClient( + host=MONGO_CONFIG222['host'], + port=MONGO_CONFIG222['port'], + username=MONGO_CONFIG222['username'], + password=MONGO_CONFIG222['password'], + authSource='admin' # 认证数据库 + ) + self.db = self.mongo_client[MONGO_CONFIG222['db']] + self.collection = self.db[self.collection_name] + + # 初始化标签集合和关联集合 + self.tags_collection = self.db['stock_pool_tags'] # 股票池标签集合 + self.tag_stock_relations_collection = self.db['tag_stock_relations'] + + # 测试连接 + self.mongo_client.admin.command('ping') + logger.info(f"MongoDB连接成功,数据库: {MONGO_CONFIG222['db']}, 集合: {self.collection_name}") + + except Exception as e: + logger.error(f"MongoDB连接失败: {str(e)}") + raise + + def save_analysis_result(self, analysis_result: Dict[str, Any]) -> bool: + """保存分析结果到MongoDB + + Args: + analysis_result: 分析结果字典,包含所有分析字段 + + Returns: + bool: 是否保存成功 + """ + try: + if not analysis_result: + logger.warning("分析结果为空,无法保存") + return False + + # 添加保存时间 + analysis_result['save_time'] = datetime.now().isoformat() + + # 使用 stock_code 和 tag_name 作为唯一标识 + filter_condition = { + 'stock_code': analysis_result.get('stock_code'), + 'tag_name': analysis_result.get('tag_name') + } + + # 检查是否已存在 + existing_record = self.collection.find_one(filter_condition) + + if existing_record: + # 更新现有记录 + update_result = self.collection.update_one( + filter_condition, + {'$set': analysis_result} + ) + logger.info(f"更新分析结果: {analysis_result.get('stock_code')} - {analysis_result.get('tag_name')}") + return update_result.modified_count > 0 + else: + # 插入新记录 + insert_result = self.collection.insert_one(analysis_result) + logger.info(f"插入新分析结果: {analysis_result.get('stock_code')} - {analysis_result.get('tag_name')}, ID: {insert_result.inserted_id}") + return insert_result.inserted_id is not None + + except Exception as e: + logger.error(f"保存分析结果失败: {str(e)}", exc_info=True) + return False + + def get_analysis_result(self, stock_code: str, tag_name: str) -> Optional[Dict[str, Any]]: + """获取指定股票和标签的分析结果 + + Args: + stock_code: 股票代码 + tag_name: 标签名称 + + Returns: + Optional[Dict[str, Any]]: 分析结果,如果不存在返回None + """ + try: + result = self.collection.find_one({ + 'stock_code': stock_code, + 'tag_name': tag_name + }) + + if result: + # 移除 MongoDB 的 _id 字段(如果需要) + result.pop('_id', None) + + return result + + except Exception as e: + logger.error(f"获取分析结果失败: {str(e)}") + return None + + def get_stock_analyses(self, stock_code: str) -> list: + """获取指定股票的所有标签分析结果 + + Args: + stock_code: 股票代码 + + Returns: + list: 分析结果列表 + """ + try: + results = list(self.collection.find({'stock_code': stock_code})) + + # 移除 MongoDB 的 _id 字段 + for result in results: + result.pop('_id', None) + + return results + + except Exception as e: + logger.error(f"获取股票分析结果失败: {str(e)}") + return [] + + def get_tag_analyses(self, tag_name: str) -> list: + """获取指定标签的所有股票分析结果 + + Args: + tag_name: 标签名称 + + Returns: + list: 分析结果列表 + """ + try: + results = list(self.collection.find({'tag_name': tag_name})) + + # 移除 MongoDB 的 _id 字段 + for result in results: + result.pop('_id', None) + + return results + + except Exception as e: + logger.error(f"获取标签分析结果失败: {str(e)}") + return [] + + def save_tag(self, tag_data: Dict[str, Any]) -> Optional[str]: + """保存标签到tags集合 + + Args: + tag_data: 标签数据字典,包含tag_id, tag_name, tag_type等 + + Returns: + Optional[str]: 标签ID(tag_code),如果保存失败返回None + """ + try: + # 检查标签是否已存在(根据tag_name) + existing_tag = self.tags_collection.find_one({'tag_name': tag_data.get('tag_name')}) + + if existing_tag: + # 更新现有标签 + tag_code = existing_tag.get('tag_code') + self.tags_collection.update_one( + {'tag_name': tag_data.get('tag_name')}, + {'$set': tag_data} + ) + logger.info(f"更新标签: {tag_data.get('tag_name')}, tag_code: {tag_code}") + return tag_code + else: + # 插入新标签 + # 如果没有tag_code,生成一个 + if 'tag_code' not in tag_data or not tag_data.get('tag_code'): + # 生成tag_code:使用当前最大ID+1,或者使用时间戳 + max_tag = self.tags_collection.find_one(sort=[("tag_code", -1)]) + if max_tag and max_tag.get('tag_code'): + try: + max_id = int(max_tag['tag_code']) + tag_code = str(max_id + 1) + except: + tag_code = str(int(datetime.now().timestamp() * 1000)) + else: + tag_code = "1" + tag_data['tag_code'] = tag_code + + tag_data['create_time'] = datetime.now().isoformat() + result = self.tags_collection.insert_one(tag_data) + logger.info(f"插入新标签: {tag_data.get('tag_name')}, tag_code: {tag_data.get('tag_code')}") + return tag_data.get('tag_code') + + except Exception as e: + logger.error(f"保存标签失败: {str(e)}", exc_info=True) + return None + + def get_tag_by_name(self, tag_name: str) -> Optional[Dict[str, Any]]: + """根据标签名称获取标签信息 + + Args: + tag_name: 标签名称 + + Returns: + Optional[Dict[str, Any]]: 标签信息,如果不存在返回None + """ + try: + tag = self.tags_collection.find_one({'tag_name': tag_name}) + if tag: + tag.pop('_id', None) + return tag + except Exception as e: + logger.error(f"获取标签失败: {str(e)}") + return None + + def save_tag_stock_relations(self, tag_code: str, stock_codes: List[str], replace_all: bool = False) -> bool: + """保存标签与股票的关联关系 + + Args: + tag_code: 标签代码 + stock_codes: 股票代码列表 + replace_all: 是否替换所有关联关系(True=全量替换,False=增量添加),默认False + + Returns: + bool: 是否保存成功 + """ + try: + if not stock_codes: + logger.warning("股票代码列表为空") + return False + + # 如果replace_all为True,先删除该标签的所有旧关联关系 + if replace_all: + self.tag_stock_relations_collection.delete_many({'tag_code': tag_code}) + + # 批量插入新关联关系(使用upsert避免重复) + relations = [] + for stock_code in stock_codes: + # 检查是否已存在 + existing = self.tag_stock_relations_collection.find_one({ + 'tag_code': tag_code, + 'stock_code': stock_code + }) + + if not existing: + relations.append({ + 'tag_code': tag_code, + 'stock_code': stock_code, + 'create_time': datetime.now().isoformat() + }) + + if relations: + self.tag_stock_relations_collection.insert_many(relations) + logger.info(f"保存标签关联关系: tag_code={tag_code}, 新增股票数量={len(relations)}") + + return True + + except Exception as e: + logger.error(f"保存标签关联关系失败: {str(e)}", exc_info=True) + return False + + def get_tag_stocks(self, tag_code: str) -> List[str]: + """获取标签关联的所有股票代码 + + Args: + tag_code: 标签代码 + + Returns: + List[str]: 股票代码列表 + """ + try: + relations = list(self.tag_stock_relations_collection.find({'tag_code': tag_code})) + stock_codes = [r.get('stock_code') for r in relations if r.get('stock_code')] + return stock_codes + except Exception as e: + logger.error(f"获取标签关联股票失败: {str(e)}") + return [] + + def update_tag_status(self, tag_code: str, status: str, progress: Optional[float] = None): + """更新标签处理状态和进度 + + Args: + tag_code: 标签代码 + status: 处理状态(pending/processing/completed/failed) + progress: 处理进度(0-100),可选 + """ + try: + update_data = { + 'status': status, + 'update_time': datetime.now().isoformat() + } + if progress is not None: + update_data['progress'] = progress + + self.tags_collection.update_one( + {'tag_code': tag_code}, + {'$set': update_data} + ) + logger.info(f"更新标签状态: tag_code={tag_code}, status={status}, progress={progress}") + except Exception as e: + logger.error(f"更新标签状态失败: {str(e)}") + + def delete_tag(self, tag_code: str, tag_name: str, delete_analysis: bool = False) -> Dict[str, Any]: + """删除标签及其相关数据 + + Args: + tag_code: 标签代码 + tag_name: 标签名称(用于删除分析结果) + delete_analysis: 是否删除分析结果,默认False + + Returns: + Dict[str, Any]: 删除结果 + - success: 是否成功 + - deleted_tag: 是否删除了标签 + - deleted_relations: 删除的关联关系数量 + - deleted_analyses: 删除的分析结果数量(如果delete_analysis=True) + - error: 错误信息 + """ + result = { + 'success': False, + 'deleted_tag': False, + 'deleted_relations': 0, + 'deleted_analyses': 0, + 'error': None + } + + try: + # 1. 删除标签与股票的关联关系(中间表) + relations_result = self.tag_stock_relations_collection.delete_many({'tag_code': tag_code}) + result['deleted_relations'] = relations_result.deleted_count + logger.info(f"删除标签关联关系: tag_code={tag_code}, 删除数量={relations_result.deleted_count}") + + # 2. 删除标签本身 + tag_result = self.tags_collection.delete_one({'tag_code': tag_code}) + result['deleted_tag'] = tag_result.deleted_count > 0 + logger.info(f"删除标签: tag_code={tag_code}, tag_name={tag_name}, 是否删除={result['deleted_tag']}") + + # 3. 可选:删除分析结果 + if delete_analysis: + analyses_result = self.collection.delete_many({'tag_name': tag_name}) + result['deleted_analyses'] = analyses_result.deleted_count + logger.info(f"删除分析结果: tag_name={tag_name}, 删除数量={analyses_result.deleted_count}") + + result['success'] = True + return result + + except Exception as e: + logger.error(f"删除标签失败: {str(e)}", exc_info=True) + result['error'] = str(e) + return result + + def delete_tag_by_name(self, tag_name: str, delete_analysis: bool = False) -> Dict[str, Any]: + """根据标签名称删除标签及其相关数据 + + Args: + tag_name: 标签名称 + delete_analysis: 是否删除分析结果,默认False + + Returns: + Dict[str, Any]: 删除结果 + - success: 是否成功 + - tag_code: 标签代码(如果找到) + - deleted_tag: 是否删除了标签 + - deleted_relations: 删除的关联关系数量 + - deleted_analyses: 删除的分析结果数量(如果delete_analysis=True) + - error: 错误信息 + """ + result = { + 'success': False, + 'tag_code': None, + 'deleted_tag': False, + 'deleted_relations': 0, + 'deleted_analyses': 0, + 'error': None + } + + try: + # 先查找标签 + tag = self.get_tag_by_name(tag_name) + if not tag: + result['error'] = f"标签不存在: {tag_name}" + logger.warning(result['error']) + return result + + tag_code = tag.get('tag_code') + result['tag_code'] = tag_code + + # 调用根据tag_code删除的方法 + delete_result = self.delete_tag(tag_code, tag_name, delete_analysis) + + result.update(delete_result) + return result + + except Exception as e: + logger.error(f"根据标签名称删除标签失败: {str(e)}", exc_info=True) + result['error'] = str(e) + return result + + def close_connection(self): + """关闭数据库连接""" + try: + if self.mongo_client: + self.mongo_client.close() + logger.info("MongoDB连接已关闭") + except Exception as e: + logger.error(f"关闭MongoDB连接失败: {str(e)}") + diff --git a/src/stock_tag_analysis/tag_relation_service.py b/src/stock_tag_analysis/tag_relation_service.py new file mode 100644 index 0000000..5c97579 --- /dev/null +++ b/src/stock_tag_analysis/tag_relation_service.py @@ -0,0 +1,222 @@ +""" +标签关联分析服务类 +整合分析器和数据库操作,提供完整的分析+存储功能 +""" + +import logging +import sys +import os +from typing import Dict, Any, Optional, List + +# 处理相对导入和绝对导入 +try: + from .tag_relation_analyzer import TagRelationAnalyzer + from .tag_relation_database import TagRelationDatabase +except ImportError: + # 如果相对导入失败,使用绝对导入 + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from stock_tag_analysis.tag_relation_analyzer import TagRelationAnalyzer + from stock_tag_analysis.tag_relation_database import TagRelationDatabase + +# 设置日志记录 +logger = logging.getLogger(__name__) + + +class TagRelationService: + """标签关联分析服务类""" + + def __init__(self, model_type: str = "ds3_2", enable_web_search: bool = True, collection_name: Optional[str] = None): + """初始化服务 + + Args: + model_type: 大模型类型,默认为 ds3_2 + enable_web_search: 是否启用联网搜索,默认True + collection_name: MongoDB集合名称,如果不指定则使用默认集合名 + """ + # 初始化分析器 + self.analyzer = TagRelationAnalyzer( + model_type=model_type, + enable_web_search=enable_web_search + ) + + # 初始化数据库 + self.database = TagRelationDatabase(collection_name=collection_name) + + logger.info("TagRelationService 初始化完成") + + def analyze_and_save( + self, + stock_code: str, + stock_name: str, + tag_name: str, + temperature: float = 0.7, + max_tokens: int = 4096, + save_to_db: bool = True + ) -> Dict[str, Any]: + """分析个股与标签的关联性并保存到数据库 + + Args: + stock_code: 股票代码(如:300750.SZ) + stock_name: 股票名称(如:宁德时代) + tag_name: 标签名称(如:医药、新能源、金融科技等) + temperature: 大模型温度参数,默认0.7 + max_tokens: 最大token数,默认4096 + save_to_db: 是否保存到数据库,默认True + + Returns: + Dict[str, Any]: 分析结果,包含所有分析字段和保存状态 + """ + try: + # 执行分析 + logger.info(f"开始分析 {stock_name}({stock_code}) 与标签 {tag_name} 的关联性") + analysis_result = self.analyzer.analyze_tag_relation( + stock_code=stock_code, + stock_name=stock_name, + tag_name=tag_name, + temperature=temperature, + max_tokens=max_tokens + ) + + # 保存到数据库 + if save_to_db and analysis_result.get('success'): + save_success = self.database.save_analysis_result(analysis_result) + analysis_result['saved_to_db'] = save_success + if save_success: + logger.info(f"分析结果已保存到数据库: {stock_code} - {tag_name}") + else: + logger.warning(f"分析结果保存失败: {stock_code} - {tag_name}") + else: + analysis_result['saved_to_db'] = False + if not save_to_db: + logger.info(f"跳过数据库保存: {stock_code} - {tag_name}") + + return analysis_result + + except Exception as e: + logger.error(f"分析并保存过程中出错: {str(e)}", exc_info=True) + return { + "success": False, + "stock_code": stock_code, + "stock_name": stock_name, + "tag_name": tag_name, + "error": str(e), + "saved_to_db": False + } + + def batch_analyze_and_save( + self, + stock_list: List[Dict[str, str]], + tag_name: str, + temperature: float = 0.7, + max_tokens: int = 4096, + save_to_db: bool = True + ) -> List[Dict[str, Any]]: + """批量分析多只股票与标签的关联性并保存到数据库 + + Args: + stock_list: 股票列表,每个元素包含 stock_code 和 stock_name + tag_name: 标签名称 + temperature: 大模型温度参数 + max_tokens: 最大token数 + save_to_db: 是否保存到数据库,默认True + + Returns: + List[Dict[str, Any]]: 分析结果列表 + """ + results = [] + total = len(stock_list) + + for idx, stock in enumerate(stock_list, 1): + stock_code = stock.get("stock_code") + stock_name = stock.get("stock_name") + + if not stock_code or not stock_name: + logger.warning(f"跳过无效股票信息: {stock}") + continue + + logger.info(f"处理进度: {idx}/{total} - {stock_name}({stock_code})") + + result = self.analyze_and_save( + stock_code=stock_code, + stock_name=stock_name, + tag_name=tag_name, + temperature=temperature, + max_tokens=max_tokens, + save_to_db=save_to_db + ) + results.append(result) + + # 统计结果 + success_count = sum(1 for r in results if r.get('success')) + saved_count = sum(1 for r in results if r.get('saved_to_db')) + logger.info(f"批量分析完成 - 总数: {len(results)}, 成功: {success_count}, 已保存: {saved_count}") + + return results + + def get_analysis_result(self, stock_code: str, tag_name: str) -> Optional[Dict[str, Any]]: + """从数据库获取分析结果 + + Args: + stock_code: 股票代码 + tag_name: 标签名称 + + Returns: + Optional[Dict[str, Any]]: 分析结果,如果不存在返回None + """ + return self.database.get_analysis_result(stock_code, tag_name) + + def get_stock_analyses(self, stock_code: str) -> list: + """获取指定股票的所有标签分析结果 + + Args: + stock_code: 股票代码 + + Returns: + list: 分析结果列表 + """ + return self.database.get_stock_analyses(stock_code) + + def get_tag_analyses(self, tag_name: str) -> list: + """获取指定标签的所有股票分析结果 + + Args: + tag_name: 标签名称 + + Returns: + list: 分析结果列表 + """ + return self.database.get_tag_analyses(tag_name) + + def close(self): + """关闭数据库连接""" + self.database.close_connection() + + +if __name__ == "__main__": + # 设置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # 测试示例 + service = TagRelationService() + + # 测试:分析宁德时代与医药标签的关联性并保存 + result = service.analyze_and_save( + stock_code="000547.SZ", + stock_name="航天发展", + tag_name="商业航天" + ) + + print("\n" + "="*50) + print("分析结果:") + print("="*50) + print(f"成功: {result.get('success')}") + print(f"已保存: {result.get('saved_to_db')}") + print(f"关联度评分: {result.get('relation_score')}") + print(f"结论: {result.get('conclusion')}") + + # 关闭连接 + service.close() + diff --git a/src/tushare_scripts/chip_distribution_collector.py b/src/tushare_scripts/chip_distribution_collector.py index a4ea8ea..7e32e6d 100644 --- a/src/tushare_scripts/chip_distribution_collector.py +++ b/src/tushare_scripts/chip_distribution_collector.py @@ -468,7 +468,7 @@ if __name__ == "__main__": # 4. 采集指定日期范围的数据 collect_chip_distribution(db_url, tushare_token, mode='full', - start_date='2021-01-03', end_date='2021-09-30') + start_date='2025-11-28', end_date='2025-12-09') # 5. 调整批量入库大小(默认100只股票一批) # collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200) diff --git a/src/valuation_analysis/config.py b/src/valuation_analysis/config.py index f83e5e4..3325560 100644 --- a/src/valuation_analysis/config.py +++ b/src/valuation_analysis/config.py @@ -43,6 +43,27 @@ MONGO_CONFIG2 = { 'collection': 'wind_financial_analysis' } +# --- MySQL配置(用于个股标签关联分析模块) +DB_CONFIG_TAG_ANALYSIS = { + 'host': '192.168.18.199', + 'port': 3306, + 'user': 'root', + 'password': 'Chlry#$.8', + 'database': 'db_gp_cj' +} +# 创建数据库连接URL(用于个股标签关联分析模块) +DB_URL_TAG_ANALYSIS = f"mysql+pymysql://{DB_CONFIG_TAG_ANALYSIS['user']}:{DB_CONFIG_TAG_ANALYSIS['password']}@{DB_CONFIG_TAG_ANALYSIS['host']}:{DB_CONFIG_TAG_ANALYSIS['port']}/{DB_CONFIG_TAG_ANALYSIS['database']}" + +# MongoDB配置(用于个股标签关联分析模块) +MONGO_CONFIG222 = { + 'host': '192.168.16.222', + 'port': 27017, + 'db': 'stock_predictions', + 'username': 'stockvix', + 'password': '!stock@vix2', + 'collection': 'tag_relation_analysis' # 默认集合名 +} + # 项目根目录 ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))