stock_fundamentals/src/quantitative_analysis/momentum_analysis.py

153 lines
5.3 KiB
Python

# -*- coding: utf-8 -*-
import sys
import os
import requests
import logging
from typing import Dict, List, Optional
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from src.valuation_analysis.industry_analysis import IndustryAnalyzer
except ImportError:
# 兼容在不同环境下执行
from valuation_analysis.industry_analysis import IndustryAnalyzer
# 设置日志
logger = logging.getLogger("momentum_analysis")
class MomentumAnalyzer:
"""动量分析器类"""
def __init__(self):
"""初始化动量分析器"""
self.industry_analyzer = IndustryAnalyzer()
self.momentum_api_url = "http://192.168.18.42:5000/api/dify/getStockMomentumIndex"
logger.info("动量分析器初始化完成")
def get_stocks_by_name(self, name: str, is_concept: bool = False) -> List[str]:
"""
根据行业或概念名称获取股票列表。
返回的股票代码格式为 '600036.SH'
"""
if is_concept:
# 调用获取概念成份股的方法
raw_codes = self.industry_analyzer.get_concept_stocks(name)
else:
# 调用获取行业成份股的方法
raw_codes = self.industry_analyzer.get_industry_stocks(name)
# 统一将 'SH600036' 格式转换为 '600036.SH' 格式
stock_list = []
if raw_codes:
for code in raw_codes:
if not isinstance(code, str): continue
if code.startswith('SH'):
stock_list.append(f"{code[2:]}.SH")
elif code.startswith('SZ'):
stock_list.append(f"{code[2:]}.SZ")
elif code.startswith('BJ'):
stock_list.append(f"{code[2:]}.BJ")
return stock_list
def get_momentum_indicators(self, stock_code: str, industry_codes: List[str]) -> Optional[Dict]:
"""
获取单个股票的动量指标数据。
Args:
stock_code: 目标股票代码 (e.g., '600036.SH')
industry_codes: 相关的股票代码列表 (e.g., ['600036.SH', '600000.SH'])
Returns:
动量指标数据字典或None
"""
try:
payload = {
# 接口需要无后缀的代码列表
"code_list": industry_codes,
"target_code": stock_code
}
response = requests.post(self.momentum_api_url, json=payload, timeout=500)
if response.status_code != 200:
logger.error(f"获取动量指标失败({stock_code}): HTTP {response.status_code}, {response.text}")
return None
data = response.json()
# 为返回结果补充股票代码和名称
data['stock_code'] = stock_code
return data
except requests.exceptions.Timeout:
logger.error(f"获取动量指标超时({stock_code})")
return None
except Exception as e:
logger.error(f"获取动量指标异常({stock_code}): {str(e)}")
return None
def analyze_momentum_by_name(self, name: str, is_concept: bool = False) -> Dict:
"""
根据行业或概念名称,批量获取其中所有股票的动量指标。
Args:
name: 行业或概念的名称
is_concept: 是否为概念板块
Returns:
包含所有股票动量数据的字典
"""
# 1. 获取板块内所有股票代码
stock_list = self.get_stocks_by_name(name, is_concept)
if not stock_list:
return {'success': False, 'message': f'未找到板块 "{name}" 中的股票'}
all_results = []
# 2. 依次请求所有股票的动量数据
for stock_code in stock_list:
try:
temp_list = [stock_code]
result = self.get_momentum_indicators(stock_code, temp_list)
if result:
all_results.append(result)
except Exception as exc:
logger.error(f"处理股票 {stock_code} 的动量分析时产生异常: {exc}")
return {
'success': True,
'plate_name': name,
'is_concept': is_concept,
'stock_count': len(stock_list),
'results_count': len(all_results),
'data': all_results
}
if __name__ == '__main__':
# 示例用法
analyzer = MomentumAnalyzer()
# 1. 测试行业
industry_name = "证券"
industry_results = analyzer.analyze_momentum_by_name(industry_name, is_concept=False)
print(f"\n行业 '{industry_name}' 动量分析结果 (前5条):")
if industry_results['success']:
# 打印部分结果
for item in industry_results['data'][:5]:
print(item)
else:
print(industry_results['message'])
# 2. 测试概念
concept_name = "芯片"
concept_results = analyzer.analyze_momentum_by_name(concept_name, is_concept=True)
print(f"\n概念 '{concept_name}' 动量分析结果 (前5条):")
if concept_results['success']:
# 打印部分结果
for item in concept_results['data'][:5]:
print(item)
else:
print(concept_results['message'])