crypt / webui /technical_indicators.py
heyunfei's picture
Upload 56 files
85653bc verified
raw
history blame
6.07 kB
#!/usr/bin/env python3
"""
技术指标计算模块 - 纯pandas实现
支持常用的技术分析指标计算,无需外部依赖
"""
import pandas as pd
import numpy as np
def calculate_sma(data, window):
"""简单移动平均线 (Simple Moving Average)"""
return data.rolling(window=window, min_periods=1).mean()
def calculate_ema(data, window):
"""指数移动平均线 (Exponential Moving Average)"""
return data.ewm(span=window, adjust=False).mean()
def calculate_rsi(close, window=14):
"""相对强弱指数 (Relative Strength Index)"""
delta = close.diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=window, min_periods=1).mean()
avg_loss = loss.rolling(window=window, min_periods=1).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
def calculate_macd(close, fast=12, slow=26, signal=9):
"""MACD指标 (Moving Average Convergence Divergence)"""
ema_fast = calculate_ema(close, fast)
ema_slow = calculate_ema(close, slow)
macd = ema_fast - ema_slow
macd_signal = calculate_ema(macd, signal)
macd_hist = macd - macd_signal
return macd, macd_signal, macd_hist
def calculate_bollinger_bands(close, window=20, std_dev=2):
"""布林带 (Bollinger Bands)"""
sma = calculate_sma(close, window)
std = close.rolling(window=window, min_periods=1).std()
upper = sma + (std * std_dev)
lower = sma - (std * std_dev)
return upper, sma, lower
def calculate_stochastic(high, low, close, k_window=14, d_window=3):
"""随机指标 (Stochastic Oscillator)"""
lowest_low = low.rolling(window=k_window, min_periods=1).min()
highest_high = high.rolling(window=k_window, min_periods=1).max()
# 避免除零错误
range_hl = highest_high - lowest_low
range_hl = range_hl.replace(0, np.nan)
k_percent = 100 * ((close - lowest_low) / range_hl)
d_percent = k_percent.rolling(window=d_window, min_periods=1).mean()
return k_percent, d_percent
def calculate_atr(high, low, close, window=14):
"""平均真实波幅 (Average True Range)"""
tr1 = high - low
tr2 = abs(high - close.shift(1))
tr3 = abs(low - close.shift(1))
# 计算真实波幅
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
# 计算ATR
atr = tr.rolling(window=window, min_periods=1).mean()
return atr
def calculate_williams_r(high, low, close, window=14):
"""威廉指标 (Williams %R)"""
highest_high = high.rolling(window=window, min_periods=1).max()
lowest_low = low.rolling(window=window, min_periods=1).min()
# 避免除零错误
range_hl = highest_high - lowest_low
range_hl = range_hl.replace(0, np.nan)
wr = -100 * ((highest_high - close) / range_hl)
return wr
def add_technical_indicators(df, indicators_config=None):
"""
为DataFrame添加技术指标
Args:
df: 包含OHLCV数据的DataFrame
indicators_config: 指标配置字典
Returns:
添加了技术指标的DataFrame
"""
if indicators_config is None:
# 默认指标配置 - 简化配置,减少指标数量
indicators_config = {
'sma': [5, 10, 20],
'ema': [12, 26],
'rsi': [14],
'macd': True,
'bollinger': True,
'atr': [14]
}
df = df.copy()
# 确保必要的列存在
required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols):
raise ValueError(f"DataFrame must contain columns: {required_cols}")
try:
# 简单移动平均线
if 'sma' in indicators_config:
for period in indicators_config['sma']:
df[f'sma_{period}'] = calculate_sma(df['close'], period)
# 指数移动平均线
if 'ema' in indicators_config:
for period in indicators_config['ema']:
df[f'ema_{period}'] = calculate_ema(df['close'], period)
# RSI
if 'rsi' in indicators_config:
for period in indicators_config['rsi']:
df[f'rsi_{period}'] = calculate_rsi(df['close'], period)
# MACD
if indicators_config.get('macd'):
macd, macd_signal, macd_hist = calculate_macd(df['close'])
df['macd'] = macd
df['macd_signal'] = macd_signal
df['macd_hist'] = macd_hist
# 布林带
if indicators_config.get('bollinger'):
bb_upper, bb_middle, bb_lower = calculate_bollinger_bands(df['close'])
df['bb_upper'] = bb_upper
df['bb_middle'] = bb_middle
df['bb_lower'] = bb_lower
# ATR
if 'atr' in indicators_config:
for period in indicators_config['atr']:
df[f'atr_{period}'] = calculate_atr(df['high'], df['low'], df['close'], period)
# 填充NaN值而不是删除行
df = df.fillna(method='bfill').fillna(method='ffill')
# 计算添加的指标数量
basic_cols = ['open', 'high', 'low', 'close', 'volume', 'amount', 'timestamps']
indicator_cols = [col for col in df.columns if col not in basic_cols]
print(f"✅ 技术指标计算完成,添加了 {len(indicator_cols)} 个指标")
except Exception as e:
print(f"❌ 技术指标计算失败: {e}")
# 如果指标计算失败,返回原始数据
basic_cols = ['open', 'high', 'low', 'close']
if 'volume' in df.columns:
basic_cols.append('volume')
if 'amount' in df.columns:
basic_cols.append('amount')
if 'timestamps' in df.columns:
basic_cols.append('timestamps')
return df[basic_cols]
return df
def get_available_indicators():
"""获取可用的技术指标列表"""
indicators = {
'trend': ['sma', 'ema', 'macd', 'bollinger'],
'momentum': ['rsi'],
'volatility': ['atr'],
'volume': []
}
return indicators