153 lines
5.3 KiB
Python
153 lines
5.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
import sys
|
|
import os
|
|
import requests
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
# 添加项目根目录到 Python 路径
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
try:
|
|
from src.valuation_analysis.industry_analysis import IndustryAnalyzer
|
|
except ImportError:
|
|
# 兼容在不同环境下执行
|
|
from valuation_analysis.industry_analysis import IndustryAnalyzer
|
|
|
|
# 设置日志
|
|
logger = logging.getLogger("momentum_analysis")
|
|
|
|
class MomentumAnalyzer:
|
|
"""动量分析器类"""
|
|
|
|
def __init__(self):
|
|
"""初始化动量分析器"""
|
|
self.industry_analyzer = IndustryAnalyzer()
|
|
self.momentum_api_url = "http://192.168.18.42:5000/api/dify/getStockMomentumIndex"
|
|
logger.info("动量分析器初始化完成")
|
|
|
|
def get_stocks_by_name(self, name: str, is_concept: bool = False) -> List[str]:
|
|
"""
|
|
根据行业或概念名称获取股票列表。
|
|
返回的股票代码格式为 '600036.SH'
|
|
"""
|
|
if is_concept:
|
|
# 调用获取概念成份股的方法
|
|
raw_codes = self.industry_analyzer.get_concept_stocks(name)
|
|
else:
|
|
# 调用获取行业成份股的方法
|
|
raw_codes = self.industry_analyzer.get_industry_stocks(name)
|
|
|
|
# 统一将 'SH600036' 格式转换为 '600036.SH' 格式
|
|
stock_list = []
|
|
if raw_codes:
|
|
for code in raw_codes:
|
|
if not isinstance(code, str): continue
|
|
|
|
if code.startswith('SH'):
|
|
stock_list.append(f"{code[2:]}.SH")
|
|
elif code.startswith('SZ'):
|
|
stock_list.append(f"{code[2:]}.SZ")
|
|
elif code.startswith('BJ'):
|
|
stock_list.append(f"{code[2:]}.BJ")
|
|
|
|
return stock_list
|
|
|
|
def get_momentum_indicators(self, stock_code: str, industry_codes: List[str]) -> Optional[Dict]:
|
|
"""
|
|
获取单个股票的动量指标数据。
|
|
|
|
Args:
|
|
stock_code: 目标股票代码 (e.g., '600036.SH')
|
|
industry_codes: 相关的股票代码列表 (e.g., ['600036.SH', '600000.SH'])
|
|
|
|
Returns:
|
|
动量指标数据字典或None
|
|
"""
|
|
try:
|
|
payload = {
|
|
# 接口需要无后缀的代码列表
|
|
"code_list": industry_codes,
|
|
"target_code": stock_code
|
|
}
|
|
|
|
response = requests.post(self.momentum_api_url, json=payload, timeout=500)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"获取动量指标失败({stock_code}): HTTP {response.status_code}, {response.text}")
|
|
return None
|
|
|
|
data = response.json()
|
|
# 为返回结果补充股票代码和名称
|
|
data['stock_code'] = stock_code
|
|
return data
|
|
|
|
except requests.exceptions.Timeout:
|
|
logger.error(f"获取动量指标超时({stock_code})")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"获取动量指标异常({stock_code}): {str(e)}")
|
|
return None
|
|
|
|
def analyze_momentum_by_name(self, name: str, is_concept: bool = False) -> Dict:
|
|
"""
|
|
根据行业或概念名称,批量获取其中所有股票的动量指标。
|
|
|
|
Args:
|
|
name: 行业或概念的名称
|
|
is_concept: 是否为概念板块
|
|
|
|
Returns:
|
|
包含所有股票动量数据的字典
|
|
"""
|
|
# 1. 获取板块内所有股票代码
|
|
stock_list = self.get_stocks_by_name(name, is_concept)
|
|
if not stock_list:
|
|
return {'success': False, 'message': f'未找到板块 "{name}" 中的股票'}
|
|
|
|
all_results = []
|
|
|
|
# 2. 依次请求所有股票的动量数据
|
|
for stock_code in stock_list:
|
|
try:
|
|
temp_list = [stock_code]
|
|
result = self.get_momentum_indicators(stock_code, temp_list)
|
|
if result:
|
|
all_results.append(result)
|
|
except Exception as exc:
|
|
logger.error(f"处理股票 {stock_code} 的动量分析时产生异常: {exc}")
|
|
|
|
return {
|
|
'success': True,
|
|
'plate_name': name,
|
|
'is_concept': is_concept,
|
|
'stock_count': len(stock_list),
|
|
'results_count': len(all_results),
|
|
'data': all_results
|
|
}
|
|
|
|
if __name__ == '__main__':
|
|
# 示例用法
|
|
analyzer = MomentumAnalyzer()
|
|
|
|
# 1. 测试行业
|
|
industry_name = "证券"
|
|
industry_results = analyzer.analyze_momentum_by_name(industry_name, is_concept=False)
|
|
print(f"\n行业 '{industry_name}' 动量分析结果 (前5条):")
|
|
if industry_results['success']:
|
|
# 打印部分结果
|
|
for item in industry_results['data'][:5]:
|
|
print(item)
|
|
else:
|
|
print(industry_results['message'])
|
|
|
|
# 2. 测试概念
|
|
concept_name = "芯片"
|
|
concept_results = analyzer.analyze_momentum_by_name(concept_name, is_concept=True)
|
|
print(f"\n概念 '{concept_name}' 动量分析结果 (前5条):")
|
|
if concept_results['success']:
|
|
# 打印部分结果
|
|
for item in concept_results['data'][:5]:
|
|
print(item)
|
|
else:
|
|
print(concept_results['message']) |