Skip to content

Commit 4c7e185

Browse files
committed
Modified vib python file for consistency with notebook
1 parent 8f93b6c commit 4c7e185

File tree

1 file changed

+38
-32
lines changed

1 file changed

+38
-32
lines changed

src/vib/vib.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import numpy as np
2+
import sympy as sp
3+
from devito import Dimension, Constant, TimeFunction, Eq, solve, Operator
24
#import matplotlib.pyplot as plt
35
import scitools.std as plt
46

@@ -11,27 +13,33 @@ def solver(I, V, m, b, s, F, dt, T, damping='linear'):
1113
'quadratic', f(u')=b*u'*abs(u').
1214
F(t) and s(u) are Python functions.
1315
"""
14-
dt = float(dt); b = float(b); m = float(m) # avoid integer div.
16+
dt = float(dt)
17+
b = float(b)
18+
m = float(m)
1519
Nt = int(round(T/dt))
16-
u = np.zeros(Nt+1)
17-
t = np.linspace(0, Nt*dt, Nt+1)
20+
t = Dimension('t', spacing=Constant('h_t'))
21+
22+
u = TimeFunction(name='u', dimensions=(t,),
23+
shape=(Nt+1,), space_order=2)
24+
25+
u.data[0] = I
1826

19-
u[0] = I
2027
if damping == 'linear':
21-
u[1] = u[0] + dt*V + dt**2/(2*m)*(-b*V - s(u[0]) + F(t[0]))
28+
# dtc for central difference (default for time is forward, 1st order)
29+
eqn = m*u.dt2 + b*u.dtc + s(u) - F(u)
30+
stencil = Eq(u.forward, solve(eqn, u.forward))
2231
elif damping == 'quadratic':
23-
u[1] = u[0] + dt*V + \
24-
dt**2/(2*m)*(-b*V*abs(V) - s(u[0]) + F(t[0]))
32+
# fd_order set as backward derivative used is 1st order
33+
eqn = m*u.dt2 + b*u.dt*sp.Abs(u.dtl(fd_order=1)) + s(u) - F(u)
34+
stencil = Eq(u.forward, solve(eqn, u.forward))
35+
# First timestep needs to have the backward timestep substituted
36+
stencil_init = stencil.subs(u.backward, u.forward-2*t.spacing*V)
37+
op_init = Operator(stencil_init, name='first_timestep')
38+
op = Operator(stencil, name='main_loop')
39+
op_init.apply(h_t=dt, t_M=1)
40+
op.apply(h_t=dt, t_m=1, t_M=Nt-1)
2541

26-
for n in range(1, Nt):
27-
if damping == 'linear':
28-
u[n+1] = (2*m*u[n] + (b*dt/2 - m)*u[n-1] +
29-
dt**2*(F(t[n]) - s(u[n])))/(m + b*dt/2)
30-
elif damping == 'quadratic':
31-
u[n+1] = (2*m*u[n] - m*u[n-1] + b*u[n]*abs(u[n] - u[n-1])
32-
+ dt**2*(F(t[n]) - s(u[n])))/\
33-
(m + b*abs(u[n] - u[n-1]))
34-
return u, t
42+
return u.data, np.linspace(0, Nt*dt, Nt+1)
3543

3644
def visualize(u, t, title='', filename='tmp'):
3745
plt.plot(t, u, 'b-')
@@ -46,8 +54,6 @@ def visualize(u, t, title='', filename='tmp'):
4654
plt.savefig(filename + '.pdf')
4755
plt.show()
4856

49-
import sympy as sym
50-
5157
def test_constant():
5258
"""Verify a constant solution."""
5359
u_exact = lambda t: I
@@ -68,24 +74,24 @@ def test_constant():
6874

6975
def lhs_eq(t, m, b, s, u, damping='linear'):
7076
"""Return lhs of differential equation as sympy expression."""
71-
v = sym.diff(u, t)
77+
v = sp.diff(u, t)
7278
if damping == 'linear':
73-
return m*sym.diff(u, t, t) + b*v + s(u)
79+
return m*sp.diff(u, t, t) + b*v + s(u)
7480
else:
75-
return m*sym.diff(u, t, t) + b*v*sym.Abs(v) + s(u)
81+
return m*sp.diff(u, t, t) + b*v*sp.Abs(v) + s(u)
7682

7783
def test_quadratic():
7884
"""Verify a quadratic solution."""
7985
I = 1.2; V = 3; m = 2; b = 0.9
8086
s = lambda u: 4*u
81-
t = sym.Symbol('t')
87+
t = sp.Symbol('t')
8288
dt = 0.2
8389
T = 2
8490

8591
q = 2 # arbitrary constant
8692
u_exact = I + V*t + q*t**2
87-
F = sym.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'linear'))
88-
u_exact = sym.lambdify(t, u_exact, modules='numpy')
93+
F = sp.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'linear'))
94+
u_exact = sp.lambdify(t, u_exact, modules='numpy')
8995
u1, t1 = solver(I, V, m, b, s, F, dt, T, 'linear')
9096
diff = np.abs(u_exact(t1) - u1).max()
9197
tol = 1E-13
@@ -94,8 +100,8 @@ def test_quadratic():
94100
# In the quadratic damping case, u_exact must be linear
95101
# in order exactly recover this solution
96102
u_exact = I + V*t
97-
F = sym.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'quadratic'))
98-
u_exact = sym.lambdify(t, u_exact, modules='numpy')
103+
F = sp.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'quadratic'))
104+
u_exact = sp.lambdify(t, u_exact, modules='numpy')
99105
u2, t2 = solver(I, V, m, b, s, F, dt, T, 'quadratic')
100106
diff = np.abs(u_exact(t2) - u2).max()
101107
assert diff < tol
@@ -127,11 +133,11 @@ def test_mms():
127133
"""Use method of manufactured solutions."""
128134
m = 4.; b = 1
129135
w = 1.5
130-
t = sym.Symbol('t')
131-
u_exact = 3*sym.exp(-0.2*t)*sym.cos(1.2*t)
136+
t = sp.Symbol('t')
137+
u_exact = 3*sp.exp(-0.2*t)*sp.cos(1.2*t)
132138
I = u_exact.subs(t, 0).evalf()
133-
V = sym.diff(u_exact, t).subs(t, 0).evalf()
134-
u_exact_py = sym.lambdify(t, u_exact, modules='numpy')
139+
V = sp.diff(u_exact, t).subs(t, 0).evalf()
140+
u_exact_py = sp.lambdify(t, u_exact, modules='numpy')
135141
s = lambda u: u**3
136142
dt = 0.2
137143
T = 6
@@ -140,14 +146,14 @@ def test_mms():
140146
# Run grid refinements and compute exact error
141147
for i in range(5):
142148
F_formula = lhs_eq(t, m, b, s, u_exact, 'linear')
143-
F = sym.lambdify(t, F_formula)
149+
F = sp.lambdify(t, F_formula)
144150
u1, t1 = solver(I, V, m, b, s, F, dt, T, 'linear')
145151
error = np.sqrt(np.sum((u_exact_py(t1) - u1)**2)*dt)
146152
errors_linear.append((dt, error))
147153

148154
F_formula = lhs_eq(t, m, b, s, u_exact, 'quadratic')
149155
#print sym.latex(F_formula, mode='plain')
150-
F = sym.lambdify(t, F_formula)
156+
F = sp.lambdify(t, F_formula)
151157
u2, t2 = solver(I, V, m, b, s, F, dt, T, 'quadratic')
152158
error = np.sqrt(np.sum((u_exact_py(t2) - u2)**2)*dt)
153159
errors_quadratic.append((dt, error))

0 commit comments

Comments
 (0)