# Copyright (C) 2025 Alessandro Santini
# SPDX-License-Identifier: MIT
# Credits for the original implementations: Cecilio García Quirós
"""
Amplitude coefficient computation for IMRPhenomTHM.
======================================================
This module implements the pAmp class functionality from phenomxpy,
computing all the coefficients needed for the IMR amplitude ansatze.
"""
from typing import Tuple
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from ..utils.utility import solve_3x3_explicit
from . import fits
from .internals import WaveformParams
from .phase import PhaseCoeffs, _inspiral_ansatz_domega, imr_omega
[docs]
class AmplitudeCoeffs(eqx.Module):
"""
All amplitude coefficients for a given mode.
"""
mode: int | Array
# PN coefficients arrays (ready for ansatz)
# Real: [ampN, amp0half, amp1, amp1half, amp2, amp2half, amp3, amp3half, amplog]
pn_real_coeffs: Array
# Imag: [amp0half, amp1, amp1half, amp2, amp2half, amp3, amp3half]
pn_imag_coeffs: Array
# Pseudo-PN coefficients (3 coefficients)
inspC1: float | Array
inspC2: float | Array
inspC3: float | Array
# Ringdown coefficients
alpha1RD: float | Array
alpha1RD_prec: float | Array
ampPeak: float | Array
c1_prec: float | Array
c2_prec: float | Array
c3: float | Array
c4_prec: float | Array
# Intermediate coefficients
mergerC1: float | Array
mergerC2: float | Array
mergerC3: float | Array
mergerC4: float | Array
# Cuts and times
inspiral_cut: float | Array
ringdown_cut: float | Array
tshift: float | Array
# Prefactor
fac0: float | Array
# Phase offset from amplitude
omegaCutPNAMP: float | Array
phiCutPNAMP: float | Array
@jax.jit
def _compute_pn_amplitude_coeffs(
eta: float | Array,
delta: float | Array,
chi1: float | Array,
chi2: float | Array,
m1: float | Array,
m2: float | Array,
mode: int | Array,
) -> Tuple[Array, Array, Array]:
"""
Compute PN amplitude coefficients for a specific mode.
Returns (pn_real_coeffs, pn_imag_coeffs, fac0).
"""
# Derived quantities
m1_2 = m1 * m1
m2_2 = m2 * m2
s1z = chi1
s2z = chi2
chis = 0.5 * (s1z + s2z)
chia = 0.5 * (s1z - s2z)
Sc = m1_2 * s1z + m2_2 * s2z
Sigmac = m2 * s2z - m1 * s1z
eta2 = eta * eta
eta3 = eta2 * eta
fac0 = 2.0 * eta * jnp.sqrt(16.0 * jnp.pi / 5.0)
# Initialize all to 0.0
ampN = 0.0
amp0halfPNreal = 0.0
amp0halfPNimag = 0.0
amp1PNreal = 0.0
amp1PNimag = 0.0
amp1halfPNreal = 0.0
amp1halfPNimag = 0.0
amp2PNreal = 0.0
amp2PNimag = 0.0
amp2halfPNreal = 0.0
amp2halfPNimag = 0.0
amp3PNreal = 0.0
amp3PNimag = 0.0
amp3halfPNreal = 0.0
amp3halfPNimag = 0.0
amplog = 0.0
# Mode 22
def get_22():
S0 = m1 * s1z + m2 * s2z
return (
1.0, # ampN
0.0, # amp0halfPNreal
0.0, # amp0halfPNimag
-107.0 / 42.0 + (55.0 * eta) / 42.0, # amp1PNreal
0.0, # amp1PNimag
(-4.0 * chis) / 3.0
- (4.0 * chia * delta) / 3.0
+ (4.0 * chis * eta) / 3.0
+ 2.0 * jnp.pi, # amp1halfPNreal
0.0, # amp1halfPNimag
-2173.0 / 1512.0
- (1069.0 * eta) / 216.0
+ (2047.0 * eta2) / 1512.0
+ S0**2, # amp2PNreal
0.0, # amp2PNimag
(-107.0 * jnp.pi) / 21.0 + (34.0 * eta * jnp.pi) / 21.0, # amp2halfPNreal
-24.0 * eta, # amp2halfPNimag
(
27027409.0 / 646800.0
- (278185.0 * eta) / 33264.0
- (20261.0 * eta2) / 2772.0
+ (114635.0 * eta3) / 99792.0
- (856.0 * 0.5772156649015329) / 105.0
+ (2.0 * jnp.pi**2) / 3.0
+ (41.0 * eta * jnp.pi**2) / 96.0
), # amp3PNreal
(428.0 * jnp.pi) / 105.0, # amp3PNimag
(-2173.0 * jnp.pi) / 756.0
- (2495.0 * eta * jnp.pi) / 378.0
+ (40.0 * eta2 * jnp.pi) / 27.0, # amp3halfPNreal
(14333.0 * eta) / 162.0 - (4066.0 * eta2) / 945.0, # amp3halfPNimag
-428.0 / 105.0, # amplog
)
# Mode 21
def get_21():
return (
0.0, # ampN
delta / 3.0, # amp0halfPNreal
0.0, # amp0halfPNimag
-0.5 * chia - (chis * delta) / 2.0, # amp1PNreal
0.0, # amp1PNimag
(-17.0 * delta) / 84.0 + (5.0 * delta * eta) / 21.0, # amp1halfPNreal
0.0, # amp1halfPNimag
(delta * jnp.pi) / 3.0
- (43.0 * delta * Sc) / 21.0
- (79.0 * Sigmac) / 42.0
+ (139.0 * eta * Sigmac) / 42.0, # amp2PNreal
-1.0 / 6.0 * delta - (delta * jnp.log(16.0)) / 6.0, # amp2PNimag
(-43.0 * delta) / 378.0
- (509.0 * delta * eta) / 378.0
+ (79.0 * delta * eta2) / 504.0, # amp2halfPNreal
0.0, # amp2halfPNimag
(-17.0 * delta * jnp.pi) / 84.0
+ (delta * eta * jnp.pi) / 14.0, # amp3PNreal
(17.0 * delta) / 168.0
- (353.0 * delta * eta) / 84.0
+ (17.0 * delta * jnp.log(16.0)) / 168.0
- (delta * eta * jnp.log(4096.0)) / 84.0, # amp3PNimag
0.0, # amp3halfPNreal
0.0, # amp3halfPNimag
0.0, # amplog
)
# Mode 33
def get_33():
return (
0.0, # ampN
(3.0 * jnp.sqrt(15.0 / 14.0) * delta) / 4.0, # amp0halfPNreal
0.0, # amp0halfPNimag
0.0, # amp1PNreal
0.0, # amp1PNimag
-3.0 * jnp.sqrt(15.0 / 14.0) * delta
+ (3.0 * jnp.sqrt(15.0 / 14.0) * delta * eta) / 2.0, # amp1halfPNreal
0.0, # amp1halfPNimag
(
(9.0 * jnp.sqrt(15.0 / 14.0) * delta * jnp.pi) / 4.0
- (3.0 * jnp.sqrt(105.0 / 2.0) * delta * Sc) / 8.0
- (9.0 * jnp.sqrt(15.0 / 14.0) * Sigmac) / 8.0
+ (27.0 * jnp.sqrt(15.0 / 14.0) * eta * Sigmac) / 8.0
), # amp2PNreal
(-9.0 * jnp.sqrt(21.0 / 10.0) * delta) / 4.0
+ (9.0 * jnp.sqrt(15.0 / 14.0) * delta * jnp.log(3.0 / 2.0))
/ 2.0, # amp2PNimag
(
(369.0 * jnp.sqrt(3.0 / 70.0) * delta) / 88.0
- (919.0 * jnp.sqrt(3.0 / 70.0) * delta * eta) / 22.0
+ (887.0 * jnp.sqrt(3.0 / 70.0) * delta * eta2) / 88.0
), # amp2halfPNreal
0.0, # amp2halfPNimag
0.0, # amp3PNreal
0.0, # amp3PNimag
0.0, # amp3halfPNreal
0.0, # amp3halfPNimag
0.0, # amplog
)
# Mode 44
def get_44():
return (
0.0, # ampN
0.0, # amp0halfPNreal
0.0, # amp0halfPNimag
(8.0 * jnp.sqrt(5.0 / 7.0)) / 9.0
- (8.0 * jnp.sqrt(5.0 / 7.0) * eta) / 3.0, # amp1PNreal
0.0, # amp1PNimag
0.0, # amp1halfPNreal
0.0, # amp1halfPNimag
-2372.0 / (99.0 * jnp.sqrt(35.0))
+ (5092.0 * jnp.sqrt(5.0 / 7.0) * eta) / 297.0
- (100.0 * jnp.sqrt(35.0) * eta2) / 99.0, # amp2PNreal
0.0, # amp2PNimag
(32.0 * jnp.sqrt(5.0 / 7.0) * jnp.pi) / 9.0
- (32.0 * jnp.sqrt(5.0 / 7.0) * eta * jnp.pi) / 3.0, # amp2halfPNreal
(
(-16.0 * jnp.sqrt(7.0 / 5.0)) / 3.0
+ (1193.0 * eta) / (9.0 * jnp.sqrt(35.0))
+ (64.0 * jnp.sqrt(5.0 / 7.0) * jnp.log(2.0)) / 9.0
- (64.0 * jnp.sqrt(5.0 / 7.0) * eta * jnp.log(2.0)) / 3.0
), # amp2halfPNimag
(
1068671.0 / (45045.0 * jnp.sqrt(35.0))
- (1088119.0 * eta) / (6435.0 * jnp.sqrt(35.0))
+ (293758.0 * eta2) / (1053.0 * jnp.sqrt(35.0))
- (226097.0 * eta3) / (3861.0 * jnp.sqrt(35.0))
), # amp3PNreal
0.0, # amp3PNimag
0.0, # amp3halfPNreal
0.0, # amp3halfPNimag
0.0, # amplog
)
# Mode 55
def get_55():
return (
0.0, # ampN
0.0, # amp0halfPNreal
0.0, # amp0halfPNimag
0.0, # amp1PNreal
0.0, # amp1PNimag
(625.0 * delta) / (96.0 * jnp.sqrt(66.0))
- (625.0 * delta * eta) / (48.0 * jnp.sqrt(66.0)), # amp1halfPNreal
0.0, # amp1halfPNimag
0.0, # amp2PNreal
0.0, # amp2PNimag
(
(-164375.0 * delta) / (3744.0 * jnp.sqrt(66.0))
+ (26875.0 * delta * eta) / (234.0 * jnp.sqrt(66.0))
- (2500.0 * jnp.sqrt(2.0 / 33.0) * delta * eta2) / 117.0
), # amp2halfPNreal
0.0, # amp2halfPNimag
(3125.0 * delta * jnp.pi) / (96.0 * jnp.sqrt(66.0))
- (3125.0 * delta * eta * jnp.pi) / (48.0 * jnp.sqrt(66.0)), # amp3PNreal
(
(-113125.0 * delta) / (1344.0 * jnp.sqrt(66.0))
+ (17639.0 * delta * eta) / (80.0 * jnp.sqrt(66.0))
+ (3125.0 * delta * jnp.log(5.0 / 2.0)) / (48.0 * jnp.sqrt(66.0))
- (3125.0 * delta * eta * jnp.log(5.0 / 2.0)) / (24.0 * jnp.sqrt(66.0))
), # amp3PNimag
0.0, # amp3halfPNreal
0.0, # amp3halfPNimag
0.0, # amplog
)
# Default (zeros)
def get_default():
return (0.0,) * 16
# Select based on mode
# We use lax.switch or nested conds. Since modes are integers, we can use select/cond.
# But mode is passed as Int.
# We can use a simple python dispatch if mode is static, but here it might be traced.
# However, usually mode is static in these contexts.
# If mode is traced, we need lax.cond.
# Map mode integer to index 0..4
# 22->0, 21->1, 33->2, 44->3, 55->4
# Using lax.cond chain
res = jax.lax.cond(
mode == 22,
get_22,
lambda: jax.lax.cond(
mode == 21,
get_21,
lambda: jax.lax.cond(
mode == 33,
get_33,
lambda: jax.lax.cond(
mode == 44,
get_44,
lambda: jax.lax.cond(mode == 55, get_55, get_default),
),
),
),
)
(
ampN,
amp0halfPNreal,
amp0halfPNimag,
amp1PNreal,
amp1PNimag,
amp1halfPNreal,
amp1halfPNimag,
amp2PNreal,
amp2PNimag,
amp2halfPNreal,
amp2halfPNimag,
amp3PNreal,
amp3PNimag,
amp3halfPNreal,
amp3halfPNimag,
amplog,
) = res
pn_real = jnp.array(
[
ampN,
amp0halfPNreal,
amp1PNreal,
amp1halfPNreal,
amp2PNreal,
amp2halfPNreal,
amp3PNreal,
amp3halfPNreal,
amplog,
]
)
pn_imag = jnp.array(
[
amp0halfPNimag,
amp1PNimag,
amp1halfPNimag,
amp2PNimag,
amp2halfPNimag,
amp3PNimag,
amp3halfPNimag,
]
)
return pn_real, pn_imag, fac0
[docs]
@jax.jit
def compute_amplitude_coeffs_22(
wf_pafams: WaveformParams,
phase_coeffs: PhaseCoeffs,
) -> AmplitudeCoeffs:
"""
Compute all amplitude coefficients for the 22 mode.
Parameters
----------
wf_pafams : WaveformParams
Waveform parameters.
phase_coeffs : PhaseCoeffs
Phase coefficients for the 22 mode.
Returns
-------
AmplitudeCoeffs
All amplitude coefficients for the 22 mode.
"""
mode = 22
pn_real, pn_imag, fac0 = _compute_pn_amplitude_coeffs(
wf_pafams.eta,
wf_pafams.delta,
wf_pafams.chi1,
wf_pafams.chi2,
wf_pafams.m1,
wf_pafams.m2,
mode,
)
# Inspiral Coefficients
tinsppoints = jnp.array([-2000.0, -250.0, -150.0])
ampInspCP1 = fits.inspiral_amp_cp(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, 22, 1
)
ampInspCP2 = fits.inspiral_amp_cp(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, 22, 2
)
ampInspCP3 = fits.inspiral_amp_cp(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, 22, 3
)
ampInspCP = jnp.array([ampInspCP1, ampInspCP2, ampInspCP3])
inspC1, inspC2, inspC3 = _solve_inspiral_amplitude_system(
tinsppoints, ampInspCP, wf_pafams.eta, pn_real, pn_imag, phase_coeffs, fac0
)
pseudo_pn = jnp.array([inspC1, inspC2, inspC3])
# Ringdown Coefficients
af = fits.final_spin_2017(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
Mf = fits.final_mass_2017(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
alpha1RD = 2.0 * jnp.pi * fits.fdamp_22(af) / Mf
alpha2RD = 2.0 * jnp.pi * fits.fdamp_n2_22(af) / Mf
alpha21RD = 0.5 * (alpha2RD - alpha1RD)
alpha1RD_prec = alpha1RD
alpha2RD_prec = alpha2RD
alpha21RD_prec = alpha21RD
ampPeak = fits.peak_amp_22(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
c3 = fits.rd_amp_c3_22(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
c2 = alpha21RD
c2_prec = alpha21RD_prec
tanhc3 = jnp.tanh(c3)
coshc3 = jnp.cosh(c3)
limit = 0.5 * alpha1RD / tanhc3
c2 = jnp.where(c2 > jnp.abs(limit), -limit, c2)
limit_prec = 0.5 * alpha1RD_prec / tanhc3
c2_prec = jnp.where(c2_prec > jnp.abs(limit_prec), -limit_prec, c2_prec)
c1 = ampPeak * alpha1RD * coshc3 * coshc3 / c2
c4 = ampPeak - c1 * tanhc3
c1_prec = ampPeak * alpha1RD_prec * coshc3 * coshc3 / c2_prec
c4_prec = ampPeak - c1_prec * tanhc3
# Intermediate Coefficients
inspiral_cut = -150.0
tshift = fits.tshift_22(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
ringdown_cut = tshift
ampMergerCP1 = fits.intermediate_amp_cp1(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, 22
)
tcpMerger = -25.0
mergerC1, mergerC2, mergerC3, mergerC4, dampMECO = (
_solve_intermediate_amplitude_system(
inspiral_cut,
tcpMerger,
tshift,
alpha1RD,
ampPeak,
ampMergerCP1,
wf_pafams.eta,
pn_real,
pn_imag,
pseudo_pn,
phase_coeffs,
fac0,
)
)
return AmplitudeCoeffs(
mode=mode,
pn_real_coeffs=pn_real,
pn_imag_coeffs=pn_imag,
inspC1=inspC1,
inspC2=inspC2,
inspC3=inspC3,
alpha1RD=alpha1RD,
alpha1RD_prec=alpha1RD_prec,
c1_prec=c1_prec,
c2_prec=c2_prec,
c3=c3,
c4_prec=c4_prec,
mergerC1=mergerC1,
mergerC2=mergerC2,
mergerC3=mergerC3,
mergerC4=mergerC4,
inspiral_cut=inspiral_cut,
ringdown_cut=ringdown_cut,
tshift=tshift,
fac0=fac0,
ampPeak=ampPeak,
omegaCutPNAMP=jnp.array(0.0),
phiCutPNAMP=jnp.array(0.0),
)
[docs]
@jax.jit
def compute_amplitude_coeffs_hm(
wf_pafams: WaveformParams,
phase_coeffs_22: PhaseCoeffs,
mode: int | Array,
) -> AmplitudeCoeffs:
"""
Compute all amplitude coefficients for a given higher mode.
Parameters
----------
wf_pafams : WaveformParams
Waveform parameters.
phase_coeffs_22 : PhaseCoeffs
Phase coefficients for the 22 mode.
mode : int | Array
The higher mode to compute (e.g., 21, 33, 44, 55).
Returns
-------
AmplitudeCoeffs
All amplitude coefficients for the specified higher mode.
"""
pn_real, pn_imag, fac0 = _compute_pn_amplitude_coeffs(
wf_pafams.eta,
wf_pafams.delta,
wf_pafams.chi1,
wf_pafams.chi2,
wf_pafams.m1,
wf_pafams.m2,
mode,
)
# Inspiral Coefficients
tinsppoints = jnp.array([-2000.0, -250.0, -150.0])
# Vectorized fit call for collocation points
# fits.inspiral_amp_cp(eta, chi1, chi2, mode, k)
ampInspCP1 = fits.inspiral_amp_cp(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode, 1
)
ampInspCP2 = fits.inspiral_amp_cp(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode, 2
)
ampInspCP3 = fits.inspiral_amp_cp(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode, 3
)
ampInspCP = jnp.array([ampInspCP1, ampInspCP2, ampInspCP3])
inspC1, inspC2, inspC3 = _solve_inspiral_amplitude_system(
tinsppoints, ampInspCP, wf_pafams.eta, pn_real, pn_imag, phase_coeffs_22, fac0
)
pseudo_pn = jnp.array([inspC1, inspC2, inspC3])
# Ringdown Coefficients
af = fits.final_spin_2017(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
Mf = fits.final_mass_2017(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2)
alpha1RD = 2.0 * jnp.pi * fits.fdamp(af, mode) / Mf
alpha2RD = 2.0 * jnp.pi * fits.fdamp_n2(af, mode) / Mf
alpha21RD = 0.5 * (alpha2RD - alpha1RD)
alpha1RD_prec = alpha1RD
alpha2RD_prec = alpha2RD
alpha21RD_prec = alpha21RD
ampPeak = fits.peak_amp(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode)
c3 = fits.rd_amp_c3(wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode)
c2 = alpha21RD
c2_prec = alpha21RD_prec
tanhc3 = jnp.tanh(c3)
coshc3 = jnp.cosh(c3)
limit = jnp.abs(0.5 * alpha1RD / tanhc3)
c2 = jnp.where(c2 > limit, -limit, c2)
limit_prec = jnp.abs(0.5 * alpha1RD_prec / tanhc3)
c2_prec = jnp.where(c2_prec > limit_prec, -limit_prec, c2_prec)
c1 = ampPeak * alpha1RD * coshc3 * coshc3 / c2
c4 = ampPeak - c1 * tanhc3
c1_prec = ampPeak * alpha1RD_prec * coshc3 * coshc3 / c2_prec
c4_prec = ampPeak - c1_prec * tanhc3
# Intermediate Coefficients
inspiral_cut = -150.0
tshift = fits.tshift(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode
) # HM uses 22 tshift?
# phenomxpy: self.tshift = IMRPhenomT_tshift(self._pWF)
# IMRPhenomT_tshift uses 22 mode fits.
ringdown_cut = tshift
ampMergerCP1 = fits.intermediate_amp_cp1(
wf_pafams.eta, wf_pafams.chi1, wf_pafams.chi2, mode
)
tcpMerger = -25.0
mergerC1, mergerC2, mergerC3, mergerC4, dampMECO = (
_solve_intermediate_amplitude_system(
inspiral_cut,
tcpMerger,
tshift,
alpha1RD,
ampPeak,
ampMergerCP1,
wf_pafams.eta,
pn_real,
pn_imag,
pseudo_pn,
phase_coeffs_22,
fac0,
)
)
# Calculate phiCutPNAMP
omega_cut = imr_omega(inspiral_cut, wf_pafams.eta, phase_coeffs_22)
x_cut = jnp.power(omega_cut * 0.5, 2.0 / 3.0)
amp2 = _inspiral_ansatz_amplitude(x_cut, fac0, pn_real, pn_imag, pseudo_pn)
phiCutPNAMP = jnp.arctan2(jnp.imag(amp2), jnp.real(amp2))
# Adjust if real part is negative (copysign check in phenomxpy)
# if np.copysign(1, np.real(amp2)) == -1:
phiCutPNAMP = jax.lax.cond(
jnp.real(amp2) < 0, lambda p: p + jnp.pi, lambda p: p, phiCutPNAMP
)
omegaCutPNAMP = -jnp.real(
_der_complex_amp_orientation(
inspiral_cut,
wf_pafams.eta,
pn_real,
pn_imag,
pseudo_pn,
phase_coeffs_22,
fac0,
return_phase=True,
)
)
return AmplitudeCoeffs(
mode=mode,
pn_real_coeffs=pn_real,
pn_imag_coeffs=pn_imag,
inspC1=inspC1,
inspC2=inspC2,
inspC3=inspC3,
alpha1RD=alpha1RD,
alpha1RD_prec=alpha1RD_prec,
c1_prec=c1_prec,
c2_prec=c2_prec,
c3=c3,
c4_prec=c4_prec,
mergerC1=mergerC1,
mergerC2=mergerC2,
mergerC3=mergerC3,
mergerC4=mergerC4,
inspiral_cut=inspiral_cut,
ringdown_cut=ringdown_cut,
tshift=tshift,
fac0=fac0,
ampPeak=ampPeak,
omegaCutPNAMP=omegaCutPNAMP,
phiCutPNAMP=phiCutPNAMP,
)
[docs]
@jax.jit
def imr_amplitude(
time: Array,
eta: Array,
amp_coeffs: AmplitudeCoeffs,
phase_coeffs_22: PhaseCoeffs,
) -> Array:
"""
Compute IMR amplitude at given times for a specific mode.
Parameters
----------
time : Array
Times at which to compute the amplitude.
eta : Array
Symmetric mass ratio.
amp_coeffs : AmplitudeCoeffs
Amplitude coefficients for the mode.
phase_coeffs_22 : PhaseCoeffs
Phase coefficients for the 22 mode.
Returns
-------
Array
Computed amplitude at the given times.
"""
def _amp_scalar(t: Array) -> Array:
is_post_inspiral = t >= amp_coeffs.inspiral_cut
is_ringdown = t >= amp_coeffs.ringdown_cut
# 0 if insp, 1 if interm, 2 if ringdown
region_idx = is_post_inspiral.astype(jnp.int32) + is_ringdown.astype(jnp.int32)
def _inspiral(t):
# Need omega from 22 mode
omega = imr_omega(t, eta, phase_coeffs_22)
x = jnp.power(omega * 0.5, 2.0 / 3.0)
pseudo_pn = jnp.array(
[amp_coeffs.inspC1, amp_coeffs.inspC2, amp_coeffs.inspC3]
)
return _inspiral_ansatz_amplitude(
x,
amp_coeffs.fac0,
amp_coeffs.pn_real_coeffs,
amp_coeffs.pn_imag_coeffs,
pseudo_pn,
)
def _intermediate(t):
return _intermediate_ansatz_amplitude(
t,
amp_coeffs.mergerC1,
amp_coeffs.mergerC2,
amp_coeffs.mergerC3,
amp_coeffs.mergerC4,
amp_coeffs.alpha1RD,
amp_coeffs.tshift,
)
def _ringdown(t):
return _ringdown_ansatz_amplitude(
t,
amp_coeffs.c1_prec,
amp_coeffs.c2_prec,
amp_coeffs.c3,
amp_coeffs.c4_prec,
amp_coeffs.alpha1RD_prec,
amp_coeffs.tshift,
)
# def _post_inspiral(t):
# return jax.lax.cond(
# t < amp_coeffs.ringdown_cut, _intermediate, _ringdown, t
# )
# return jax.lax.cond(t < amp_coeffs.inspiral_cut, _inspiral, _post_inspiral, t)
return jax.lax.switch(
region_idx,
[_inspiral, _intermediate, _ringdown],
t,
)
# Vectorize
time = jnp.asarray(time)
time_shape = jnp.shape(time)
time_flat = jnp.reshape(time, (-1,))
amps_flat = jax.vmap(_amp_scalar)(time_flat)
amps = jnp.reshape(amps_flat, time_shape)
return amps
[docs]
def imr_amplitude_dot(
time: Array,
eta: Array,
amp_coeffs: AmplitudeCoeffs,
phase_coeffs_22: PhaseCoeffs,
return_amplitude: bool = False,
) -> Array | Tuple[Array, Array]:
"""
Compute the IMR amplitude time derivative :math:`\\dot{A}(t)` at given times for a specific mode, using JAX automatic differentiation.
Parameters
----------
time : Array
Times at which to compute the amplitude.
eta : Array
Symmetric mass ratio.
amp_coeffs : AmplitudeCoeffs
Amplitude coefficients for the mode.
phase_coeffs_22 : PhaseCoeffs
Phase coefficients for the 22 mode.
return_amplitude : bool, default False
Whether to return the amplitude as well.
Returns
-------
Array | Tuple[Array, Array]
Computed amplitude time derivative at the given times. If `return_amplitude` is True, the function returns the amplitude value as well.
"""
A, dA_dt = jax.jvp(
lambda t: imr_amplitude(t, eta, amp_coeffs, phase_coeffs_22),
(time,),
(jnp.ones_like(time),),
)
if return_amplitude:
return A, dA_dt
return dA_dt
# =============================================================================
# Helper functions
# =============================================================================
@jax.jit
def _pn_ansatz_amplitude(
x: Array,
fac0: Array,
pn_real: jnp.ndarray,
pn_imag: jnp.ndarray,
) -> Array:
"""Evaluate PN ansatz amplitude at x = (omega/2)^(2/3)."""
xhalf = jnp.sqrt(x)
x1half = x * xhalf
x2 = x * x
x2half = x2 * xhalf
x3 = x2 * x
x3half = x3 * xhalf
# Real part
ampreal = (
pn_real[0]
+ pn_real[1] * xhalf
+ pn_real[2] * x
+ pn_real[3] * x1half
+ pn_real[4] * x2
+ pn_real[5] * x2half
+ pn_real[6] * x3
+ pn_real[7] * x3half
+ pn_real[8] * jnp.log(16.0 * x) * x3
)
# Imaginary part
ampimag = (
pn_imag[0] * xhalf
+ pn_imag[1] * x
+ pn_imag[2] * x1half
+ pn_imag[3] * x2
+ pn_imag[4] * x2half
+ pn_imag[5] * x3
+ pn_imag[6] * x3half
)
return fac0 * x * (ampreal + 1j * ampimag)
@jax.jit
def _inspiral_ansatz_amplitude(
x: Array,
fac0: Array,
pn_real: jnp.ndarray,
pn_imag: jnp.ndarray,
pseudo_pn: jnp.ndarray,
) -> Array:
"""Evaluate inspiral ansatz amplitude (PN + pseudo-PN)."""
# PN part
pn_amp = _pn_ansatz_amplitude(x, fac0, pn_real, pn_imag)
# Pseudo-PN part (only affects real part)
x2 = x * x
x4 = x2 * x2
x4half = x4 * jnp.sqrt(x)
x5 = x4 * x
pseudo_pn_term = pseudo_pn[0] * x4 + pseudo_pn[1] * x4half + pseudo_pn[2] * x5
return pn_amp + fac0 * x * pseudo_pn_term
@jax.jit
def _intermediate_ansatz_amplitude(
time: Array,
c1: Array,
c2: Array,
c3: Array,
c4: Array,
alpha: Array,
tshift: Array,
) -> Array:
"""Evaluate intermediate amplitude ansatz."""
dt = time - tshift
phi = alpha * dt
phi2 = 2.0 * phi
sech1 = 1.0 / jnp.cosh(phi)
sech2 = 1.0 / jnp.cosh(phi2)
return c1 + c2 * sech1 + c3 * jnp.power(sech2, 1.0 / 7.0) + c4 * dt * dt + 0.0j
@jax.jit
def _ringdown_ansatz_amplitude(
time: Array,
c1: Array,
c2: Array,
c3: Array,
c4: Array,
alpha: Array,
tshift: Array,
) -> Array:
"""Evaluate ringdown amplitude ansatz."""
dt = time - tshift
return jnp.exp(-alpha * dt) * (c1 * jnp.tanh(c2 * dt + c3) + c4) + 0.0j
def _solve_inspiral_amplitude_system(
times: jnp.ndarray,
amp_vals: jnp.ndarray,
eta: float | Array,
pn_real: jnp.ndarray,
pn_imag: jnp.ndarray,
phase_coeffs: PhaseCoeffs,
fac0: float | Array,
) -> tuple:
"""Solve for inspiral pseudo-PN coefficients."""
# We need omega at collocation points
# Use vmap for efficiency
omegas = imr_omega(times, eta, phase_coeffs)
xx = jnp.power(0.5 * omegas, 2.0 / 3.0)
xxhalf = jnp.sqrt(xx)
xx4 = xx * xx * xx * xx
# Compute PN offset
pseudo_pn_zero = jnp.zeros(3)
def get_offset(x):
return jnp.real(
_inspiral_ansatz_amplitude(x, fac0, pn_real, pn_imag, pseudo_pn_zero)
)
amp_offsets = jax.vmap(get_offset)(xx)
B = (1.0 / fac0 / xx) * (amp_vals - amp_offsets)
# Matrix
# # c1 x^4 + c2 x^4.5 + c3 x^5
row_0 = jnp.array([xx4[0], xx4[0] * xxhalf[0], xx4[0] * xx[0]])
row_1 = jnp.array([xx4[1], xx4[1] * xxhalf[1], xx4[1] * xx[1]])
row_2 = jnp.array([xx4[2], xx4[2] * xxhalf[2], xx4[2] * xx[2]])
matrix = jnp.array(
[
row_0,
row_1,
row_2,
]
)
solution = solve_3x3_explicit(matrix, B) # jnp.linalg.solve(matrix, B)
return solution[0], solution[1], solution[2]
def _der_complex_amp_orientation(
time: float | Array,
eta: float | Array,
pn_real: jnp.ndarray,
pn_imag: jnp.ndarray,
pseudo_pn: jnp.ndarray,
phase_coeffs: PhaseCoeffs,
fac0: float | Array,
return_phase: bool = False,
) -> Array:
"""Compute derivative of complex inspiral amplitude."""
omega = imr_omega(time, eta, phase_coeffs)
x = jnp.power(omega * 0.5, 2.0 / 3.0)
xhalf = jnp.sqrt(x)
x1half = x * xhalf
x2 = x * x
x2half = x2 * xhalf
x3 = x2 * x
x3half = x3 * xhalf
x4 = x2 * x2
x4half = x4 * xhalf
x5 = x3 * x2
# Real part
ampreal = (
pn_real[0]
+ pn_real[1] * xhalf
+ pn_real[2] * x
+ pn_real[3] * x1half
+ pn_real[4] * x2
+ pn_real[5] * x2half
+ pn_real[6] * x3
+ pn_real[7] * x3half
+ pn_real[8] * jnp.log(16.0 * x) * x3
+ pseudo_pn[0] * x4
+ pseudo_pn[1] * x4half
+ pseudo_pn[2] * x5
)
# Imaginary part
ampimag = (
pn_imag[0] * xhalf
+ pn_imag[1] * x
+ pn_imag[2] * x1half
+ pn_imag[3] * x2
+ pn_imag[4] * x2half
+ pn_imag[5] * x3
+ pn_imag[6] * x3half
)
# Derivatives w.r.t x
dampreal = (
0.5 * pn_real[1] / xhalf
+ pn_real[2]
+ 1.5 * pn_real[3] * xhalf
+ 2.0 * pn_real[4] * x
+ 2.5 * pn_real[5] * x1half
+ 3.0 * pn_real[6] * x2
+ 3.5 * pn_real[7] * x2half
+ pn_real[8] * x2 * (1.0 + 3.0 * jnp.log(16.0 * x))
+ 4.0 * pseudo_pn[0] * x3
+ 4.5 * pseudo_pn[1] * x3half
+ 5.0 * pseudo_pn[2] * x4
)
dampimag = (
0.5 * pn_imag[0] / xhalf
+ pn_imag[1]
+ 1.5 * pn_imag[2] * xhalf
+ 2.0 * pn_imag[3] * x
+ 2.5 * pn_imag[4] * x1half
+ 3.0 * pn_imag[5] * x2
+ 3.5 * pn_imag[6] * x2half
)
der_x_per_omega = jnp.cbrt(2.0 / omega) / 3.0
# domega/dt
# We need to handle the branch manually or use the helper from phase
# But phase helper is for 22 mode coefficients.
# We have phase_coeffs which is 22 mode.
def inspiral_branch(t):
omega_pn_coeffs = jnp.array(
[
phase_coeffs.omega1PN,
phase_coeffs.omega1halfPN,
phase_coeffs.omega2PN,
phase_coeffs.omega2halfPN,
phase_coeffs.omega3PN,
phase_coeffs.omega3halfPN,
]
)
omega_pseudo_pn_coeffs = jnp.array(
[
phase_coeffs.omegaInspC1,
phase_coeffs.omegaInspC2,
phase_coeffs.omegaInspC3,
phase_coeffs.omegaInspC4,
phase_coeffs.omegaInspC5,
phase_coeffs.omegaInspC6,
]
)
return _inspiral_ansatz_domega(t, eta, omega_pn_coeffs, omega_pseudo_pn_coeffs)
def merger_branch(t):
# This logic is duplicated from phase.py but we need it here
# Or we can expose it in phase.py
# I exposed _ringdown_ansatz_domega but not the intermediate one directly?
# Wait, phase.py has compute_domega_cut which does exactly this switch.
# But it switches at tCut_threshold.
# Here we want domega/dt at 'time'.
# If time is in intermediate region, we use merger branch.
arcsinh = jnp.arcsinh(phase_coeffs.alpha1RD * t)
return (
-phase_coeffs.omegaRING
/ jnp.sqrt(1.0 + (phase_coeffs.alpha1RD * t) ** 2)
* (
phase_coeffs.domegaPeak
+ phase_coeffs.alpha1RD
* (
2.0 * phase_coeffs.omegaMergerC1 * arcsinh
+ 3.0 * phase_coeffs.omegaMergerC2 * arcsinh * arcsinh
+ 4.0 * phase_coeffs.omegaMergerC3 * arcsinh**3
)
)
)
der_omega_per_t = jax.lax.cond(
time < phase_coeffs.inspiral_cut, inspiral_branch, merger_branch, time
)
amp = jnp.abs(ampreal + 1j * ampimag)
if return_phase:
return (
(dampimag * ampreal - dampreal * ampimag)
/ (amp * amp)
* der_x_per_omega
* der_omega_per_t
)
return (
fac0
* (ampreal * (dampreal * x + ampreal) + ampimag * (dampimag * x + ampimag))
/ amp
* der_x_per_omega
* der_omega_per_t
)
def _solve_intermediate_amplitude_system(
tCut: float | Array,
tcpMerger: float | Array,
tshift: float | Array,
alpha1RD: float | Array,
ampPeak: float | Array,
ampMergerCP1: float | Array,
eta: float | Array,
pn_real: jnp.ndarray,
pn_imag: jnp.ndarray,
pseudo_pn: jnp.ndarray,
phase_coeffs: PhaseCoeffs,
fac0: Array,
) -> tuple:
"""Solve for intermediate amplitude coefficients."""
# Compute omega at tCut
omega_cut = imr_omega(tCut, eta, phase_coeffs)
x_cut = jnp.power(omega_cut * 0.5, 2.0 / 3.0)
# Inspiral amplitude at tCut
ampinsp_cplx = _inspiral_ansatz_amplitude(x_cut, fac0, pn_real, pn_imag, pseudo_pn)
ampinsp = ampinsp_cplx
# Match sign
ampinsp = jnp.copysign(jnp.abs(ampinsp), jnp.real(ampinsp))
phi = alpha1RD * (tCut - tshift)
phi2 = 2.0 * phi
sech1 = 1.0 / jnp.cosh(phi)
sech2 = 1.0 / jnp.cosh(phi2)
# Row 0: Match amplitude at tCut
row_0 = jnp.array([1.0, sech1, jnp.power(sech2, 1.0 / 7.0), (tCut - tshift) ** 2])
# Row 1: Match amplitude at tcpMerger
phib = alpha1RD * (tcpMerger - tshift)
sech1b = 1.0 / jnp.cosh(phib)
sech2b = 1.0 / jnp.cosh(2.0 * phib)
row_1 = jnp.array(
[1.0, sech1b, jnp.power(sech2b, 1.0 / 7.0), (tcpMerger - tshift) ** 2]
)
# Row 2: Match amplitude at peak (t=tshift)
row_2 = jnp.ones(4)
# Row 3: Match derivative at tCut
dampMECO = jnp.copysign(1.0, jnp.real(ampinsp_cplx)) * _der_complex_amp_orientation(
tCut, eta, pn_real, pn_imag, pseudo_pn, phase_coeffs, fac0, return_phase=False
)
tanh = jnp.tanh(phi)
sinh = jnp.sinh(phi2)
aux1 = -alpha1RD * sech1 * tanh
aux2 = (-2.0 / 7.0) * alpha1RD * sinh * jnp.power(sech2, 8.0 / 7.0)
aux3 = 2.0 * (tCut - tshift)
row_3 = jnp.array([0.0, aux1, aux2, aux3])
matrix = jnp.array([row_0, row_1, row_2, row_3])
B = jnp.array([ampinsp, ampMergerCP1, ampPeak, dampMECO])
solution = jnp.linalg.solve(matrix, B)
return solution[0], solution[1], solution[2], solution[3], dampMECO