Source code for phentax.core.pn_coeffs

# Copyright (C) 2025 Alessandro Santini
# SPDX-License-Identifier: MIT

# Credits for the original implementations: Cecilio García Quirós

"""
Post-Newtonian (PN) coefficients for IMRPhenomT(HM).

Contains the TaylorT3 omega PN coefficients and amplitude PN coefficients
for all supported modes. These are mode-dependent and spin-dependent.

References:
- TaylorT3 omega: Eq. A5 in arXiv:2012.11923 (IMRPhenomT paper)
- Amplitude PN: Eq. 9.4 Blanchet 2008, Eq. 43 Faye 2012, Eq. 4.17 Arun, Eq. 4.27 Buonanno
"""


import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array


[docs] class OmegaPNCoeffs(eqx.Module): """TaylorT3 omega PN coefficients (up to 3.5PN).""" omega1PN: float | Array omega1halfPN: float | Array omega2PN: float | Array omega2halfPN: float | Array omega3PN: float | Array omega3halfPN: float | Array
[docs] class AmpPNCoeffs(eqx.Module): """ Amplitude PN coefficients for a specific mode. """ # Real part coefficients ampN: float | Array # Leading order amp0halfPNreal: float | Array amp1PNreal: float | Array amp1halfPNreal: float | Array amp2PNreal: float | Array amp2halfPNreal: float | Array amp3PNreal: float | Array amp3halfPNreal: float | Array amplog: float | Array # Log term coefficient # Imaginary part coefficients amp0halfPNimag: float | Array amp1PNimag: float | Array amp1halfPNimag: float | Array amp2PNimag: float | Array amp2halfPNimag: float | Array amp3PNimag: float | Array amp3halfPNimag: float | Array # Prefactor fac0: float | Array
[docs] @jax.jit def compute_omega_pn_coeffs( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, ) -> OmegaPNCoeffs: """ Compute TaylorT3 omega PN coefficients. Eq. A5 in arXiv:2012.11923. Paper misses the term eta^3 * 235925/1769472 at 3PN order. Parameters ---------- eta : Array Symmetric mass ratio. chi1, chi2 : Array Dimensionless spin z-components. delta : Array Mass difference ratio (m1-m2)/M. m1, m2 : Array Component masses as fractions of total mass (m1+m2=1). Returns ------- OmegaPNCoeffs PN coefficients for TaylorT3 omega. """ eta2 = eta * eta eta3 = eta * eta2 chi12 = chi1 * chi1 chi22 = chi2 * chi2 chi23 = chi2 * chi22 omega1PN = 743 / 2688 + (11 * eta) / 32 omega1halfPN = (-19 * (chi1 + chi2) * eta) / 80 + ( -113 * (-2 * chi1 * m1 - 2 * chi2 * m2) - 96 * jnp.pi ) / 320 omega2PN = ( ((56975 + 61236 * chi12 - 119448 * chi1 * chi2 + 61236 * chi22) * eta) / 258048 + (371 * eta2) / 2048 + (1855099 - 3429216 * chi12 * m1 - 3429216 * chi22 * m2) / 14450688 ) omega2halfPN = ( (-17 * (chi1 + chi2) * eta2) / 128 + (-146597 * (-2 * chi1 * m1 - 2 * chi2 * m2) - 46374 * jnp.pi) / 129024 + ( eta * ( -2 * (chi1 * (1213 - 63 * delta) + chi2 * (1213 + 63 * delta)) + 117 * jnp.pi ) ) / 2304 ) omega3PN = ( -720817631400877 / 288412611379200 - (16928263 * chi12) / 137625600 - (16928263 * chi22) / 137625600 - (16928263 * chi12 * delta) / 137625600 + (16928263 * chi22 * delta) / 137625600 + ( (-2318475 + 18767224 * chi12 - 54663952 * chi1 * chi2 + 18767224 * chi22) * eta2 ) / 137625600 + (235925 * eta3) / 1769472 + (107 * jnp.euler_gamma) / 280 - (6127 * chi1 * jnp.pi) / 12800 - (6127 * chi2 * jnp.pi) / 12800 - (6127 * chi1 * delta * jnp.pi) / 12800 + (6127 * chi2 * delta * jnp.pi) / 12800 + ( eta * ( 632550449425 + 35200873512 * chi12 - 28527282000 * chi1 * chi2 + 9605339856 * chi12 * delta - 1512 * chi22 * (-23281001 + 6352738 * delta) + 34172264448 * (chi1 + chi2) * jnp.pi - 22912243200 * jnp.pi**2 ) ) / 104044953600 + (53 * jnp.pi**2) / 200 + (107 * jnp.log(2)) / 280 ) omega3halfPN = ( (-12029 * (chi1 + chi2) * eta3) / 92160 + ( eta2 * ( 507654 * chi1 * chi22 - 838782 * chi23 + chi2 * (-840149 + 507654 * chi12 - 870576 * delta) + chi1 * (-840149 - 838782 * chi12 + 870576 * delta) + 1701228 * jnp.pi ) ) / 15482880 + ( eta * ( -1134 * chi23 * (-206917 + 71931 * delta) + chi1 * ( -1496368361 - 429508815 * delta + 1134 * chi12 * (206917 + 71931 * delta) ) - chi2 * (1496368361 - 429508815 * delta + 437064012 * chi12 * m1) - 437064012 * chi1 * chi22 * m2 - 144 * (488825 + 923076 * chi12 - 1782648 * chi1 * chi2 + 923076 * chi22) * jnp.pi ) ) / 185794560 + ( -2 * chi1 * (-6579635551 + 535759434 * chi12) * m1 + 13159271102 * chi2 * m2 - 1071518868 * chi23 * m2 + (-565550067 + 930460608 * chi12 * m1 + 930460608 * chi22 * m2) * jnp.pi ) / 1300561920 ) return OmegaPNCoeffs( omega1PN=omega1PN, omega1halfPN=omega1halfPN, omega2PN=omega2PN, omega2halfPN=omega2halfPN, omega3PN=omega3PN, omega3halfPN=omega3halfPN, )
[docs] @jax.jit def compute_amp_pn_coeffs_22( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, ) -> AmpPNCoeffs: """ Compute amplitude PN coefficients for the (2,2) mode. 3PN non-spinning from Eq 9.4 Blanchet 2008. 3.5PN non-spinning from Eq. 43 Faye 2012. 1.5PN spinning from Eq 4.17 Arun. 2PN spinning from Eq. 4.27 Buonanno. Parameters ---------- eta : Array Symmetric mass ratio. chi1, chi2 : Array Dimensionless spin z-components. delta : Array Mass difference ratio (m1-m2)/M. m1, m2 : Array Component masses as fractions of total mass (m1+m2=1). Returns ------- AmpPNCoeffs Amplitude PN coefficients for the 22 mode. """ eta2 = eta * eta eta3 = eta * eta2 # Spin combinations chis = 0.5 * (chi1 + chi2) chia = 0.5 * (chi1 - chi2) S0 = m1 * chi1 + m2 * chi2 # Euler's constant euler_gamma = 0.5772156649015329 # Prefactor fac0 = 2.0 * eta * jnp.sqrt(16.0 * jnp.pi / 5.0) # PN coefficients for 22 mode ampN = 1.0 amp0halfPNreal = 0.0 amp0halfPNimag = 0.0 amp1PNreal = -107.0 / 42.0 + (55.0 * eta) / 42.0 amp1PNimag = 0.0 amp1halfPNreal = ( (-4.0 * chis) / 3.0 - (4.0 * chia * delta) / 3.0 + (4.0 * chis * eta) / 3.0 + 2.0 * jnp.pi ) amp1halfPNimag = 0.0 amp2PNreal = ( -2173.0 / 1512.0 - (1069.0 * eta) / 216.0 + (2047.0 * eta2) / 1512.0 + S0 * S0 ) amp2PNimag = 0.0 amp2halfPNreal = (-107.0 * jnp.pi) / 21.0 + (34.0 * eta * jnp.pi) / 21.0 amp2halfPNimag = -24.0 * eta amp3PNreal = ( 27027409.0 / 646800.0 - (278185.0 * eta) / 33264.0 - (20261.0 * eta2) / 2772.0 + (114635.0 * eta3) / 99792.0 - (856.0 * euler_gamma) / 105.0 + (2.0 * jnp.pi * jnp.pi) / 3.0 + (41.0 * eta * jnp.pi * jnp.pi) / 96.0 ) amp3PNimag = (428.0 * jnp.pi) / 105.0 amp3halfPNreal = ( (-2173.0 * jnp.pi) / 756.0 - (2495.0 * eta * jnp.pi) / 378.0 + (40.0 * eta2 * jnp.pi) / 27.0 ) amp3halfPNimag = (14333.0 * eta) / 162.0 - (4066.0 * eta2) / 945.0 amplog = -428.0 / 105.0 return AmpPNCoeffs( ampN=ampN, amp0halfPNreal=amp0halfPNreal, amp1PNreal=amp1PNreal, amp1halfPNreal=amp1halfPNreal, amp2PNreal=amp2PNreal, amp2halfPNreal=amp2halfPNreal, amp3PNreal=amp3PNreal, amp3halfPNreal=amp3halfPNreal, amplog=amplog, amp0halfPNimag=amp0halfPNimag, amp1PNimag=amp1PNimag, amp1halfPNimag=amp1halfPNimag, amp2PNimag=amp2PNimag, amp2halfPNimag=amp2halfPNimag, amp3PNimag=amp3PNimag, amp3halfPNimag=amp3halfPNimag, fac0=fac0, )
[docs] @jax.jit def compute_amp_pn_coeffs_21( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, ) -> AmpPNCoeffs: """Compute amplitude PN coefficients for the (2,1) mode.""" # Spin combinations chia = 0.5 * (chi1 - chi2) chis = 0.5 * (chi1 + chi2) Sc = m1 * m1 * chi1 + m2 * m2 * chi2 Sigmac = m2 * chi2 - m1 * chi1 eta2 = eta * eta fac0 = 2.0 * eta * jnp.sqrt(16.0 * jnp.pi / 5.0) ampN = 0.0 amp0halfPNreal = delta / 3.0 amp0halfPNimag = 0.0 amp1PNreal = -0.5 * chia - (chis * delta) / 2.0 amp1PNimag = 0.0 amp1halfPNreal = (-17.0 * delta) / 84.0 + (5.0 * delta * eta) / 21.0 amp1halfPNimag = 0.0 amp2PNreal = ( (delta * jnp.pi) / 3.0 - (43.0 * delta * Sc) / 21.0 - (79.0 * Sigmac) / 42.0 + (139.0 * eta * Sigmac) / 42.0 ) amp2PNimag = -delta / 6.0 - (delta * jnp.log(16.0)) / 6.0 amp2halfPNreal = ( (-43.0 * delta) / 378.0 - (509.0 * delta * eta) / 378.0 + (79.0 * delta * eta2) / 504.0 ) amp2halfPNimag = 0.0 amp3PNreal = (-17.0 * delta * jnp.pi) / 84.0 + (delta * eta * jnp.pi) / 14.0 amp3PNimag = ( (17.0 * delta) / 168.0 - (353.0 * delta * eta) / 84.0 + (17.0 * delta * jnp.log(16.0)) / 168.0 - (delta * eta * jnp.log(4096.0)) / 84.0 ) amp3halfPNreal = 0.0 amp3halfPNimag = 0.0 amplog = 0.0 return AmpPNCoeffs( ampN=ampN, amp0halfPNreal=amp0halfPNreal, amp1PNreal=amp1PNreal, amp1halfPNreal=amp1halfPNreal, amp2PNreal=amp2PNreal, amp2halfPNreal=amp2halfPNreal, amp3PNreal=amp3PNreal, amp3halfPNreal=amp3halfPNreal, amplog=amplog, amp0halfPNimag=amp0halfPNimag, amp1PNimag=amp1PNimag, amp1halfPNimag=amp1halfPNimag, amp2PNimag=amp2PNimag, amp2halfPNimag=amp2halfPNimag, amp3PNimag=amp3PNimag, amp3halfPNimag=amp3halfPNimag, fac0=fac0, )
[docs] @jax.jit def compute_amp_pn_coeffs_33( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, ) -> AmpPNCoeffs: """Compute amplitude PN coefficients for the (3,3) mode.""" Sc = m1 * m1 * chi1 + m2 * m2 * chi2 Sigmac = m2 * chi2 - m1 * chi1 eta2 = eta * eta fac0 = 2.0 * eta * jnp.sqrt(16.0 * jnp.pi / 5.0) # sqrt(15/14) and sqrt(105/2) etc sqrt_15_14 = jnp.sqrt(15.0 / 14.0) sqrt_105_2 = jnp.sqrt(105.0 / 2.0) sqrt_21_10 = jnp.sqrt(21.0 / 10.0) sqrt_3_70 = jnp.sqrt(3.0 / 70.0) ampN = 0.0 amp0halfPNreal = (3.0 * sqrt_15_14 * delta) / 4.0 amp0halfPNimag = 0.0 amp1PNreal = 0.0 amp1PNimag = 0.0 amp1halfPNreal = -3.0 * sqrt_15_14 * delta + (3.0 * sqrt_15_14 * delta * eta) / 2.0 amp1halfPNimag = 0.0 amp2PNreal = ( (9.0 * sqrt_15_14 * delta * jnp.pi) / 4.0 - (3.0 * sqrt_105_2 * delta * Sc) / 8.0 - (9.0 * sqrt_15_14 * Sigmac) / 8.0 + (27.0 * sqrt_15_14 * eta * Sigmac) / 8.0 ) amp2PNimag = (-9.0 * sqrt_21_10 * delta) / 4.0 + ( 9.0 * sqrt_15_14 * delta * jnp.log(3.0 / 2.0) ) / 2.0 amp2halfPNreal = ( (369.0 * sqrt_3_70 * delta) / 88.0 - (919.0 * sqrt_3_70 * delta * eta) / 22.0 + (887.0 * sqrt_3_70 * delta * eta2) / 88.0 ) amp2halfPNimag = 0.0 amp3PNreal = 0.0 amp3PNimag = 0.0 amp3halfPNreal = 0.0 amp3halfPNimag = 0.0 amplog = 0.0 return AmpPNCoeffs( ampN=ampN, amp0halfPNreal=amp0halfPNreal, amp1PNreal=amp1PNreal, amp1halfPNreal=amp1halfPNreal, amp2PNreal=amp2PNreal, amp2halfPNreal=amp2halfPNreal, amp3PNreal=amp3PNreal, amp3halfPNreal=amp3halfPNreal, amplog=amplog, amp0halfPNimag=amp0halfPNimag, amp1PNimag=amp1PNimag, amp1halfPNimag=amp1halfPNimag, amp2PNimag=amp2PNimag, amp2halfPNimag=amp2halfPNimag, amp3PNimag=amp3PNimag, amp3halfPNimag=amp3halfPNimag, fac0=fac0, )
[docs] @jax.jit def compute_amp_pn_coeffs_44( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, ) -> AmpPNCoeffs: """Compute amplitude PN coefficients for the (4,4) mode.""" eta2 = eta * eta eta3 = eta * eta2 fac0 = 2.0 * eta * jnp.sqrt(16.0 * jnp.pi / 5.0) sqrt_5_7 = jnp.sqrt(5.0 / 7.0) sqrt_35 = jnp.sqrt(35.0) sqrt_7_5 = jnp.sqrt(7.0 / 5.0) ampN = 0.0 amp0halfPNreal = 0.0 amp0halfPNimag = 0.0 amp1PNreal = (8.0 * sqrt_5_7) / 9.0 - (8.0 * sqrt_5_7 * eta) / 3.0 amp1PNimag = 0.0 amp1halfPNreal = 0.0 amp1halfPNimag = 0.0 amp2PNreal = ( -2372.0 / (99.0 * sqrt_35) + (5092.0 * sqrt_5_7 * eta) / 297.0 - (100.0 * sqrt_35 * eta2) / 99.0 ) amp2PNimag = 0.0 amp2halfPNreal = (32.0 * sqrt_5_7 * jnp.pi) / 9.0 - ( 32.0 * sqrt_5_7 * eta * jnp.pi ) / 3.0 amp2halfPNimag = ( (-16.0 * sqrt_7_5) / 3.0 + (1193.0 * eta) / (9.0 * sqrt_35) + (64.0 * sqrt_5_7 * jnp.log(2.0)) / 9.0 - (64.0 * sqrt_5_7 * eta * jnp.log(2.0)) / 3.0 ) amp3PNreal = ( 1068671.0 / (45045.0 * sqrt_35) - (1088119.0 * eta) / (6435.0 * sqrt_35) + (293758.0 * eta2) / (1053.0 * sqrt_35) - (226097.0 * eta3) / (3861.0 * sqrt_35) ) amp3PNimag = 0.0 amp3halfPNreal = 0.0 amp3halfPNimag = 0.0 amplog = 0.0 return AmpPNCoeffs( ampN=ampN, amp0halfPNreal=amp0halfPNreal, amp1PNreal=amp1PNreal, amp1halfPNreal=amp1halfPNreal, amp2PNreal=amp2PNreal, amp2halfPNreal=amp2halfPNreal, amp3PNreal=amp3PNreal, amp3halfPNreal=amp3halfPNreal, amplog=amplog, amp0halfPNimag=amp0halfPNimag, amp1PNimag=amp1PNimag, amp1halfPNimag=amp1halfPNimag, amp2PNimag=amp2PNimag, amp2halfPNimag=amp2halfPNimag, amp3PNimag=amp3PNimag, amp3halfPNimag=amp3halfPNimag, fac0=fac0, )
[docs] @jax.jit def compute_amp_pn_coeffs_55( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, ) -> AmpPNCoeffs: """Compute amplitude PN coefficients for the (5,5) mode.""" eta2 = eta * eta fac0 = 2.0 * eta * jnp.sqrt(16.0 * jnp.pi / 5.0) sqrt_66 = jnp.sqrt(66.0) sqrt_2_33 = jnp.sqrt(2.0 / 33.0) ampN = 0.0 amp0halfPNreal = 0.0 amp0halfPNimag = 0.0 amp1PNreal = 0.0 amp1PNimag = 0.0 amp1halfPNreal = (625.0 * delta) / (96.0 * sqrt_66) - (625.0 * delta * eta) / ( 48.0 * sqrt_66 ) amp1halfPNimag = 0.0 amp2PNreal = 0.0 amp2PNimag = 0.0 amp2halfPNreal = ( (-164375.0 * delta) / (3744.0 * sqrt_66) + (26875.0 * delta * eta) / (234.0 * sqrt_66) - (2500.0 * sqrt_2_33 * delta * eta2) / 117.0 ) amp2halfPNimag = 0.0 amp3PNreal = (3125.0 * delta * jnp.pi) / (96.0 * sqrt_66) - ( 3125.0 * delta * eta * jnp.pi ) / (48.0 * sqrt_66) amp3PNimag = ( (-113125.0 * delta) / (1344.0 * sqrt_66) + (17639.0 * delta * eta) / (80.0 * sqrt_66) + (3125.0 * delta * jnp.log(5.0 / 2.0)) / (48.0 * sqrt_66) - (3125.0 * delta * eta * jnp.log(5.0 / 2.0)) / (24.0 * sqrt_66) ) amp3halfPNreal = 0.0 amp3halfPNimag = 0.0 amplog = 0.0 return AmpPNCoeffs( ampN=ampN, amp0halfPNreal=amp0halfPNreal, amp1PNreal=amp1PNreal, amp1halfPNreal=amp1halfPNreal, amp2PNreal=amp2PNreal, amp2halfPNreal=amp2halfPNreal, amp3PNreal=amp3PNreal, amp3halfPNreal=amp3halfPNreal, amplog=amplog, amp0halfPNimag=amp0halfPNimag, amp1PNimag=amp1PNimag, amp1halfPNimag=amp1halfPNimag, amp2PNimag=amp2PNimag, amp2halfPNimag=amp2halfPNimag, amp3PNimag=amp3PNimag, amp3halfPNimag=amp3halfPNimag, fac0=fac0, )
[docs] def compute_amp_pn_coeffs( eta: float | Array, chi1: float | Array, chi2: float | Array, delta: float | Array, m1: float | Array, m2: float | Array, mode: int, ) -> AmpPNCoeffs: """ Compute amplitude PN coefficients for a given mode. Parameters ---------- eta : Array Symmetric mass ratio. chi1, chi2 : Array Dimensionless spin z-components. delta : Array Mass difference ratio (m1-m2)/M. m1, m2 : Array Component masses as fractions of total mass. mode : int Mode key (22, 21, 33, 44, 55). Returns ------- AmpPNCoeffs Amplitude PN coefficients for the mode. """ # Use lax.switch for JIT compatibility def mode_22(): return compute_amp_pn_coeffs_22(eta, chi1, chi2, delta, m1, m2) def mode_21(): return compute_amp_pn_coeffs_21(eta, chi1, chi2, delta, m1, m2) def mode_33(): return compute_amp_pn_coeffs_33(eta, chi1, chi2, delta, m1, m2) def mode_44(): return compute_amp_pn_coeffs_44(eta, chi1, chi2, delta, m1, m2) def mode_55(): return compute_amp_pn_coeffs_55(eta, chi1, chi2, delta, m1, m2) # Mode index mapping: 22->0, 21->1, 33->2, 44->3, 55->4 mode_idx = jax.lax.cond( mode == 22, lambda: 0, lambda: jax.lax.cond( mode == 21, lambda: 1, lambda: jax.lax.cond( mode == 33, lambda: 2, lambda: jax.lax.cond(mode == 44, lambda: 3, lambda: 4), # 55 ), ), ) return jax.lax.switch(mode_idx, [mode_22, mode_21, mode_33, mode_44, mode_55])
# Powers of 5 array for phase computation POWERS_OF_5 = jnp.array( [ 1.0, 5.0 ** (1.0 / 8.0), 5.0 ** (2.0 / 8.0), 5.0 ** (3.0 / 8.0), 5.0 ** (4.0 / 8.0), 5.0 ** (5.0 / 8.0), 5.0 ** (6.0 / 8.0), 5.0 ** (7.0 / 8.0), ] )