File size: 11,926 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import torch
def epxm1_x(x):
    # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
    result = torch.special.expm1(x) / x
    # replace NaN or inf values with 0
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
    mask = torch.abs(x) < 1e-2
    result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
    return result
def epxm1mx_x2(x):
    # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
    result = (torch.special.expm1(x) - x) / x**2
    # replace NaN or inf values with 0
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
    mask = torch.abs(x**2) < 1e-2
    result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
    return result

def expm1mxmhx2_x3(x):
    # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
    result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
    # replace NaN or inf values with 0
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
    mask = torch.abs(x**3) < 1e-2
    result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
    return result

def exp_1mcosh_GD(gamma_t, delta):
    """
    Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
    
    Parameters:
    gamma_t: Γ*t term (could be a scalar or tensor)
    delta: Δ term (could be a scalar or tensor)
    
    Returns:
    Result of the computation with numerical stability handling
    """
    # Main computation
    is_positive = delta > 0
    sqrt_abs_delta = torch.sqrt(torch.abs(delta))
    gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
    numerator_pos =  torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
    numerator_neg = torch.exp(-gamma_t) * ( 1 -  torch.cos(gamma_t * sqrt_abs_delta ) )
    numerator = torch.where(is_positive, numerator_pos, numerator_neg)
    result =  numerator / (delta * gamma_t**2 )
    # Handle NaN/inf cases
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
    # Handle numerical instability for small delta
    mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
    taylor = ( -0.5  - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
    result = torch.where(mask, taylor, result)
    return result

def exp_sinh_GsqrtD(gamma_t, delta):
    """
    Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)

    Parameters:
    gamma_t: Γ*t term (could be a scalar or tensor)
    delta: Δ term (could be a scalar or tensor)

    Returns:
    Result of the computation with numerical stability handling
    """
    # Main computation
    is_positive = delta > 0
    sqrt_abs_delta = torch.sqrt(torch.abs(delta))
    gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
    numerator_pos =  (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
    denominator_pos = gamma_t_sqrt_delta
    result_pos = numerator_pos / gamma_t_sqrt_delta
    result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))

    # Taylor expansion for small gamma_t_sqrt_delta
    mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
    taylor = ( 1  + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
    result_pos = torch.where(mask, taylor, result_pos)

    # Handle negative delta
    result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
    result = torch.where(is_positive, result_pos, result_neg)
    return result

def exp_cosh(gamma_t, delta):
    """
    Compute e^(-Γt) * cosh(Γt√Δ)

    Parameters:
    gamma_t: Γ*t term (could be a scalar or tensor)
    delta: Δ term (could be a scalar or tensor)

    Returns:
    Result of the computation with numerical stability handling
    """
    exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
    result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
    return result
def exp_sinh_sqrtD(gamma_t, delta):
    """
    Compute e^(-Γt) * sinh(Γt√Δ) / √Δ
    Parameters:
    gamma_t: Γ*t term (could be a scalar or tensor)
    delta: Δ term (could be a scalar or tensor)
    Returns:
    Result of the computation with numerical stability handling
    """
    exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
    result = gamma_t * exp_sinh_GsqrtD_result
    return result



def zeta1(gamma_t, delta):
    # Compute hyperbolic terms and exponential
    half_gamma_t = gamma_t / 2
    exp_cosh_term = exp_cosh(half_gamma_t, delta)
    exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)

    
    # Main computation
    numerator = 1 - (exp_cosh_term + exp_sinh_term)
    denominator = gamma_t * (1 - delta) / 4
    result = 1 - numerator / denominator
    
    # Handle numerical instability
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
    
    # Taylor expansion for small x (similar to your epxm1Dx approach)
    mask = torch.abs(denominator) < 5e-3
    term1 = epxm1_x(-gamma_t)
    term2 = epxm1mx_x2(-gamma_t)
    term3 = expm1mxmhx2_x3(-gamma_t)
    taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
    result = torch.where(mask, taylor, result)
    
    return result

def exp_cosh_minus_terms(gamma_t, delta):
    """
    Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ))
    
    Parameters:
    gamma_t: Γ*t term (could be a scalar or tensor)
    delta: Δ term (could be a scalar or tensor)
    
    Returns:
    Result of the computation with numerical stability handling
    """
    exp_term = torch.exp(-gamma_t)
    # Compute individual terms
    exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
    exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta)  # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
    
    #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
    # Main computation
    numerator = exp_cosh_term - exp_cosh_delta_term
    denominator = gamma_t * (1 - delta)
    
    result = numerator / denominator
    
    # Handle numerical instability
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
    
    # Taylor expansion for small gamma_t and delta near 1
    mask = (torch.abs(denominator) < 1e-1)
    exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
    taylor = (
       gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) 
       - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
    )
    result = torch.where(mask, taylor, result)
    
    return result


def zeta2(gamma_t, delta):
    half_gamma_t = gamma_t / 2
    return exp_sinh_GsqrtD(half_gamma_t, delta)

def sig11(gamma_t, delta):
    return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)


def Zcoefs(gamma_t, delta):
    Zeta1 = zeta1(gamma_t, delta)
    Zeta2 = zeta2(gamma_t, delta)
    
    sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
    amplitude = torch.sqrt(sq_total)
    Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
    Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta)  / sig11(gamma_t, delta)  ) ** 0.5 
    #cterm = exp_cosh_minus_terms(gamma_t, delta)
    #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
    #Zcoef3 = 2 * torch.sqrt(  cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
    Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )

    return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude

def Zcoefs_asymp(gamma_t, delta):
    A_t = (gamma_t * (1 - delta) )/4
    return epxm1_x(- 2 * A_t)

class StochasticHarmonicOscillator:
    """
    Simulates a stochastic harmonic oscillator governed by the equations:
        dy(t) = q(t) dt
        dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt

    Also define v(t) = q(t) / √Γ, which is numerically more stable.
        
    Where:
        y(t) - Position variable
        q(t) - Velocity variable
        Γ - Damping coefficient
        A - Harmonic potential strength
        C - Constant force term
        D - Noise amplitude
        dw(t) - Wiener process (Brownian motion)
    """
    def __init__(self, Gamma, A, C, D):
        self.Gamma = Gamma
        self.A = A
        self.C = C
        self.D = D
        self.Delta = 1 - 4 * A / Gamma
    def sig11(self, gamma_t, delta):
        return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
    def sig22(self, gamma_t, delta):
        return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) 
    def dynamics(self, y0, v0, t):
        """
        Calculates the position and velocity variables at time t.

        Parameters:
            y0 (float): Initial position
            v0 (float): Initial velocity v(0) = q(0) / √Γ
            t (float): Time at which to evaluate the dynamics
        Returns:
            tuple: (y(t), v(t))
        """
        
        dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
        Delta = self.Delta + dummyzero
        Gamma_hat = self.Gamma * t + dummyzero
        A = self.A + dummyzero
        C = self.C + dummyzero
        D = self.D + dummyzero
        Gamma = self.Gamma + dummyzero
        zeta_1 = zeta1( Gamma_hat, Delta) 
        zeta_2 = zeta2( Gamma_hat, Delta)
        EE = 1 - Gamma_hat * zeta_2

        if v0 is None:
            v0 = torch.randn_like(y0) * D / 2 ** 0.5
            #v0 = (C - A * y0)/Gamma**0.5
        
        # Calculate mean position and velocity
        term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
        y_mean = term1 + y0
        v_mean =  (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
        
        cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
        cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
        cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)

        # sample new position and velocity with multivariate normal distribution

        batch_shape = y0.shape
        cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
        cov_matrix[..., 0, 0] = cov_yy
        cov_matrix[..., 0, 1] = cov_yv
        cov_matrix[..., 1, 0] = cov_yv  # symmetric
        cov_matrix[..., 1, 1] = cov_vv


        
        # Compute the Cholesky decomposition to get scale_tril
        #scale_tril = torch.linalg.cholesky(cov_matrix)
        scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
        tol = 1e-8
        cov_yy = torch.clamp( cov_yy, min = tol )
        sd_yy = torch.sqrt( cov_yy )
        inv_sd_yy = 1/(sd_yy)

        scale_tril[..., 0, 0] = sd_yy
        scale_tril[..., 0, 1] = 0.
        scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
        scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
        # check if it matches torch.linalg.
        #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
        # Sample correlated noise from multivariate normal
        mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
        mean[..., 0] = y_mean
        mean[..., 1] = v_mean
        new_yv = torch.distributions.MultivariateNormal(
            loc=mean,
            scale_tril=scale_tril
        ).sample()

        return new_yv[...,0], new_yv[...,1]