|
|
|
|
|
""" |
|
|
技术指标计算模块 - 纯pandas实现 |
|
|
支持常用的技术分析指标计算,无需外部依赖 |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def calculate_sma(data, window): |
|
|
"""简单移动平均线 (Simple Moving Average)""" |
|
|
return data.rolling(window=window, min_periods=1).mean() |
|
|
|
|
|
|
|
|
def calculate_ema(data, window): |
|
|
"""指数移动平均线 (Exponential Moving Average)""" |
|
|
return data.ewm(span=window, adjust=False).mean() |
|
|
|
|
|
|
|
|
def calculate_rsi(close, window=14): |
|
|
"""相对强弱指数 (Relative Strength Index)""" |
|
|
delta = close.diff() |
|
|
gain = delta.where(delta > 0, 0) |
|
|
loss = -delta.where(delta < 0, 0) |
|
|
|
|
|
avg_gain = gain.rolling(window=window, min_periods=1).mean() |
|
|
avg_loss = loss.rolling(window=window, min_periods=1).mean() |
|
|
|
|
|
rs = avg_gain / avg_loss |
|
|
rsi = 100 - (100 / (1 + rs)) |
|
|
return rsi |
|
|
|
|
|
|
|
|
def calculate_macd(close, fast=12, slow=26, signal=9): |
|
|
"""MACD指标 (Moving Average Convergence Divergence)""" |
|
|
ema_fast = calculate_ema(close, fast) |
|
|
ema_slow = calculate_ema(close, slow) |
|
|
macd = ema_fast - ema_slow |
|
|
macd_signal = calculate_ema(macd, signal) |
|
|
macd_hist = macd - macd_signal |
|
|
return macd, macd_signal, macd_hist |
|
|
|
|
|
|
|
|
def calculate_bollinger_bands(close, window=20, std_dev=2): |
|
|
"""布林带 (Bollinger Bands)""" |
|
|
sma = calculate_sma(close, window) |
|
|
std = close.rolling(window=window, min_periods=1).std() |
|
|
upper = sma + (std * std_dev) |
|
|
lower = sma - (std * std_dev) |
|
|
return upper, sma, lower |
|
|
|
|
|
|
|
|
def calculate_stochastic(high, low, close, k_window=14, d_window=3): |
|
|
"""随机指标 (Stochastic Oscillator)""" |
|
|
lowest_low = low.rolling(window=k_window, min_periods=1).min() |
|
|
highest_high = high.rolling(window=k_window, min_periods=1).max() |
|
|
|
|
|
|
|
|
range_hl = highest_high - lowest_low |
|
|
range_hl = range_hl.replace(0, np.nan) |
|
|
|
|
|
k_percent = 100 * ((close - lowest_low) / range_hl) |
|
|
d_percent = k_percent.rolling(window=d_window, min_periods=1).mean() |
|
|
|
|
|
return k_percent, d_percent |
|
|
|
|
|
|
|
|
def calculate_atr(high, low, close, window=14): |
|
|
"""平均真实波幅 (Average True Range)""" |
|
|
tr1 = high - low |
|
|
tr2 = abs(high - close.shift(1)) |
|
|
tr3 = abs(low - close.shift(1)) |
|
|
|
|
|
|
|
|
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) |
|
|
|
|
|
|
|
|
atr = tr.rolling(window=window, min_periods=1).mean() |
|
|
return atr |
|
|
|
|
|
|
|
|
def calculate_williams_r(high, low, close, window=14): |
|
|
"""威廉指标 (Williams %R)""" |
|
|
highest_high = high.rolling(window=window, min_periods=1).max() |
|
|
lowest_low = low.rolling(window=window, min_periods=1).min() |
|
|
|
|
|
|
|
|
range_hl = highest_high - lowest_low |
|
|
range_hl = range_hl.replace(0, np.nan) |
|
|
|
|
|
wr = -100 * ((highest_high - close) / range_hl) |
|
|
return wr |
|
|
|
|
|
|
|
|
def add_technical_indicators(df, indicators_config=None): |
|
|
""" |
|
|
为DataFrame添加技术指标 |
|
|
|
|
|
Args: |
|
|
df: 包含OHLCV数据的DataFrame |
|
|
indicators_config: 指标配置字典 |
|
|
|
|
|
Returns: |
|
|
添加了技术指标的DataFrame |
|
|
""" |
|
|
if indicators_config is None: |
|
|
|
|
|
indicators_config = { |
|
|
'sma': [5, 10, 20], |
|
|
'ema': [12, 26], |
|
|
'rsi': [14], |
|
|
'macd': True, |
|
|
'bollinger': True, |
|
|
'atr': [14] |
|
|
} |
|
|
|
|
|
df = df.copy() |
|
|
|
|
|
|
|
|
required_cols = ['open', 'high', 'low', 'close'] |
|
|
if not all(col in df.columns for col in required_cols): |
|
|
raise ValueError(f"DataFrame must contain columns: {required_cols}") |
|
|
|
|
|
try: |
|
|
|
|
|
if 'sma' in indicators_config: |
|
|
for period in indicators_config['sma']: |
|
|
df[f'sma_{period}'] = calculate_sma(df['close'], period) |
|
|
|
|
|
|
|
|
if 'ema' in indicators_config: |
|
|
for period in indicators_config['ema']: |
|
|
df[f'ema_{period}'] = calculate_ema(df['close'], period) |
|
|
|
|
|
|
|
|
if 'rsi' in indicators_config: |
|
|
for period in indicators_config['rsi']: |
|
|
df[f'rsi_{period}'] = calculate_rsi(df['close'], period) |
|
|
|
|
|
|
|
|
if indicators_config.get('macd'): |
|
|
macd, macd_signal, macd_hist = calculate_macd(df['close']) |
|
|
df['macd'] = macd |
|
|
df['macd_signal'] = macd_signal |
|
|
df['macd_hist'] = macd_hist |
|
|
|
|
|
|
|
|
if indicators_config.get('bollinger'): |
|
|
bb_upper, bb_middle, bb_lower = calculate_bollinger_bands(df['close']) |
|
|
df['bb_upper'] = bb_upper |
|
|
df['bb_middle'] = bb_middle |
|
|
df['bb_lower'] = bb_lower |
|
|
|
|
|
|
|
|
if 'atr' in indicators_config: |
|
|
for period in indicators_config['atr']: |
|
|
df[f'atr_{period}'] = calculate_atr(df['high'], df['low'], df['close'], period) |
|
|
|
|
|
|
|
|
df = df.fillna(method='bfill').fillna(method='ffill') |
|
|
|
|
|
|
|
|
basic_cols = ['open', 'high', 'low', 'close', 'volume', 'amount', 'timestamps'] |
|
|
indicator_cols = [col for col in df.columns if col not in basic_cols] |
|
|
|
|
|
print(f"✅ 技术指标计算完成,添加了 {len(indicator_cols)} 个指标") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 技术指标计算失败: {e}") |
|
|
|
|
|
basic_cols = ['open', 'high', 'low', 'close'] |
|
|
if 'volume' in df.columns: |
|
|
basic_cols.append('volume') |
|
|
if 'amount' in df.columns: |
|
|
basic_cols.append('amount') |
|
|
if 'timestamps' in df.columns: |
|
|
basic_cols.append('timestamps') |
|
|
return df[basic_cols] |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def get_available_indicators(): |
|
|
"""获取可用的技术指标列表""" |
|
|
indicators = { |
|
|
'trend': ['sma', 'ema', 'macd', 'bollinger'], |
|
|
'momentum': ['rsi'], |
|
|
'volatility': ['atr'], |
|
|
'volume': [] |
|
|
} |
|
|
|
|
|
return indicators |
|
|
|
|
|
|
|
|
|
|
|
|