commit;
This commit is contained in:
parent
b4c8b6d1b4
commit
b3332eb920
76
src/app.py
76
src/app.py
|
|
@ -2,6 +2,7 @@
|
|||
import sys
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
import pandas as pd
|
||||
import uuid
|
||||
import json
|
||||
|
|
@ -74,6 +75,9 @@ from src.tushare_scripts.chip_distribution_collector import collect_chip_distrib
|
|||
from src.tushare_scripts.stock_factor_collector import collect_stock_factor
|
||||
from src.tushare_scripts.stock_factor_pro_collector import collect_stock_factor_pro
|
||||
|
||||
# 导入科技主题基本面因子选股策略
|
||||
from src.quantitative_analysis.tech_fundamental_factor_strategy_v3 import TechFundamentalFactorStrategy
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -460,7 +464,7 @@ def run_chip_distribution_collection():
|
|||
# mode='full'
|
||||
# )
|
||||
collect_chip_distribution(db_url=db_url, tushare_token=TUSHARE_TOKEN, mode='full',
|
||||
start_date='2022-01-03', end_date='2022-12-31')
|
||||
start_date='2021-01-03', end_date='2021-09-30')
|
||||
# collect_chip_distribution(
|
||||
# db_url=db_url,
|
||||
# tushare_token=TUSHARE_TOKEN,
|
||||
|
|
@ -3780,6 +3784,76 @@ def analyze_stock_overlap():
|
|||
}), 500
|
||||
|
||||
|
||||
@app.route('/scheduler/techFundamentalStrategy/batch', methods=['GET', 'POST'])
|
||||
def run_tech_fundamental_strategy_batch():
|
||||
"""批量运行科技主题基本面因子选股策略并保存到数据库"""
|
||||
try:
|
||||
# 获取参数
|
||||
start_date = request.args.get('start_date') or (request.get_json() if request.is_json else {}).get('start_date')
|
||||
end_date = request.args.get('end_date') or (request.get_json() if request.is_json else {}).get('end_date')
|
||||
|
||||
if not start_date or not end_date:
|
||||
return jsonify({"status": "error", "message": "缺少参数: start_date 和 end_date"}), 400
|
||||
|
||||
# 验证日期
|
||||
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
|
||||
|
||||
# 数据库连接(用于检查交易日数据)
|
||||
from sqlalchemy import create_engine
|
||||
db_url = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||||
engine = create_engine(db_url)
|
||||
|
||||
# 循环日期区间
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0
|
||||
current_date = start_dt
|
||||
|
||||
while current_date <= end_dt:
|
||||
trade_date = current_date.strftime('%Y-%m-%d')
|
||||
trade_datetime = f"{trade_date} 00:00:00"
|
||||
|
||||
# 检查该日期是否有数据
|
||||
check_query = text("""
|
||||
SELECT COUNT(0) as cnt
|
||||
FROM gp_day_data
|
||||
WHERE `timestamp` = :trade_datetime
|
||||
""")
|
||||
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(check_query, {"trade_datetime": trade_datetime}).fetchone()
|
||||
count = result[0] if result else 0
|
||||
|
||||
# 只有当有数据时才运行策略
|
||||
if count > 0:
|
||||
try:
|
||||
strategy = TechFundamentalFactorStrategy(target_date=trade_date)
|
||||
strategy.run_strategy()
|
||||
strategy.close_connections()
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
logger.error(f"日期 {trade_date} 处理失败: {str(e)}")
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
engine.dispose()
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"success_count": success_count,
|
||||
"failed_count": failed_count,
|
||||
"skipped_count": skipped_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量运行失败: {str(e)}")
|
||||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# 启动Web服务器
|
||||
|
|
|
|||
|
|
@ -55,16 +55,32 @@ class AverageDistanceFactor:
|
|||
print(f"获取股票列表失败: {e}")
|
||||
return []
|
||||
|
||||
def get_stock_data(self, symbols, days=20):
|
||||
"""获取股票的历史数据"""
|
||||
def get_stock_data(self, symbols, days=20, end_date=None):
|
||||
"""
|
||||
获取股票的历史数据
|
||||
|
||||
Args:
|
||||
symbols: 股票代码列表
|
||||
days: 需要获取的天数
|
||||
end_date: 结束日期(datetime对象或字符串),如果为None则使用今天
|
||||
"""
|
||||
if not symbols:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算开始日期
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days * 2) # 多取一些数据以防节假日
|
||||
# 计算结束日期(及时数据使用目标日期当天)
|
||||
if end_date is None:
|
||||
end_date = datetime.now()
|
||||
elif isinstance(end_date, str):
|
||||
end_date = datetime.strptime(end_date, '%Y-%m-%d')
|
||||
|
||||
# 构建SQL查询
|
||||
# 及时数据(日线数据)应该包含目标日期当天的数据
|
||||
# 例如:如果目标日期是2025-11-01,则使用2025-11-01及之前的数据
|
||||
query_end_date = end_date
|
||||
|
||||
# 计算开始日期(多取一些数据以防节假日)
|
||||
start_date = query_end_date - timedelta(days=days * 2)
|
||||
|
||||
# 构建SQL查询(使用 <= 包含目标日期当天)
|
||||
symbols_str = "', '".join(symbols)
|
||||
query = f"""
|
||||
SELECT symbol, timestamp, volume, open, high, low, close,
|
||||
|
|
@ -72,28 +88,36 @@ class AverageDistanceFactor:
|
|||
FROM gp_day_data
|
||||
WHERE symbol IN ('{symbols_str}')
|
||||
AND timestamp >= '{start_date.strftime('%Y-%m-%d')}'
|
||||
AND timestamp <= '{end_date.strftime('%Y-%m-%d')}'
|
||||
ORDER BY symbol, timestamp DESC
|
||||
"""
|
||||
|
||||
try:
|
||||
df = pd.read_sql(query, self.engine)
|
||||
print(f"获取到 {len(df)} 条历史数据")
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f"获取历史数据失败: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def calculate_technical_indicators(self, df, days=20):
|
||||
"""计算技术指标"""
|
||||
"""
|
||||
计算技术指标
|
||||
|
||||
Args:
|
||||
df: 股票历史数据DataFrame(应已按目标日期筛选)
|
||||
days: 计算指标使用的天数(取目标日期前N天)
|
||||
"""
|
||||
result_data = []
|
||||
|
||||
for symbol in df['symbol'].unique():
|
||||
stock_data = df[df['symbol'] == symbol].copy()
|
||||
# 按时间升序排序,确保时间顺序正确
|
||||
stock_data = stock_data.sort_values('timestamp')
|
||||
|
||||
# 只取最近N天的数据
|
||||
# 只取目标日期前N个交易日的数据(包括目标日期当天)
|
||||
# tail取最后N行,即最接近目标日期的N个交易日(包括目标日期当天)
|
||||
stock_data = stock_data.tail(days)
|
||||
|
||||
# 检查数据是否足够(需要至少有days个交易日的数据)
|
||||
if len(stock_data) < days:
|
||||
continue # 数据不足,跳过
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ import sys
|
|||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# __file__ 是当前文件路径,例如: /app/src/quantitative_analysis/company_lifecycle_factor.py
|
||||
# 需要获取项目根目录: /app
|
||||
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # /app/src/quantitative_analysis
|
||||
src_dir = os.path.dirname(current_file_dir) # /app/src
|
||||
project_root = os.path.dirname(src_dir) # /app (项目根目录)
|
||||
sys.path.append(project_root)
|
||||
|
||||
# 导入配置
|
||||
|
|
@ -16,7 +20,7 @@ try:
|
|||
from valuation_analysis.config import MONGO_CONFIG2
|
||||
except ImportError:
|
||||
import importlib.util
|
||||
config_path = os.path.join(project_root, 'valuation_analysis', 'config.py')
|
||||
config_path = os.path.join(src_dir, 'valuation_analysis', 'config.py')
|
||||
spec = importlib.util.spec_from_file_location("config", config_path)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
|
@ -27,7 +31,12 @@ try:
|
|||
from tools.stock_code_formatter import StockCodeFormatter
|
||||
except ImportError:
|
||||
import importlib.util
|
||||
formatter_path = os.path.join(os.path.dirname(project_root), 'tools', 'stock_code_formatter.py')
|
||||
# project_root 已经是项目根目录,直接拼接 tools 目录
|
||||
formatter_path = os.path.join(project_root, 'tools', 'stock_code_formatter.py')
|
||||
|
||||
if not os.path.exists(formatter_path):
|
||||
raise ImportError(f"无法找到 stock_code_formatter.py,路径: {formatter_path}")
|
||||
|
||||
spec = importlib.util.spec_from_file_location("stock_code_formatter", formatter_path)
|
||||
formatter_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(formatter_module)
|
||||
|
|
@ -120,7 +129,7 @@ class CompanyLifecycleFactor:
|
|||
annual_data = self.collection.find_one(query)
|
||||
|
||||
if annual_data:
|
||||
logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
|
||||
# logger.info(f"找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
|
||||
return annual_data
|
||||
else:
|
||||
logger.warning(f"未找到年报数据: {stock_code} (标准化后: {normalized_code}) - {report_date}")
|
||||
|
|
|
|||
|
|
@ -464,11 +464,11 @@ if __name__ == "__main__":
|
|||
# collect_chip_distribution(db_url, tushare_token, mode='full')
|
||||
|
||||
# 3. 采集指定日期的数据
|
||||
collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-24')
|
||||
# collect_chip_distribution(db_url, tushare_token, mode='daily', date='2025-11-25')
|
||||
|
||||
# 4. 采集指定日期范围的数据
|
||||
# collect_chip_distribution(db_url, tushare_token, mode='full',
|
||||
# start_date='2021-11-01', end_date='2021-11-30')
|
||||
collect_chip_distribution(db_url, tushare_token, mode='full',
|
||||
start_date='2021-01-03', end_date='2021-09-30')
|
||||
|
||||
# 5. 调整批量入库大小(默认100只股票一批)
|
||||
# collect_chip_distribution(db_url, tushare_token, mode='daily', batch_size=200)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,242 @@
|
|||
# coding:utf-8
|
||||
"""
|
||||
同花顺概念板块成分股采集工具
|
||||
功能:从 Tushare 获取同花顺概念板块成分股数据并落库
|
||||
API 文档: https://tushare.pro/document/2?doc_id=261
|
||||
说明:每次调用时全量覆盖,不需要每日更新
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加项目根目录到路径,确保能够读取配置
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from src.scripts.config import TUSHARE_TOKEN
|
||||
|
||||
|
||||
class THSConceptMemberCollector:
|
||||
"""同花顺概念板块成分股采集器"""
|
||||
|
||||
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_concept_member"):
|
||||
"""
|
||||
Args:
|
||||
db_url: 数据库连接 URL
|
||||
tushare_token: Tushare Token
|
||||
table_name: 目标表名,默认 ths_concept_member
|
||||
"""
|
||||
self.engine = create_engine(
|
||||
db_url,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
self.table_name = table_name
|
||||
|
||||
ts.set_token(tushare_token)
|
||||
self.pro = ts.pro_api()
|
||||
|
||||
print("=" * 60)
|
||||
print("同花顺概念板块成分股采集工具")
|
||||
print(f"目标数据表: {self.table_name}")
|
||||
print("=" * 60)
|
||||
|
||||
@staticmethod
|
||||
def convert_tushare_code_to_db(ts_code: str) -> str:
|
||||
"""
|
||||
将 Tushare 代码(600000.SH)转换为数据库代码(SH600000)
|
||||
"""
|
||||
if not ts_code or "." not in ts_code:
|
||||
return ts_code
|
||||
base, market = ts_code.split(".")
|
||||
return f"{market}{base}"
|
||||
|
||||
def fetch_concept_index_list(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取所有同花顺概念板块列表
|
||||
API: ths_index
|
||||
文档: https://tushare.pro/document/2?doc_id=259
|
||||
|
||||
Returns:
|
||||
概念板块列表 DataFrame
|
||||
"""
|
||||
try:
|
||||
print("正在获取同花顺概念板块列表...")
|
||||
df = self.pro.ths_index(
|
||||
exchange='A',
|
||||
type='N', # N=概念板块
|
||||
)
|
||||
if df.empty:
|
||||
print("未获取到概念板块数据")
|
||||
return pd.DataFrame()
|
||||
print(f"成功获取 {len(df)} 个概念板块")
|
||||
return df
|
||||
except Exception as exc:
|
||||
print(f"获取概念板块列表失败: {exc}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_concept_members(self, ts_code: str) -> pd.DataFrame:
|
||||
"""
|
||||
获取指定概念板块的成分股
|
||||
API: ths_member
|
||||
文档: https://tushare.pro/document/2?doc_id=261
|
||||
|
||||
Args:
|
||||
ts_code: 概念板块代码,如 '885800.TI'
|
||||
|
||||
Returns:
|
||||
成分股列表 DataFrame
|
||||
"""
|
||||
try:
|
||||
df = self.pro.ths_member(ts_code=ts_code)
|
||||
return df
|
||||
except Exception as exc:
|
||||
print(f"获取概念板块 {ts_code} 成分股失败: {exc}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def transform_data(self, member_df: pd.DataFrame, concept_name: str = "") -> pd.DataFrame:
|
||||
"""
|
||||
将 Tushare 返回的数据转换为数据库入库格式
|
||||
|
||||
Args:
|
||||
member_df: 成分股信息(包含 ts_code, con_code, con_name 等)
|
||||
concept_name: 概念板块名称(从概念板块列表中获取)
|
||||
|
||||
Returns:
|
||||
转换后的 DataFrame
|
||||
"""
|
||||
if member_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 检查必要字段
|
||||
if "ts_code" not in member_df.columns:
|
||||
print(f"警告: 未找到概念板块代码字段(ts_code),可用字段: {member_df.columns.tolist()}")
|
||||
return pd.DataFrame()
|
||||
if "con_code" not in member_df.columns:
|
||||
print(f"警告: 未找到股票代码字段(con_code),可用字段: {member_df.columns.tolist()}")
|
||||
return pd.DataFrame()
|
||||
|
||||
result = pd.DataFrame()
|
||||
# ts_code 是概念板块代码
|
||||
result["concept_code"] = member_df["ts_code"]
|
||||
result["concept_name"] = concept_name
|
||||
# con_code 是股票代码
|
||||
result["stock_ts_code"] = member_df["con_code"]
|
||||
result["stock_symbol"] = member_df["con_code"].apply(self.convert_tushare_code_to_db)
|
||||
# con_name 是股票名称
|
||||
result["stock_name"] = member_df["con_name"] if "con_name" in member_df.columns else ""
|
||||
# is_new 是否最新(接口有返回,但weight、in_date、out_date暂无数据,暂不存储)
|
||||
result["is_new"] = member_df["is_new"] if "is_new" in member_df.columns else None
|
||||
result["created_at"] = datetime.now()
|
||||
result["updated_at"] = datetime.now()
|
||||
return result
|
||||
|
||||
def save_dataframe(self, df: pd.DataFrame) -> None:
|
||||
"""
|
||||
将数据写入数据库
|
||||
"""
|
||||
if df.empty:
|
||||
return
|
||||
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
||||
|
||||
def run_full_collection(self) -> None:
|
||||
"""
|
||||
执行全量覆盖采集:
|
||||
- 清空目标表
|
||||
- 获取所有概念板块列表
|
||||
- 遍历每个概念板块,获取成分股数据
|
||||
- 全量写入数据库
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("开始执行全量覆盖采集(同花顺概念板块成分股)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 清空表
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
||||
print(f"{self.table_name} 已清空")
|
||||
|
||||
# 获取所有概念板块列表
|
||||
concept_list_df = self.fetch_concept_index_list()
|
||||
if concept_list_df.empty:
|
||||
print("未获取到概念板块列表,采集终止")
|
||||
return
|
||||
|
||||
total_records = 0
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# 遍历每个概念板块,获取成分股
|
||||
for idx, row in concept_list_df.iterrows():
|
||||
concept_code = row["ts_code"]
|
||||
concept_name = row["name"]
|
||||
|
||||
print(f"\n[{idx + 1}/{len(concept_list_df)}] 正在采集: {concept_name} ({concept_code})")
|
||||
|
||||
try:
|
||||
# 获取该概念板块的成分股
|
||||
member_df = self.fetch_concept_members(concept_code)
|
||||
if member_df.empty:
|
||||
print(f" 概念板块 {concept_name} 无成分股数据")
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
# 转换数据格式(传入概念板块名称)
|
||||
result_df = self.transform_data(member_df, concept_name)
|
||||
if result_df.empty:
|
||||
print(f" 概念板块 {concept_name} 数据转换失败")
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
# 保存数据
|
||||
self.save_dataframe(result_df)
|
||||
total_records += len(result_df)
|
||||
success_count += 1
|
||||
print(f" 成功采集 {len(result_df)} 只成分股")
|
||||
|
||||
except Exception as exc:
|
||||
print(f" 采集 {concept_name} ({concept_code}) 失败: {exc}")
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("全量覆盖采集完成")
|
||||
print(f"总概念板块数: {len(concept_list_df)}")
|
||||
print(f"成功采集: {success_count}")
|
||||
print(f"失败: {failed_count}")
|
||||
print(f"累计成分股记录: {total_records}")
|
||||
print("=" * 60)
|
||||
except Exception as exc:
|
||||
print(f"全量采集失败: {exc}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.engine.dispose()
|
||||
|
||||
|
||||
def collect_ths_concept_member(
|
||||
db_url: str,
|
||||
tushare_token: str,
|
||||
):
|
||||
"""
|
||||
采集入口 - 全量覆盖采集
|
||||
"""
|
||||
collector = THSConceptMemberCollector(db_url, tushare_token)
|
||||
collector.run_full_collection()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||||
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||||
TOKEN = TUSHARE_TOKEN
|
||||
|
||||
# 执行全量覆盖采集
|
||||
collect_ths_concept_member(DB_URL, TOKEN)
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
# coding:utf-8
|
||||
"""
|
||||
同花顺概念和行业指数采集工具
|
||||
功能:从 Tushare 获取同花顺概念和行业指数数据并落库
|
||||
API 文档: https://tushare.pro/document/2?doc_id=259
|
||||
说明:每次调用时全量覆盖,不需要每日更新
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加项目根目录到路径,确保能够读取配置
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from src.scripts.config import TUSHARE_TOKEN
|
||||
|
||||
|
||||
class THSIndexCollector:
|
||||
"""同花顺概念和行业指数采集器"""
|
||||
|
||||
def __init__(self, db_url: str, tushare_token: str, table_name: str = "ths_index"):
|
||||
"""
|
||||
Args:
|
||||
db_url: 数据库连接 URL
|
||||
tushare_token: Tushare Token
|
||||
table_name: 目标表名,默认 ths_index
|
||||
"""
|
||||
self.engine = create_engine(
|
||||
db_url,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
self.table_name = table_name
|
||||
|
||||
ts.set_token(tushare_token)
|
||||
self.pro = ts.pro_api()
|
||||
|
||||
print("=" * 60)
|
||||
print("同花顺概念和行业指数采集工具")
|
||||
print(f"目标数据表: {self.table_name}")
|
||||
print("=" * 60)
|
||||
|
||||
def fetch_index_list(
|
||||
self,
|
||||
exchange: Optional[str] = None,
|
||||
index_type: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
获取同花顺概念和行业指数列表
|
||||
API: ths_index
|
||||
文档: https://tushare.pro/document/2?doc_id=259
|
||||
|
||||
Args:
|
||||
exchange: 市场类型 A-a股 HK-港股 US-美股,None表示全部
|
||||
index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数
|
||||
ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数,None表示全部
|
||||
|
||||
Returns:
|
||||
指数列表 DataFrame
|
||||
"""
|
||||
try:
|
||||
params = {}
|
||||
if exchange:
|
||||
params['exchange'] = exchange
|
||||
if index_type:
|
||||
params['type'] = index_type
|
||||
|
||||
print(f"正在获取同花顺指数列表...")
|
||||
if params:
|
||||
print(f"筛选条件: {params}")
|
||||
|
||||
df = self.pro.ths_index(**params)
|
||||
if df.empty:
|
||||
print("未获取到指数数据")
|
||||
return pd.DataFrame()
|
||||
print(f"成功获取 {len(df)} 个指数")
|
||||
return df
|
||||
except Exception as exc:
|
||||
print(f"获取指数列表失败: {exc}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
将 Tushare 返回的数据转换为数据库入库格式
|
||||
|
||||
Args:
|
||||
df: 指数信息 DataFrame
|
||||
|
||||
Returns:
|
||||
转换后的 DataFrame
|
||||
"""
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
result = pd.DataFrame()
|
||||
result["ts_code"] = df["ts_code"]
|
||||
result["name"] = df["name"]
|
||||
result["count"] = df.get("count", None)
|
||||
result["exchange"] = df.get("exchange", None)
|
||||
result["list_date"] = pd.to_datetime(df["list_date"], format="%Y%m%d", errors='coerce') if "list_date" in df.columns else None
|
||||
result["type"] = df.get("type", None)
|
||||
result["created_at"] = datetime.now()
|
||||
result["updated_at"] = datetime.now()
|
||||
return result
|
||||
|
||||
def save_dataframe(self, df: pd.DataFrame) -> None:
|
||||
"""
|
||||
将数据写入数据库
|
||||
"""
|
||||
if df.empty:
|
||||
return
|
||||
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
||||
|
||||
def run_full_collection(
|
||||
self,
|
||||
exchange: Optional[str] = None,
|
||||
index_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行全量覆盖采集:
|
||||
- 清空目标表
|
||||
- 获取所有指数数据
|
||||
- 全量写入数据库
|
||||
|
||||
Args:
|
||||
exchange: 市场类型,None表示全部
|
||||
index_type: 指数类型,None表示全部
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("开始执行全量覆盖采集(同花顺概念和行业指数)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 清空表
|
||||
with self.engine.begin() as conn:
|
||||
# 如果指定了筛选条件,只删除符合条件的记录;否则清空全部
|
||||
if exchange or index_type:
|
||||
delete_conditions = []
|
||||
params = {}
|
||||
if exchange:
|
||||
delete_conditions.append("exchange = :exchange")
|
||||
params["exchange"] = exchange
|
||||
if index_type:
|
||||
delete_conditions.append("type = :type")
|
||||
params["type"] = index_type
|
||||
delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}")
|
||||
conn.execute(delete_sql, params)
|
||||
print(f"{self.table_name} 已删除符合条件的旧数据")
|
||||
else:
|
||||
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
||||
print(f"{self.table_name} 已清空")
|
||||
|
||||
# 获取指数列表
|
||||
index_df = self.fetch_index_list(exchange=exchange, index_type=index_type)
|
||||
if index_df.empty:
|
||||
print("未获取到指数数据,采集终止")
|
||||
return
|
||||
|
||||
# 转换数据格式
|
||||
result_df = self.transform_data(index_df)
|
||||
if result_df.empty:
|
||||
print("数据转换失败")
|
||||
return
|
||||
|
||||
# 保存数据
|
||||
self.save_dataframe(result_df)
|
||||
print(f"成功写入 {len(result_df)} 条记录")
|
||||
|
||||
# 按类型统计
|
||||
if "type" in result_df.columns:
|
||||
type_stats = result_df.groupby("type").size()
|
||||
print("\n按类型统计:")
|
||||
for idx, count in type_stats.items():
|
||||
print(f" {idx}: {count} 个")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("全量覆盖采集完成")
|
||||
print("=" * 60)
|
||||
except Exception as exc:
|
||||
print(f"全量采集失败: {exc}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.engine.dispose()
|
||||
|
||||
|
||||
def collect_ths_index(
|
||||
db_url: str,
|
||||
tushare_token: str,
|
||||
exchange: Optional[str] = None,
|
||||
index_type: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
采集入口 - 全量覆盖采集
|
||||
|
||||
Args:
|
||||
db_url: 数据库连接URL
|
||||
tushare_token: Tushare Token
|
||||
exchange: 市场类型 A-a股 HK-港股 US-美股,None表示全部
|
||||
index_type: 指数类型 N-概念指数 I-行业指数 R-地域指数 S-同花顺特色指数
|
||||
ST-同花顺风格指数 TH-同花顺主题指数 BB-同花顺宽基指数,None表示全部
|
||||
"""
|
||||
collector = THSIndexCollector(db_url, tushare_token)
|
||||
collector.run_full_collection(exchange=exchange, index_type=index_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||||
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||||
TOKEN = TUSHARE_TOKEN
|
||||
|
||||
# 执行全量覆盖采集(获取所有类型的指数)
|
||||
collect_ths_index(DB_URL, TOKEN)
|
||||
|
||||
# 如果只想获取特定类型,可以这样调用:
|
||||
# collect_ths_index(DB_URL, TOKEN, exchange='A', index_type='N') # 只获取A股概念指数
|
||||
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
# coding:utf-8
|
||||
"""
|
||||
交易日历采集工具
|
||||
功能:从 Tushare 获取交易日历数据并落库
|
||||
API 文档: https://tushare.pro/document/2?doc_id=26
|
||||
说明:每次调用时全量覆盖,不需要每日更新
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加项目根目录到路径,确保能够读取配置
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from src.scripts.config import TUSHARE_TOKEN
|
||||
|
||||
|
||||
class TradeCalCollector:
|
||||
"""交易日历采集器"""
|
||||
|
||||
def __init__(self, db_url: str, tushare_token: str, table_name: str = "trade_cal"):
|
||||
"""
|
||||
Args:
|
||||
db_url: 数据库连接 URL
|
||||
tushare_token: Tushare Token
|
||||
table_name: 目标表名,默认 trade_cal
|
||||
"""
|
||||
self.engine = create_engine(
|
||||
db_url,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
self.table_name = table_name
|
||||
|
||||
ts.set_token(tushare_token)
|
||||
self.pro = ts.pro_api()
|
||||
|
||||
print("=" * 60)
|
||||
print("交易日历采集工具")
|
||||
print(f"目标数据表: {self.table_name}")
|
||||
print("=" * 60)
|
||||
|
||||
def fetch_trade_cal(
|
||||
self,
|
||||
exchange: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
获取交易日历数据
|
||||
API: trade_cal
|
||||
文档: https://tushare.pro/document/2?doc_id=26
|
||||
|
||||
Args:
|
||||
exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所,None表示全部
|
||||
start_date: 开始日期,格式YYYYMMDD,None表示不限制
|
||||
end_date: 结束日期,格式YYYYMMDD,None表示不限制
|
||||
|
||||
Returns:
|
||||
交易日历 DataFrame
|
||||
"""
|
||||
try:
|
||||
params = {}
|
||||
if exchange:
|
||||
params['exchange'] = exchange
|
||||
if start_date:
|
||||
params['start_date'] = start_date
|
||||
if end_date:
|
||||
params['end_date'] = end_date
|
||||
|
||||
print("正在获取交易日历数据...")
|
||||
if params:
|
||||
print(f"筛选条件: {params}")
|
||||
|
||||
df = self.pro.trade_cal(**params)
|
||||
if df.empty:
|
||||
print("未获取到交易日历数据")
|
||||
return pd.DataFrame()
|
||||
print(f"成功获取 {len(df)} 条交易日历记录")
|
||||
return df
|
||||
except Exception as exc:
|
||||
print(f"获取交易日历数据失败: {exc}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def transform_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
将 Tushare 返回的数据转换为数据库入库格式
|
||||
|
||||
Args:
|
||||
df: 交易日历 DataFrame
|
||||
|
||||
Returns:
|
||||
转换后的 DataFrame
|
||||
"""
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
result = pd.DataFrame()
|
||||
result["exchange"] = df["exchange"]
|
||||
result["cal_date"] = pd.to_datetime(df["cal_date"], format="%Y%m%d", errors='coerce')
|
||||
result["is_open"] = df.get("is_open", None)
|
||||
result["pretrade_date"] = pd.to_datetime(df["pretrade_date"], format="%Y%m%d", errors='coerce') if "pretrade_date" in df.columns else None
|
||||
result["created_at"] = datetime.now()
|
||||
result["updated_at"] = datetime.now()
|
||||
return result
|
||||
|
||||
def save_dataframe(self, df: pd.DataFrame) -> None:
|
||||
"""
|
||||
将数据写入数据库
|
||||
"""
|
||||
if df.empty:
|
||||
return
|
||||
df.to_sql(self.table_name, self.engine, if_exists="append", index=False)
|
||||
|
||||
def run_full_collection(
|
||||
self,
|
||||
exchange: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行全量覆盖采集:
|
||||
- 清空目标表
|
||||
- 获取交易日历数据
|
||||
- 全量写入数据库
|
||||
|
||||
Args:
|
||||
exchange: 交易所代码,None表示全部
|
||||
start_date: 开始日期,格式YYYYMMDD,None表示不限制
|
||||
end_date: 结束日期,格式YYYYMMDD,None表示不限制
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("开始执行全量覆盖采集(交易日历)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 清空表
|
||||
with self.engine.begin() as conn:
|
||||
# 如果指定了筛选条件,只删除符合条件的记录;否则清空全部
|
||||
if exchange or start_date or end_date:
|
||||
delete_conditions = []
|
||||
params = {}
|
||||
if exchange:
|
||||
delete_conditions.append("exchange = :exchange")
|
||||
params["exchange"] = exchange
|
||||
if start_date:
|
||||
delete_conditions.append("cal_date >= :start_date")
|
||||
params["start_date"] = datetime.strptime(start_date, "%Y%m%d").date()
|
||||
if end_date:
|
||||
delete_conditions.append("cal_date <= :end_date")
|
||||
params["end_date"] = datetime.strptime(end_date, "%Y%m%d").date()
|
||||
delete_sql = text(f"DELETE FROM {self.table_name} WHERE {' AND '.join(delete_conditions)}")
|
||||
conn.execute(delete_sql, params)
|
||||
print(f"{self.table_name} 已删除符合条件的旧数据")
|
||||
else:
|
||||
conn.execute(text(f"TRUNCATE TABLE {self.table_name}"))
|
||||
print(f"{self.table_name} 已清空")
|
||||
|
||||
# 获取交易日历数据
|
||||
cal_df = self.fetch_trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
||||
if cal_df.empty:
|
||||
print("未获取到交易日历数据,采集终止")
|
||||
return
|
||||
|
||||
# 转换数据格式
|
||||
result_df = self.transform_data(cal_df)
|
||||
if result_df.empty:
|
||||
print("数据转换失败")
|
||||
return
|
||||
|
||||
# 保存数据
|
||||
self.save_dataframe(result_df)
|
||||
print(f"成功写入 {len(result_df)} 条记录")
|
||||
|
||||
# 按交易所统计
|
||||
if "exchange" in result_df.columns:
|
||||
print("\n按交易所统计:")
|
||||
for exchange_name in result_df["exchange"].unique():
|
||||
exchange_data = result_df[result_df["exchange"] == exchange_name]
|
||||
total_days = len(exchange_data)
|
||||
trade_days = int(exchange_data["is_open"].sum()) if "is_open" in exchange_data.columns and exchange_data["is_open"].notna().any() else 0
|
||||
min_date = exchange_data["cal_date"].min()
|
||||
max_date = exchange_data["cal_date"].max()
|
||||
print(f" {exchange_name}: 总天数={total_days}, 交易日={trade_days}, 日期范围={min_date} ~ {max_date}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("全量覆盖采集完成")
|
||||
print("=" * 60)
|
||||
except Exception as exc:
|
||||
print(f"全量采集失败: {exc}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.engine.dispose()
|
||||
|
||||
|
||||
def collect_trade_cal(
|
||||
db_url: str,
|
||||
tushare_token: str,
|
||||
exchange: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
采集入口 - 全量覆盖采集
|
||||
|
||||
Args:
|
||||
db_url: 数据库连接URL
|
||||
tushare_token: Tushare Token
|
||||
exchange: 交易所代码 SSE-上交所 SZSE-深交所 BSE-北交所,None表示全部
|
||||
start_date: 开始日期,格式YYYYMMDD,None表示不限制
|
||||
end_date: 结束日期,格式YYYYMMDD,None表示不限制
|
||||
"""
|
||||
collector = TradeCalCollector(db_url, tushare_token)
|
||||
collector.run_full_collection(exchange=exchange, start_date=start_date, end_date=end_date)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# DB_URL = "mysql+pymysql://root:Chlry#$.8@192.168.18.199:3306/db_gp_cj"
|
||||
DB_URL = "mysql+pymysql://fac_pattern:Chlry$%.8_app@192.168.16.153:3307/my_quant_db"
|
||||
TOKEN = TUSHARE_TOKEN
|
||||
|
||||
# 执行全量覆盖采集(获取所有交易所的交易日历,使用默认日期范围)
|
||||
collect_trade_cal(DB_URL, TOKEN)
|
||||
|
||||
# 如果只想获取特定交易所,可以这样调用:
|
||||
# collect_trade_cal(DB_URL, TOKEN, exchange='SSE')
|
||||
|
||||
Loading…
Reference in New Issue