File size: 6,069 Bytes
85653bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python3
"""
技术指标计算模块 - 纯pandas实现
支持常用的技术分析指标计算,无需外部依赖
"""

import pandas as pd
import numpy as np


def calculate_sma(data, window):
    """简单移动平均线 (Simple Moving Average)"""
    return data.rolling(window=window, min_periods=1).mean()


def calculate_ema(data, window):
    """指数移动平均线 (Exponential Moving Average)"""
    return data.ewm(span=window, adjust=False).mean()


def calculate_rsi(close, window=14):
    """相对强弱指数 (Relative Strength Index)"""
    delta = close.diff()
    gain = delta.where(delta > 0, 0)
    loss = -delta.where(delta < 0, 0)

    avg_gain = gain.rolling(window=window, min_periods=1).mean()
    avg_loss = loss.rolling(window=window, min_periods=1).mean()

    rs = avg_gain / avg_loss
    rsi = 100 - (100 / (1 + rs))
    return rsi


def calculate_macd(close, fast=12, slow=26, signal=9):
    """MACD指标 (Moving Average Convergence Divergence)"""
    ema_fast = calculate_ema(close, fast)
    ema_slow = calculate_ema(close, slow)
    macd = ema_fast - ema_slow
    macd_signal = calculate_ema(macd, signal)
    macd_hist = macd - macd_signal
    return macd, macd_signal, macd_hist


def calculate_bollinger_bands(close, window=20, std_dev=2):
    """布林带 (Bollinger Bands)"""
    sma = calculate_sma(close, window)
    std = close.rolling(window=window, min_periods=1).std()
    upper = sma + (std * std_dev)
    lower = sma - (std * std_dev)
    return upper, sma, lower


def calculate_stochastic(high, low, close, k_window=14, d_window=3):
    """随机指标 (Stochastic Oscillator)"""
    lowest_low = low.rolling(window=k_window, min_periods=1).min()
    highest_high = high.rolling(window=k_window, min_periods=1).max()

    # 避免除零错误
    range_hl = highest_high - lowest_low
    range_hl = range_hl.replace(0, np.nan)

    k_percent = 100 * ((close - lowest_low) / range_hl)
    d_percent = k_percent.rolling(window=d_window, min_periods=1).mean()

    return k_percent, d_percent


def calculate_atr(high, low, close, window=14):
    """平均真实波幅 (Average True Range)"""
    tr1 = high - low
    tr2 = abs(high - close.shift(1))
    tr3 = abs(low - close.shift(1))

    # 计算真实波幅
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)

    # 计算ATR
    atr = tr.rolling(window=window, min_periods=1).mean()
    return atr


def calculate_williams_r(high, low, close, window=14):
    """威廉指标 (Williams %R)"""
    highest_high = high.rolling(window=window, min_periods=1).max()
    lowest_low = low.rolling(window=window, min_periods=1).min()

    # 避免除零错误
    range_hl = highest_high - lowest_low
    range_hl = range_hl.replace(0, np.nan)

    wr = -100 * ((highest_high - close) / range_hl)
    return wr


def add_technical_indicators(df, indicators_config=None):
    """
    为DataFrame添加技术指标

    Args:
        df: 包含OHLCV数据的DataFrame
        indicators_config: 指标配置字典

    Returns:
        添加了技术指标的DataFrame
    """
    if indicators_config is None:
        # 默认指标配置 - 简化配置,减少指标数量
        indicators_config = {
            'sma': [5, 10, 20],
            'ema': [12, 26],
            'rsi': [14],
            'macd': True,
            'bollinger': True,
            'atr': [14]
        }

    df = df.copy()

    # 确保必要的列存在
    required_cols = ['open', 'high', 'low', 'close']
    if not all(col in df.columns for col in required_cols):
        raise ValueError(f"DataFrame must contain columns: {required_cols}")

    try:
        # 简单移动平均线
        if 'sma' in indicators_config:
            for period in indicators_config['sma']:
                df[f'sma_{period}'] = calculate_sma(df['close'], period)

        # 指数移动平均线
        if 'ema' in indicators_config:
            for period in indicators_config['ema']:
                df[f'ema_{period}'] = calculate_ema(df['close'], period)

        # RSI
        if 'rsi' in indicators_config:
            for period in indicators_config['rsi']:
                df[f'rsi_{period}'] = calculate_rsi(df['close'], period)

        # MACD
        if indicators_config.get('macd'):
            macd, macd_signal, macd_hist = calculate_macd(df['close'])
            df['macd'] = macd
            df['macd_signal'] = macd_signal
            df['macd_hist'] = macd_hist

        # 布林带
        if indicators_config.get('bollinger'):
            bb_upper, bb_middle, bb_lower = calculate_bollinger_bands(df['close'])
            df['bb_upper'] = bb_upper
            df['bb_middle'] = bb_middle
            df['bb_lower'] = bb_lower

        # ATR
        if 'atr' in indicators_config:
            for period in indicators_config['atr']:
                df[f'atr_{period}'] = calculate_atr(df['high'], df['low'], df['close'], period)

        # 填充NaN值而不是删除行
        df = df.fillna(method='bfill').fillna(method='ffill')

        # 计算添加的指标数量
        basic_cols = ['open', 'high', 'low', 'close', 'volume', 'amount', 'timestamps']
        indicator_cols = [col for col in df.columns if col not in basic_cols]

        print(f"✅ 技术指标计算完成,添加了 {len(indicator_cols)} 个指标")

    except Exception as e:
        print(f"❌ 技术指标计算失败: {e}")
        # 如果指标计算失败,返回原始数据
        basic_cols = ['open', 'high', 'low', 'close']
        if 'volume' in df.columns:
            basic_cols.append('volume')
        if 'amount' in df.columns:
            basic_cols.append('amount')
        if 'timestamps' in df.columns:
            basic_cols.append('timestamps')
        return df[basic_cols]

    return df


def get_available_indicators():
    """获取可用的技术指标列表"""
    indicators = {
        'trend': ['sma', 'ema', 'macd', 'bollinger'],
        'momentum': ['rsi'],
        'volatility': ['atr'],
        'volume': []
    }

    return indicators