124 lines
4.2 KiB
Python
124 lines
4.2 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from sqlalchemy import create_engine, text
|
||
import pandas as pd
|
||
|
||
from src.valuation_analysis.config import DB_URL
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class IndexAnalyzer:
|
||
"""指数数据分析工具类"""
|
||
|
||
def __init__(self, db_url=None):
|
||
# 初始化数据库连接
|
||
self.db_url = db_url or DB_URL
|
||
self.engine = create_engine(self.db_url)
|
||
|
||
def get_indices_list(self):
|
||
"""
|
||
获取可用指数列表
|
||
|
||
Returns:
|
||
list: 包含指数信息的列表 [{"id": id, "name": name, "code": code}, ...]
|
||
"""
|
||
try:
|
||
with self.engine.connect() as conn:
|
||
query = text("""
|
||
SELECT id, gp_name as name, gp_code as code
|
||
FROM gp_code_zs
|
||
ORDER BY gp_name
|
||
""")
|
||
result = conn.execute(query).fetchall()
|
||
|
||
indices = []
|
||
for row in result:
|
||
indices.append({
|
||
"id": row[0],
|
||
"name": row[1],
|
||
"code": row[2]
|
||
})
|
||
|
||
logger.info(f"获取到 {len(indices)} 个指数")
|
||
return indices
|
||
except Exception as e:
|
||
logger.error(f"获取指数列表失败: {str(e)}")
|
||
return []
|
||
|
||
def get_index_data(self, index_code, start_date=None, end_date=None):
|
||
"""
|
||
获取指数历史数据
|
||
|
||
Args:
|
||
index_code: 指数代码
|
||
start_date: 开始日期 (可选,默认为1年前)
|
||
end_date: 结束日期 (可选,默认为今天)
|
||
|
||
Returns:
|
||
dict: 包含指数数据的字典 {"code": code, "dates": [...], "values": [...]}
|
||
"""
|
||
try:
|
||
# 处理日期参数
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||
|
||
if start_date is None:
|
||
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d')
|
||
|
||
with self.engine.connect() as conn:
|
||
query = text("""
|
||
SELECT timestamp, close
|
||
FROM gp_day_data
|
||
WHERE symbol = :symbol
|
||
AND timestamp BETWEEN :start_date AND :end_date
|
||
ORDER BY timestamp
|
||
""")
|
||
|
||
result = conn.execute(query, {
|
||
"symbol": index_code,
|
||
"start_date": start_date,
|
||
"end_date": end_date
|
||
}).fetchall()
|
||
|
||
dates = []
|
||
values = []
|
||
|
||
for row in result:
|
||
dates.append(row[0].strftime('%Y-%m-%d'))
|
||
# close可能是字符串类型,转换为浮点数
|
||
values.append(float(row[1]) if row[1] else None)
|
||
|
||
logger.info(f"获取指数 {index_code} 数据: {len(dates)} 条记录")
|
||
return {
|
||
"code": index_code,
|
||
"dates": dates,
|
||
"values": values
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取指数 {index_code} 数据失败: {str(e)}")
|
||
return {
|
||
"code": index_code,
|
||
"dates": [],
|
||
"values": []
|
||
}
|
||
|
||
# 测试代码
|
||
if __name__ == "__main__":
|
||
analyzer = IndexAnalyzer()
|
||
|
||
# 测试获取指数列表
|
||
indices = analyzer.get_indices_list()
|
||
print(f"指数列表: {indices[:5]}...")
|
||
|
||
# 测试获取指数数据
|
||
if indices:
|
||
# 测试第一个指数的数据
|
||
first_index = indices[0]
|
||
index_data = analyzer.get_index_data(first_index['code'])
|
||
print(f"指数 {first_index['name']} 数据:")
|
||
print(f"日期数量: {len(index_data['dates'])}")
|
||
if index_data['dates']:
|
||
print(f"第一个日期: {index_data['dates'][0]}, 值: {index_data['values'][0]}") |