1
1
import numpy as np
2
+ import sympy as sp
3
+ from devito import Dimension , Constant , TimeFunction , Eq , solve , Operator
2
4
#import matplotlib.pyplot as plt
3
5
import scitools .std as plt
4
6
@@ -11,27 +13,33 @@ def solver(I, V, m, b, s, F, dt, T, damping='linear'):
11
13
'quadratic', f(u')=b*u'*abs(u').
12
14
F(t) and s(u) are Python functions.
13
15
"""
14
- dt = float (dt ); b = float (b ); m = float (m ) # avoid integer div.
16
+ dt = float (dt )
17
+ b = float (b )
18
+ m = float (m )
15
19
Nt = int (round (T / dt ))
16
- u = np .zeros (Nt + 1 )
17
- t = np .linspace (0 , Nt * dt , Nt + 1 )
20
+ t = Dimension ('t' , spacing = Constant ('h_t' ))
21
+
22
+ u = TimeFunction (name = 'u' , dimensions = (t ,),
23
+ shape = (Nt + 1 ,), space_order = 2 )
24
+
25
+ u .data [0 ] = I
18
26
19
- u [0 ] = I
20
27
if damping == 'linear' :
21
- u [1 ] = u [0 ] + dt * V + dt ** 2 / (2 * m )* (- b * V - s (u [0 ]) + F (t [0 ]))
28
+ # dtc for central difference (default for time is forward, 1st order)
29
+ eqn = m * u .dt2 + b * u .dtc + s (u ) - F (u )
30
+ stencil = Eq (u .forward , solve (eqn , u .forward ))
22
31
elif damping == 'quadratic' :
23
- u [1 ] = u [0 ] + dt * V + \
24
- dt ** 2 / (2 * m )* (- b * V * abs (V ) - s (u [0 ]) + F (t [0 ]))
32
+ # fd_order set as backward derivative used is 1st order
33
+ eqn = m * u .dt2 + b * u .dt * sp .Abs (u .dtl (fd_order = 1 )) + s (u ) - F (u )
34
+ stencil = Eq (u .forward , solve (eqn , u .forward ))
35
+ # First timestep needs to have the backward timestep substituted
36
+ stencil_init = stencil .subs (u .backward , u .forward - 2 * t .spacing * V )
37
+ op_init = Operator (stencil_init , name = 'first_timestep' )
38
+ op = Operator (stencil , name = 'main_loop' )
39
+ op_init .apply (h_t = dt , t_M = 1 )
40
+ op .apply (h_t = dt , t_m = 1 , t_M = Nt - 1 )
25
41
26
- for n in range (1 , Nt ):
27
- if damping == 'linear' :
28
- u [n + 1 ] = (2 * m * u [n ] + (b * dt / 2 - m )* u [n - 1 ] +
29
- dt ** 2 * (F (t [n ]) - s (u [n ])))/ (m + b * dt / 2 )
30
- elif damping == 'quadratic' :
31
- u [n + 1 ] = (2 * m * u [n ] - m * u [n - 1 ] + b * u [n ]* abs (u [n ] - u [n - 1 ])
32
- + dt ** 2 * (F (t [n ]) - s (u [n ])))/ \
33
- (m + b * abs (u [n ] - u [n - 1 ]))
34
- return u , t
42
+ return u .data , np .linspace (0 , Nt * dt , Nt + 1 )
35
43
36
44
def visualize (u , t , title = '' , filename = 'tmp' ):
37
45
plt .plot (t , u , 'b-' )
@@ -46,8 +54,6 @@ def visualize(u, t, title='', filename='tmp'):
46
54
plt .savefig (filename + '.pdf' )
47
55
plt .show ()
48
56
49
- import sympy as sym
50
-
51
57
def test_constant ():
52
58
"""Verify a constant solution."""
53
59
u_exact = lambda t : I
@@ -68,24 +74,24 @@ def test_constant():
68
74
69
75
def lhs_eq (t , m , b , s , u , damping = 'linear' ):
70
76
"""Return lhs of differential equation as sympy expression."""
71
- v = sym .diff (u , t )
77
+ v = sp .diff (u , t )
72
78
if damping == 'linear' :
73
- return m * sym .diff (u , t , t ) + b * v + s (u )
79
+ return m * sp .diff (u , t , t ) + b * v + s (u )
74
80
else :
75
- return m * sym .diff (u , t , t ) + b * v * sym .Abs (v ) + s (u )
81
+ return m * sp .diff (u , t , t ) + b * v * sp .Abs (v ) + s (u )
76
82
77
83
def test_quadratic ():
78
84
"""Verify a quadratic solution."""
79
85
I = 1.2 ; V = 3 ; m = 2 ; b = 0.9
80
86
s = lambda u : 4 * u
81
- t = sym .Symbol ('t' )
87
+ t = sp .Symbol ('t' )
82
88
dt = 0.2
83
89
T = 2
84
90
85
91
q = 2 # arbitrary constant
86
92
u_exact = I + V * t + q * t ** 2
87
- F = sym .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'linear' ))
88
- u_exact = sym .lambdify (t , u_exact , modules = 'numpy' )
93
+ F = sp .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'linear' ))
94
+ u_exact = sp .lambdify (t , u_exact , modules = 'numpy' )
89
95
u1 , t1 = solver (I , V , m , b , s , F , dt , T , 'linear' )
90
96
diff = np .abs (u_exact (t1 ) - u1 ).max ()
91
97
tol = 1E-13
@@ -94,8 +100,8 @@ def test_quadratic():
94
100
# In the quadratic damping case, u_exact must be linear
95
101
# in order exactly recover this solution
96
102
u_exact = I + V * t
97
- F = sym .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'quadratic' ))
98
- u_exact = sym .lambdify (t , u_exact , modules = 'numpy' )
103
+ F = sp .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'quadratic' ))
104
+ u_exact = sp .lambdify (t , u_exact , modules = 'numpy' )
99
105
u2 , t2 = solver (I , V , m , b , s , F , dt , T , 'quadratic' )
100
106
diff = np .abs (u_exact (t2 ) - u2 ).max ()
101
107
assert diff < tol
@@ -127,11 +133,11 @@ def test_mms():
127
133
"""Use method of manufactured solutions."""
128
134
m = 4. ; b = 1
129
135
w = 1.5
130
- t = sym .Symbol ('t' )
131
- u_exact = 3 * sym .exp (- 0.2 * t )* sym .cos (1.2 * t )
136
+ t = sp .Symbol ('t' )
137
+ u_exact = 3 * sp .exp (- 0.2 * t )* sp .cos (1.2 * t )
132
138
I = u_exact .subs (t , 0 ).evalf ()
133
- V = sym .diff (u_exact , t ).subs (t , 0 ).evalf ()
134
- u_exact_py = sym .lambdify (t , u_exact , modules = 'numpy' )
139
+ V = sp .diff (u_exact , t ).subs (t , 0 ).evalf ()
140
+ u_exact_py = sp .lambdify (t , u_exact , modules = 'numpy' )
135
141
s = lambda u : u ** 3
136
142
dt = 0.2
137
143
T = 6
@@ -140,14 +146,14 @@ def test_mms():
140
146
# Run grid refinements and compute exact error
141
147
for i in range (5 ):
142
148
F_formula = lhs_eq (t , m , b , s , u_exact , 'linear' )
143
- F = sym .lambdify (t , F_formula )
149
+ F = sp .lambdify (t , F_formula )
144
150
u1 , t1 = solver (I , V , m , b , s , F , dt , T , 'linear' )
145
151
error = np .sqrt (np .sum ((u_exact_py (t1 ) - u1 )** 2 )* dt )
146
152
errors_linear .append ((dt , error ))
147
153
148
154
F_formula = lhs_eq (t , m , b , s , u_exact , 'quadratic' )
149
155
#print sym.latex(F_formula, mode='plain')
150
- F = sym .lambdify (t , F_formula )
156
+ F = sp .lambdify (t , F_formula )
151
157
u2 , t2 = solver (I , V , m , b , s , F , dt , T , 'quadratic' )
152
158
error = np .sqrt (np .sum ((u_exact_py (t2 ) - u2 )** 2 )* dt )
153
159
errors_quadratic .append ((dt , error ))
0 commit comments