"""Time-dependent PDE solvers: time integrators and KdV equation solver."""
from __future__ import annotations
import numpy as np
from typing import Callable
from numba import njit
# Use numpy.fft directly (minimal overhead)
fft_backend = np.fft
# Numba JIT kernels to eliminate Python overhead in element-wise operations
@njit(cache=True)
def _jit_nonlinear_term(u: np.ndarray, ux: np.ndarray) -> np.ndarray:
"""Compute u * ux with Numba JIT."""
return u * ux
@njit(cache=True)
def _jit_combine_rhs(nonlinear: np.ndarray, uxxx: np.ndarray) -> np.ndarray:
"""Compute -6*nonlinear - uxxx with Numba JIT."""
result = np.empty_like(nonlinear)
for i in range(nonlinear.size):
result[i] = -6.0 * nonlinear[i] - uxxx[i]
return result
@njit(cache=True, inline="always")
def _jit_rk4_combine(
u: np.ndarray,
k1: np.ndarray,
k2: np.ndarray,
k3: np.ndarray,
k4: np.ndarray,
dt: float,
) -> np.ndarray:
"""Fused RK4 final combination."""
result = np.empty_like(u)
dt6 = dt / 6.0
for i in range(u.size):
result[i] = u[i] + dt6 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i])
return result
@njit(cache=True, inline="always")
def _jit_rk4_stage(u: np.ndarray, k: np.ndarray, factor: float) -> np.ndarray:
"""Compute u + factor * k for RK4 stages."""
result = np.empty_like(u)
for i in range(u.size):
result[i] = u[i] + factor * k[i]
return result
@njit(cache=True, inline="always")
def _jit_rk3_stage1(u: np.ndarray, k1: np.ndarray, dt: float) -> np.ndarray:
"""RK3 first stage: u + dt*k1."""
result = np.empty_like(u)
for i in range(u.size):
result[i] = u[i] + dt * k1[i]
return result
@njit(cache=True, inline="always")
def _jit_rk3_stage2(
u: np.ndarray, u1: np.ndarray, k2: np.ndarray, dt: float
) -> np.ndarray:
"""RK3 second stage: 0.75*u + 0.25*u1 + 0.25*dt*k2."""
result = np.empty_like(u)
for i in range(u.size):
result[i] = 0.75 * u[i] + 0.25 * u1[i] + 0.25 * dt * k2[i]
return result
@njit(cache=True, inline="always")
def _jit_rk3_stage3(
u: np.ndarray, u2: np.ndarray, k3: np.ndarray, dt: float
) -> np.ndarray:
"""RK3 third stage: (1/3)*u + (2/3)*u2 + (2/3)*dt*k3."""
result = np.empty_like(u)
for i in range(u.size):
result[i] = (1.0 / 3.0) * u[i] + (2.0 / 3.0) * u2[i] + (2.0 / 3.0) * dt * k3[i]
return result
# =============================================================================
# Time Integration Methods
# =============================================================================
[docs]
class TimeIntegrator:
"""Base class for time integration methods."""
def __init__(self, name: str, order: int, stages: int = 1):
"""
Initialize time integrator.
Parameters
----------
name : str
Name of the method
order : int
Order of accuracy
stages : int
Number of stages (for RK) or steps (for LMM)
"""
self.name = name
self.order = order
self.stages = stages
[docs]
def step(self, rhs: Callable, u: np.ndarray, t: float, dt: float) -> np.ndarray:
"""
Take one time step.
Parameters
----------
rhs : Callable
Right-hand side function f(u, t)
u : np.ndarray
Current solution
t : float
Current time
dt : float
Time step
Returns
-------
np.ndarray
Solution at next time step
"""
raise NotImplementedError
# =============================================================================
# Runge-Kutta Methods
# =============================================================================
[docs]
class RK3(TimeIntegrator):
"""3rd-order Strong Stability Preserving Runge-Kutta (SSP-RK3).
This explicit three-stage method preserves strong stability properties,
making it particularly suitable for hyperbolic PDEs and problems requiring
positivity preservation. The method is 3rd-order accurate in time.
Notes
-----
The SSP property ensures that the numerical solution satisfies the same
stability bounds as forward Euler under a modified time step restriction.
This is particularly useful for problems with steep gradients or shocks.
References
----------
Engsig-Karup, "Lecture 5: Initial Value Problems", p. 63
"""
def __init__(self):
super().__init__("RK3", order=3, stages=3)
[docs]
def step(self, rhs: Callable, u: np.ndarray, t: float, dt: float) -> np.ndarray:
# Pre-allocate temp arrays to reduce allocations
temp = np.empty_like(u)
u_stage = np.empty_like(u)
k1 = rhs(u, t)
np.multiply(dt, k1, out=u_stage)
u_stage += u
k2 = rhs(u_stage, t + dt)
np.multiply(0.75, u, out=temp)
temp += 0.25 * u_stage
temp += 0.25 * dt * k2
k3 = rhs(temp, t + 0.5 * dt)
np.multiply(1.0 / 3.0, u, out=u_stage)
u_stage += (2.0 / 3.0) * temp
u_stage += (2.0 / 3.0) * dt * k3
return u_stage
[docs]
class RK4(TimeIntegrator):
"""Classical 4th-order Runge-Kutta method (ERK4).
The classical explicit four-stage fourth-order Runge-Kutta method.
This is one of the most widely used explicit time integrators,
offering a good balance between accuracy and computational cost.
Notes
-----
The method evaluates the right-hand side four times per step at
intermediate stages. It is 4th-order accurate and suitable for
a wide range of initial value problems, though it lacks special
stability properties like SSP methods.
References
----------
Engsig-Karup, "Lecture 5: Initial Value Problems", p. 58
"""
def __init__(self):
super().__init__("RK4", order=4, stages=4)
[docs]
def step(self, rhs: Callable, u: np.ndarray, t: float, dt: float) -> np.ndarray:
# Use JIT kernels for all stages
k1 = rhs(u, t)
k2 = rhs(_jit_rk4_stage(u, k1, 0.5 * dt), t + 0.5 * dt)
k3 = rhs(_jit_rk4_stage(u, k2, 0.5 * dt), t + 0.5 * dt)
k4 = rhs(_jit_rk4_stage(u, k3, dt), t + dt)
# Use JIT kernel for final combination
return _jit_rk4_combine(u, k1, k2, k3, k4, dt)
[docs]
def get_time_integrator(name: str, **kwargs) -> TimeIntegrator:
"""
Retrieve a time integrator by name.
Parameters
----------
name : str
Integrator identifier: "rk4" or "rk3"
kwargs :
Extra keyword arguments (currently unused, kept for API compatibility)
Returns
-------
TimeIntegrator
The requested time integrator instance
"""
normalized = "".join(ch for ch in name.lower() if ch.isalnum())
if normalized == "rk4":
return RK4()
elif normalized == "rk3":
return RK3()
else:
raise ValueError(f"Unknown time integrator '{name}'. Available: 'rk4', 'rk3'")
# =============================================================================
# KdV Equation Solver
# =============================================================================
[docs]
def soliton(x: np.ndarray, t: float, c: float, x0: float = 0.0) -> np.ndarray:
"""
Compute KdV soliton solution.
Parameters
----------
x : np.ndarray
Spatial coordinates
t : float
Time
c : float
Soliton speed parameter
x0 : float, optional
Initial position offset (default: 0.0)
Returns
-------
np.ndarray
Soliton amplitude at each spatial point
"""
xi = x - c * t - x0
return 0.5 * c / np.cosh(0.5 * np.sqrt(c) * xi) ** 2
[docs]
def two_soliton_initial(
x: np.ndarray, c1: float, x01: float, c2: float, x02: float
) -> np.ndarray:
"""
Initial condition for two-soliton collision simulation.
Superposition of two solitons at t=0.
Parameters
----------
x : np.ndarray
Spatial coordinates
c1 : float
Speed parameter of first soliton
x01 : float
Initial position of first soliton
c2 : float
Speed parameter of second soliton
x02 : float
Initial position of second soliton
Returns
-------
np.ndarray
Initial condition u(x, 0)
"""
return soliton(x, 0.0, c1, x01) + soliton(x, 0.0, c2, x02)
class ManufacturedSolution:
"""
Manufactured solution for convergence testing.
Provides an exact solution u_exact(x,t) and computes the source term
f(x,t) needed to satisfy the modified KdV equation:
u_t + 6*u*u_x + u_xxx = f(x,t)
The source term is computed symbolically as:
f(x,t) = u_t + 6*u*u_x + u_xxx
evaluated at the exact solution.
Parameters
----------
amplitude : float
Amplitude of the solution
wavenumber : float
Spatial wavenumber (must be integer for periodicity)
decay_rate : float
Temporal decay rate (positive for decay)
Notes
-----
The manufactured solution has the form:
u(x,t) = A * sin(k*x) * exp(-alpha*t)
This is smooth, periodic, and decays in time, making it ideal for
convergence testing without shock formation or instabilities.
"""
def __init__(
self, amplitude: float = 1.0, wavenumber: float = 1.0, frequency: float = 0.1
):
"""Initialize manufactured solution parameters."""
self.A = amplitude
self.k = wavenumber
self.omega = frequency
def u_exact(self, x: np.ndarray, t: float) -> np.ndarray:
"""
Compute exact solution at given spatial points and time.
u(x,t) = A * sin(k*x) * sin(omega*t)
Parameters
----------
x : np.ndarray
Spatial coordinates
t : float
Time
Returns
-------
np.ndarray
Exact solution u(x,t)
"""
return self.A * np.sin(self.k * x) * np.sin(self.omega * t)
def source(self, x: np.ndarray, t: float) -> np.ndarray:
"""
Compute source term f(x,t) = u_t + 6*u*u_x + u_xxx.
For u(x,t) = A * sin(k*x) * sin(omega*t):
- u_t = A * omega * sin(k*x) * cos(omega*t)
- u_x = A * k * cos(k*x) * sin(omega*t)
- u_xxx = -A * k^3 * cos(k*x) * sin(omega*t)
Parameters
----------
x : np.ndarray
Spatial coordinates
t : float
Time
Returns
-------
np.ndarray
Source term f(x,t)
"""
# Precompute common terms
sin_kx = np.sin(self.k * x)
cos_kx = np.cos(self.k * x)
sin_wt = np.sin(self.omega * t)
cos_wt = np.cos(self.omega * t)
# u = A * sin(kx) * sin(wt)
u = self.A * sin_kx * sin_wt
# u_t = A * omega * sin(kx) * cos(wt)
u_t = self.A * self.omega * sin_kx * cos_wt
# u_x = A * k * cos(kx) * sin(wt)
u_x = self.A * self.k * cos_kx * sin_wt
# u_xxx = -A * k^3 * cos(kx) * sin(wt)
u_xxx = -self.A * (self.k**3) * cos_kx * sin_wt
# f = u_t + 6*u*u_x + u_xxx
return u_t + 6.0 * u * u_x + u_xxx
[docs]
class KdVSolver:
"""Korteweg-de Vries equation solver using Fourier spectral methods.
Solves the KdV equation u_t + 6u*u_x + u_xxx = 0 on a periodic
domain using Fourier collocation for spatial discretization.
The nonlinear term can optionally be dealiased using the 3/2-rule.
Notes
-----
The Fourier spectral method provides exponential convergence for
smooth periodic solutions. Spatial derivatives are computed in
Fourier space via multiplication by ik for first derivatives and
(ik)^3 for third derivatives, where k is the wavenumber.
The 3/2-rule dealiasing prevents aliasing errors in the nonlinear
convolution product u*u_x by padding the Fourier coefficients to
3/2 times the original resolution.
"""
def __init__(self, N: int, L: float, dealias: bool = False):
"""
Initialize the KdV solver.
Parameters
----------
N : int
Number of Fourier modes (grid points)
L : float
Half-length of spatial domain [-L, L]
dealias : bool, optional
Apply 2/3-rule dealiasing to nonlinear term (default: False)
"""
self.N = N
self.L = L
self.dealias = dealias
self.x = np.linspace(-L, L, N, endpoint=False)
self.dx = 2 * L / N
# Wave numbers for Fourier spectral method
self.k = fft_backend.fftfreq(N, d=self.dx) * 2 * np.pi
self.ik = 1j * self.k
self.ik3 = self.ik**3
def _dealias_product(self, u_hat: np.ndarray, v_hat: np.ndarray) -> np.ndarray:
"""
Compute dealiased product u*v using 3/2-rule.
The 3/2-rule pads the Fourier coefficients to 3/2*N points,
performs multiplication in physical space, then truncates back to N points.
This properly handles aliasing in nonlinear convolution products.
Parameters
----------
u_hat : np.ndarray
Fourier coefficients of first function
v_hat : np.ndarray
Fourier coefficients of second function
Returns
-------
np.ndarray
Fourier coefficients of dealiased product u*v
"""
N = len(u_hat)
M = int(3 * N // 2)
# For correct frequency splitting with both even and odd N
n_low = (N + 1) // 2 # Number of non-negative frequencies
n_high = N // 2 # Number of negative frequencies
# Pad with zeros in middle of frequency space
# [low freqs, zeros, high freqs]
u_hat_pad = np.concatenate([u_hat[:n_low], np.zeros(M - N), u_hat[n_low:]])
v_hat_pad = np.concatenate([v_hat[:n_low], np.zeros(M - N), v_hat[n_low:]])
# Multiply in physical space (on finer grid)
u_pad = fft_backend.ifft(u_hat_pad)
v_pad = fft_backend.ifft(v_hat_pad)
w_pad = u_pad * v_pad
# Transform back and truncate (keep low and high freqs, discard padded region)
w_pad_hat = fft_backend.fft(w_pad)
w_hat = (3 / 2) * np.concatenate([w_pad_hat[:n_low], w_pad_hat[M - n_high :]])
return w_hat
[docs]
def rhs(
self, u: np.ndarray, t: float, source_term: Callable | None = None
) -> np.ndarray:
"""
Compute right-hand side of semi-discrete KdV equation.
RHS = :math:`-6u u_x - u_{xxx} + f(x,t)`
Parameters
----------
u : np.ndarray
Solution at current time
t : float
Current time
source_term : Callable[[np.ndarray, float], np.ndarray] | None, optional
Optional source term function f(x, t) for manufactured solutions
Returns
-------
np.ndarray
Time derivative du/dt
"""
# Compute FFT once
u_hat = fft_backend.fft(u)
# Compute nonlinear term: u * u_x
ux_hat = self.ik * u_hat
if self.dealias:
nonlinear_hat = self._dealias_product(u_hat, ux_hat)
nonlinear = fft_backend.ifft(nonlinear_hat).real
else:
ux = fft_backend.ifft(ux_hat).real
nonlinear = u * ux
# Compute linear term and combine
uxxx_hat = self.ik3 * u_hat
uxxx = fft_backend.ifft(uxxx_hat).real
# Combine in-place
dudt = nonlinear
dudt *= -6.0
dudt -= uxxx
# Add source term if provided
if source_term is not None:
dudt += source_term(self.x, t)
return dudt
[docs]
def get_spectrum(self, u: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute Fourier spectrum for spectral analysis.
Parameters
----------
u : np.ndarray
Solution field
Returns
-------
k : np.ndarray
Wave numbers
magnitude : np.ndarray
Magnitude :math:`|\\hat{u}_k|` of Fourier coefficients
phase : np.ndarray
Phase angle of Fourier coefficients
"""
u_hat = fft_backend.fft(u)
return self.k, np.abs(u_hat), np.angle(u_hat)
[docs]
def compute_eigenvalues(self, u_max: float) -> np.ndarray:
"""
Compute eigenvalues of frozen-coefficient linearized KdV operator.
For the KdV equation u_t = -6u*u_x - u_xxx, the linearization
around a frozen coefficient u_max gives:
L = -6*u_max*D1 - D3
where D1 and D3 are the first and third derivative operators.
This is useful for stability analysis.
Parameters
----------
u_max : float
Maximum amplitude for frozen-coefficient approximation
Returns
-------
np.ndarray
Complex eigenvalues of the linearized operator
"""
# Build differentiation matrices in Fourier space
# For Fourier methods, D1 and D3 are diagonal in spectral space:
# D1[k] = ik, D3[k] = (ik)^3 = -ik^3
# Construct the frozen-coefficient operator matrix
# In spectral space: L_hat[k] = -6*u_max*(ik) - (ik)^3
eigvals = -6 * u_max * self.ik - self.ik3
return eigvals
[docs]
def solve(
self,
u0: np.ndarray,
t_final: float,
dt: float,
save_every: int = 1,
integrator: TimeIntegrator = None,
measure_performance: bool = False,
) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, dict]:
"""
Solve KdV equation from t=0 to t=t_final.
Parameters
----------
u0 : np.ndarray
Initial condition
t_final : float
Final time
dt : float
Time step
save_every : int, optional
Save solution every N steps, by default 1
integrator : TimeIntegrator, optional
Time integration method, by default RK4()
measure_performance : bool, optional
Measure and return performance metrics, by default False
Returns
-------
t_saved : np.ndarray
Times at which solution was saved
u_saved : np.ndarray
Saved solutions (shape: [n_saves, N])
performance : dict, optional
Performance metrics (returned if measure_performance=True):
- 'wall_time_s': Total wall time in seconds
- 'mean_step_time_ms': Mean time per step in milliseconds
- 'std_step_time_ms': Standard deviation of step times
- 'nsteps': Total number of time steps
"""
if integrator is None:
integrator = get_time_integrator("rk4")
n_steps = int(np.ceil(t_final / dt))
n_saves = n_steps // save_every + 1
u = u0.copy()
t = 0.0
t_saved = np.zeros(n_saves)
u_saved = np.zeros((n_saves, self.N))
t_saved[0] = t
u_saved[0] = u
# Performance measurement
if measure_performance:
import time
step_times = []
save_idx = 1
for step in range(n_steps):
if measure_performance:
t_start = time.perf_counter()
u = integrator.step(self.rhs, u, t, dt)
t += dt
if measure_performance:
step_times.append(time.perf_counter() - t_start)
if (step + 1) % save_every == 0 and save_idx < n_saves:
t_saved[save_idx] = t
u_saved[save_idx] = u
save_idx += 1
if measure_performance:
step_times_ms = np.array(step_times) * 1000 # Convert to ms
performance = {
"wall_time_s": np.sum(step_times),
"mean_step_time_ms": np.mean(step_times_ms),
"std_step_time_ms": np.std(step_times_ms),
"nsteps": n_steps,
}
return t_saved[:save_idx], u_saved[:save_idx], performance
return t_saved[:save_idx], u_saved[:save_idx]
[docs]
@staticmethod
def compute_conserved_quantities(
u: np.ndarray, dx: float
) -> tuple[float, float, float]:
"""
Compute conserved quantities for KdV equation.
Mass: M = ∫ u dx
Momentum: V = ∫ u² dx
Energy: E = ∫ (½u_x² - u³) dx
Parameters
----------
u : np.ndarray
Solution field
dx : float
Grid spacing
Returns
-------
M : float
Mass
V : float
Momentum
E : float
Energy
"""
N = len(u)
k = fft_backend.fftfreq(N, d=dx) * 2 * np.pi
ik = 1j * k
# Compute derivative in Fourier space
u_hat = fft_backend.fft(u)
ux_hat = ik * u_hat
ux = fft_backend.ifft(ux_hat).real
M = np.sum(u) * dx
V = np.sum(u**2) * dx
E = np.sum(0.5 * ux**2 - u**3) * dx
return M, V, E
[docs]
@staticmethod
def stable_dt(
N: int,
L: float,
u_max: float,
*,
integrator_name: str = "rk4",
dealiased: bool = False,
) -> float:
"""
Compute stable time step via absolute-stability & frozen coefficients.
Semi-discrete KdV eigenvalues (frozen u):
:math:`\\lambda_k = ik(k^2 - 6u_{max})`, so :math:`|\\lambda_k| = |k| \\cdot |k^2 - 6u_{max}|`.
Choose :math:`\\Delta t` so that :math:`\\Delta t |\\lambda_{max}| \\leq s(method)`
on the imaginary axis.
Parameters
----------
N, L : grid size and half-domain for [-L, L]
u_max : max :math:`|u|` expected (for 1-soliton with parameter c, use u_max = c/2)
integrator_name : one of {"rk4", "rk3"}
dealiased : retained for API compatibility; the returned Δt is conservative
across both aliased and de-aliased configurations.
Returns
-------
float
Suggested stable time step.
"""
name = "".join(ch for ch in integrator_name.lower() if ch.isalnum())
# Imag-axis crossing s(method):
# (from stability diagrams / standard results)
imag_axis_radius = {
"rk4": 2.828,
"rk3": 1.732,
}.get(name, 2.828 if "rk4" in name else 1.732 if "rk3" in name else 0.0)
# k_max from Fourier grid
def _lam_max(k_max: float) -> float:
return k_max * abs(k_max**2 - 6.0 * abs(u_max))
kmax_alias = (np.pi / L) * (N // 2)
kmax_dealias = (np.pi / L) * (N // 3)
lam_alias = _lam_max(kmax_alias)
lam_dealias = _lam_max(kmax_dealias)
# Guard against degenerate configurations
def _dt(lam: float) -> float:
if lam == 0.0:
return np.inf
if imag_axis_radius == 0.0:
return 1e-12 / lam
return imag_axis_radius / lam
dt_alias = _dt(lam_alias)
dt_dealias = _dt(lam_dealias)
return min(dt_alias, dt_dealias)