Source code for phentax.core.phase

# Copyright (C) 2025 Alessandro Santini
# SPDX-License-Identifier: MIT

# Credits for the original implementations: Cecilio García Quirós

"""
Phase and omega coefficient computation for IMRPhenomTHM.

This module implements the pPhase class functionality from phenomxpy,
computing all the coefficients needed for the IMR omega and phase ansatze.
"""

from typing import Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
from jaxtyping import Array

from ..utils.utility import solve_3x3_explicit
from . import collocation, fits, pn_coeffs
from .internals import WaveformParams, compute_wf_length_params

jax.config.update("jax_enable_x64", True)


[docs] class PhaseCoeffs(eqx.Module): """ All phase/omega coefficients for a given mode. Contains PN coefficients, pseudo-PN coefficients, ringdown parameters, and intermediate region coefficients. Parameters ---------- mode : int | Array Mode number (e.g., 22, 33, 44, etc.). omega1PN : float | Array 1PN coefficient for omega inspiral. omega1halfPN : float | Array 1.5PN coefficient for omega inspiral. omega2PN : float | Array 2PN coefficient for omega inspiral. omega2halfPN : float | Array 2.5PN coefficient for omega inspiral. omega3PN : float | Array 3PN coefficient for omega inspiral. omega3halfPN : float | Array 3.5PN coefficient for omega inspiral. omegaInspC1 : float | Array 1st pseudo-PN coefficient for omega inspiral. omegaInspC2 : float | Array 2nd pseudo-PN coefficient for omega inspiral. omegaInspC3 : float | Array 3rd pseudo-PN coefficient for omega inspiral. omegaInspC4 : float | Array 4th pseudo-PN coefficient for omega inspiral. omegaInspC5 : float | Array 5th pseudo-PN coefficient for omega inspiral. omegaInspC6 : float | Array 6th pseudo-PN coefficient for omega inspiral. omegaRING : float | Array Ringdown frequency (2*pi*fring). alpha1RD : float | Array Ringdown damping rate (2*pi*fdamp). omegaRING_prec : float | Array Ringdown frequency for precessing case. omegaPeak : float | Array Peak frequency. c1 : float | Array 1st ringdown ansatz coefficient. c2 : float | Array 2nd ringdown ansatz coefficient. c3 : float | Array 3rd ringdown ansatz coefficient. c4 : float | Array 4th ringdown ansatz coefficient. c1_prec : float | Array 1st ringdown ansatz coefficient for precessing case. omegaMergerC1 : float | Array 1st intermediate region coefficient. omegaMergerC2 : float | Array 2nd intermediate region coefficient. omegaMergerC3 : float | Array 3rd intermediate region coefficient. omegaCut : float | Array Omega at inspiral cut. domegaCut : float | Array domega/dt at inspiral cut. domegaPeak : float | Array domega/dt at peak. inspiral_cut : float | Array Transition time inspiral -> intermediate. ringdown_cut : float | Array Transition time intermediate -> ringdown (= 0, peak time). tt0 : float | Array t0 from fit. tEarly : float | Array Early time for phase offset. omegaCutPNAMP : float | Array Omega contribution from complex amplitude at transtion time ``pAmp.inspiral_cut``. phiCutPNAMP : float | Array Phase contribution from complex amplitude at transtion time ``pAmp.inspiral_cut``. phOffInsp : float | Array Phase offset for inspiral. phOffMerger : float | Array Phase offset for intermediate region. phOffRD : float | Array Phase offset for ringdown. phoff : float | Array phase offset for different modes. phiref0 : float | Array reference phase of the 22 mode at t=tref. powers_of_5 : Array Powers of 5^(n/8) for n=0..7 for phase computation. """ mode: int | Array # PN coefficients (TaylorT3) omega1PN: float | Array omega1halfPN: float | Array omega2PN: float | Array omega2halfPN: float | Array omega3PN: float | Array omega3halfPN: float | Array # Pseudo-PN coefficients (6 coefficients for omega inspiral) omegaInspC1: float | Array omegaInspC2: float | Array omegaInspC3: float | Array omegaInspC4: float | Array omegaInspC5: float | Array omegaInspC6: float | Array # Ringdown quantities omegaRING: float | Array # 2*pi*fring alpha1RD: float | Array # 2*pi*fdamp omegaRING_prec: float | Array # For precessing case # Peak frequency omegaPeak: float | Array # Ringdown ansatz coefficients c1: float | Array c2: float | Array c3: float | Array c4: float | Array c1_prec: float | Array # intermediate ansatz coefficients omegaMergerC1: float | Array omegaMergerC2: float | Array omegaMergerC3: float | Array # intermediate region values omegaCut: float | Array # omega at inspiral cut domegaCut: float | Array # domega/dt at inspiral cut domegaPeak: float | Array # domega/dt at peak # Times and cuts inspiral_cut: float | Array # tCut: transition time inspiral -> intermediate ringdown_cut: float | Array # = 0 (peak time) tt0: float | Array # t0 from fit tEarly: float | Array # Early time for phase offset omegaCutPNAMP: ( float | Array ) # Omega contribution from complex amplitude at transtion time ``pAmp.inspiral_cut``. phiCutPNAMP: ( float | Array ) # Phase contribution from complex amplitude at transtion time ``pAmp.inspiral_cut``. # Phase continuity offsets phOffInsp: float | Array phOffMerger: float | Array phOffRD: float | Array # phase offset for different modes phoff: float | Array # reference phase at t=tref phiref0: float | Array # Powers of 5 for phase computation powers_of_5: Array
def _compute_pn_and_pseudo_pn( wf_params: WaveformParams, ) -> Tuple[ jnp.ndarray, collocation.OmegaPseudoPNCoeffs, float | Array, float | Array, Array, ]: """ Compute PN and pseudo-PN coefficients, which are common to all modes. """ # Powers of 5 for phase computation base = jnp.power(5.0, 1.0 / 8.0) powers_of_5 = jnp.array( [1.0, base, base**2, base**3, base**4, base**5, base**6, base**7] ) # PN Coefficients omega_pn = pn_coeffs.compute_omega_pn_coeffs( wf_params.eta, wf_params.chi1, wf_params.chi2, wf_params.delta, wf_params.m1, wf_params.m2, ) omega_pn_array = jnp.array( [ omega_pn.omega1PN, omega_pn.omega1halfPN, omega_pn.omega2PN, omega_pn.omega2halfPN, omega_pn.omega3PN, omega_pn.omega3halfPN, ] ) # Collocation points and pseudo-PN coefficients omega_cp_values, tt0, tEarly = collocation.compute_omega_collocation_points( wf_params.eta, wf_params.chi1, wf_params.chi2, omega_pn_array ) pseudo_pn = collocation.compute_omega_pseudo_pn_coeffs( omega_pn_array, omega_cp_values ) return omega_pn_array, pseudo_pn, tt0, tEarly, powers_of_5
[docs] @jax.jit def compute_phase_coeffs_22( wf_params: WaveformParams, ) -> tuple[WaveformParams, PhaseCoeffs]: """ Compute all phase coefficients for the 22 mode. Parameters ---------- wf_params : WaveformParams Waveform parameters for the waveform. Returns ------- tuple[WaveformParams, PhaseCoeffs] Updated derived parameters and phase coefficients for mode 22. """ # Common coefficients omega_pn_array, pseudo_pn, tt0, tEarly, powers_of_5 = _compute_pn_and_pseudo_pn( wf_params ) # Final state af = fits.final_spin_2017(wf_params.eta, wf_params.chi1, wf_params.chi2) Mfinal = fits.final_mass_2017(wf_params.eta, wf_params.chi1, wf_params.chi2) # Ringdown quantities omegaRING = 2.0 * jnp.pi * fits.fring_22(af) / Mfinal alpha1RD = 2.0 * jnp.pi * fits.fdamp_22(af) / Mfinal omegaRING_prec = omegaRING omegaPeak = fits.peak_freq_22(wf_params.eta, wf_params.chi1, wf_params.chi2) c2 = fits.rd_freq_d2_22(wf_params.eta, wf_params.chi1, wf_params.chi2) c3 = fits.rd_freq_d3_22(wf_params.eta, wf_params.chi1, wf_params.chi2) c4 = 0.0 c1 = (1.0 + c3 + c4) * (omegaRING - omegaPeak) / c2 / (c3 + 2.0 * c4) c1_prec = c1 # Cuts inspiral_cut = -26.982976386771437 / wf_params.eta ringdown_cut = 0.0 # int | Arrayermediate region pseudo_pn_array = jnp.array( [ pseudo_pn.c1, pseudo_pn.c2, pseudo_pn.c3, pseudo_pn.c4, pseudo_pn.c5, pseudo_pn.c6, ] ) omegaCut = _inspiral_ansatz_omega_single( inspiral_cut, wf_params.eta, omega_pn_array, pseudo_pn_array, m=2 ) domegaCut = _inspiral_ansatz_domega( inspiral_cut, wf_params.eta, omega_pn_array, pseudo_pn_array, m=2 ) domegaPeak = -_ringdown_ansatz_domega(0.0, c1, c2, c3, c4) / omegaRING tcpMerger = -5.0 / (wf_params.eta * jnp.power(0.95, 8)) omegaMergerCP = ( 1.0 - fits.intermediate_freq_cp1_22(wf_params.eta, wf_params.chi1, wf_params.chi2) / omegaRING ) omegaCutBar = 1.0 - omegaCut / omegaRING domegaCut = -domegaCut / omegaRING omegaMergerC1, omegaMergerC2, omegaMergerC3 = _solve_intermediate_omega_system( alpha1RD, inspiral_cut, tcpMerger, omegaCutBar, omegaMergerCP, domegaCut, domegaPeak, omegaPeak, omegaRING, ) # Phase offsets thetabarini = jnp.power(wf_params.eta * (tt0 - tEarly), -1.0 / 8.0) pn_phase_at_thetabarini = _pn_ansatz_phase( thetabarini, wf_params.eta, powers_of_5, omega_pn_array ) inspiral_phase_at_tEarly = _inspiral_ansatz_phase_value_22( tEarly, wf_params.eta, powers_of_5, omega_pn_array, pseudo_pn_array, 0.0, ) phOffInsp = pn_phase_at_thetabarini - inspiral_phase_at_tEarly inspiral_phase_at_cut = _inspiral_ansatz_phase_value_22( inspiral_cut, wf_params.eta, powers_of_5, omega_pn_array, pseudo_pn_array, phOffInsp, ) intermediate_phase_at_cut = _intermediate_ansatz_phase_value( inspiral_cut, alpha1RD, omegaMergerC1, omegaMergerC2, omegaMergerC3, omegaPeak, domegaPeak, omegaRING, 0.0, ) phOffMerger = inspiral_phase_at_cut - intermediate_phase_at_cut phOffRD = _intermediate_ansatz_phase_value( 0.0, alpha1RD, omegaMergerC1, omegaMergerC2, omegaMergerC3, omegaPeak, domegaPeak, omegaRING, phOffMerger, ) PhaseCoeffs22 = PhaseCoeffs( mode=22, omega1PN=omega_pn_array[0], omega1halfPN=omega_pn_array[1], omega2PN=omega_pn_array[2], omega2halfPN=omega_pn_array[3], omega3PN=omega_pn_array[4], omega3halfPN=omega_pn_array[5], omegaInspC1=pseudo_pn.c1, omegaInspC2=pseudo_pn.c2, omegaInspC3=pseudo_pn.c3, omegaInspC4=pseudo_pn.c4, omegaInspC5=pseudo_pn.c5, omegaInspC6=pseudo_pn.c6, omegaRING=omegaRING, alpha1RD=alpha1RD, omegaRING_prec=omegaRING_prec, omegaPeak=omegaPeak, c1=c1, c2=c2, c3=c3, c4=c4, c1_prec=c1_prec, omegaMergerC1=omegaMergerC1, omegaMergerC2=omegaMergerC2, omegaMergerC3=omegaMergerC3, omegaCut=omegaCut, domegaCut=domegaCut, domegaPeak=domegaPeak, inspiral_cut=inspiral_cut, ringdown_cut=ringdown_cut, tt0=tt0, tEarly=tEarly, omegaCutPNAMP=jnp.array(0.0), phiCutPNAMP=jnp.array(0.0), phoff=jnp.array(0.0), phiref0=jnp.array(0.0), phOffInsp=phOffInsp, phOffMerger=phOffMerger, phOffRD=phOffRD, powers_of_5=powers_of_5, ) def _compute_min(_): return get_time_of_frequency( wf_params.Mf_min, wf_params.eta, PhaseCoeffs22, wf_params.t_low, wf_params.atol, wf_params.rtol, ) def _use_existing_min(_): return wf_params.Mt_min # This works inside vmap because isnan returns a boolean tracer _Mt_min = jax.lax.cond( jnp.isnan(wf_params.Mt_min), _compute_min, _use_existing_min, operand=None ) # wf_params = wf_params._replace(Mt_min=_Mt_min) wf_params = eqx.tree_at(lambda p: p.Mt_min, wf_params, _Mt_min) # check here if fmin and fref are the same to avoid a second root solving _Mt_ref = jax.lax.cond( jnp.isnan(wf_params.Mt_ref), lambda: jax.lax.cond( wf_params.Mf_min == wf_params.Mf_ref, lambda: wf_params.Mt_min, lambda: wf_params.Mt_ref, ), lambda: wf_params.Mt_ref, ) # wf_params = wf_params._replace(Mt_ref=_Mt_ref) wf_params = eqx.tree_at(lambda p: p.Mt_ref, wf_params, _Mt_ref) def _compute_ref(_): return get_time_of_frequency( wf_params.Mf_ref, wf_params.eta, PhaseCoeffs22, wf_params.t_low, wf_params.atol, wf_params.rtol, ) def _use_existing_ref(_): return wf_params.Mt_ref # This works inside vmap because isnan returns a boolean tracer _Mt_ref = jax.lax.cond( jnp.isnan(_Mt_ref), _compute_ref, _use_existing_ref, operand=None ) # wf_params = wf_params._replace(Mt_ref=_Mt_ref) wf_params = eqx.tree_at(lambda p: p.Mt_ref, wf_params, _Mt_ref) wf_params = compute_wf_length_params( wf_params ) # compute waveform length parameters based on Mt_min phiref0 = imr_phase(_Mt_ref, wf_params.eta, PhaseCoeffs22) # phase at tref # phiref0 = imr_phase(wf_params.t_ref, wf_params.eta, PhaseCoeffs22) # phase at tref return wf_params, eqx.tree_at(lambda p: p.phiref0, PhaseCoeffs22, phiref0)
[docs] @jax.jit def compute_phase_coeffs_hm( wf_params: WaveformParams, phase_coeffs_22: PhaseCoeffs, OmegaCutPNAMP: Array, PhiCutPNAMP: Array, mode: int | Array, ) -> PhaseCoeffs: """ Compute all phase/omega coefficients for HM modes. Parameters ---------- wf_params : WaveformParams Waveform parameters for the waveform. phase_coeffs_22 : PhaseCoeffs Phase coefficients for the 22 mode. OmegaCutPNAMP : Array Omega contribution from complex amplitude at transtion time ``pAmp.inspiral_cut`` for HM modes. PhiCutPNAMP : Array Phase contribution from complex amplitude at transtion time ``pAmp.inspiral_cut`` for HM modes. mode : int | Array Mode number (e.g., 33, 44, etc.). """ # Common coefficients omega_pn_array, pseudo_pn, tt0, tEarly, powers_of_5 = _compute_pn_and_pseudo_pn( wf_params ) m = mode % 10 # Final state af = fits.final_spin_2017(wf_params.eta, wf_params.chi1, wf_params.chi2) Mfinal = fits.final_mass_2017(wf_params.eta, wf_params.chi1, wf_params.chi2) # Ringdown quantities omegaRING = 2.0 * jnp.pi * fits.fring(af, mode) / Mfinal alpha1RD = 2.0 * jnp.pi * fits.fdamp(af, mode) / Mfinal omegaRING_prec = omegaRING omegaPeak = fits.peak_freq(wf_params.eta, wf_params.chi1, wf_params.chi2, mode) c2 = fits.rd_freq_d2(wf_params.eta, wf_params.chi1, wf_params.chi2, mode) c3 = fits.rd_freq_d3(wf_params.eta, wf_params.chi1, wf_params.chi2, mode) c4 = 0.0 c1 = (1.0 + c3 + c4) * (omegaRING - omegaPeak) / c2 / (c3 + 2.0 * c4) c1_prec = c1 # Cuts inspiral_cut = -150.0 ringdown_cut = 0.0 # int | Arrayermediate region omegaCut = ( m / 2.0 * imr_omega(inspiral_cut, eta=wf_params.eta, phase_coeffs=phase_coeffs_22) ) domegaCut = compute_domega_cut( inspiral_cut, phase_coeffs_22.inspiral_cut, wf_params.eta, phase_coeffs_22 ) domegaCut = -m / 2.0 * domegaCut / omegaRING domegaPeak = -_ringdown_ansatz_domega(0.0, c1, c2, c3, c4) / omegaRING tcpMerger = -25.0 omegaMergerCP = ( 1.0 - fits.intermediate_freq_cp1( wf_params.eta, wf_params.chi1, wf_params.chi2, mode ) / omegaRING ) omegaCutBar = 1.0 - (omegaCut + OmegaCutPNAMP) / omegaRING omegaMergerC1, omegaMergerC2, omegaMergerC3 = _solve_intermediate_omega_system( alpha1RD, inspiral_cut, tcpMerger, omegaCutBar, omegaMergerCP, domegaCut, domegaPeak, omegaPeak, omegaRING, ) # Phase offsets phOffInsp = 0.0 phMECOinsp = ( m / 2.0 * imr_phase(inspiral_cut, eta=wf_params.eta, phase_coeffs=phase_coeffs_22) ) phMECOmerger = _intermediate_ansatz_phase_value( inspiral_cut, alpha1RD, omegaMergerC1, omegaMergerC2, omegaMergerC3, omegaPeak, domegaPeak, omegaRING, 0.0, ) phOffMerger = phMECOinsp - phMECOmerger phOffRD = _intermediate_ansatz_phase_value( 0.0, alpha1RD, omegaMergerC1, omegaMergerC2, omegaMergerC3, omegaPeak, domegaPeak, omegaRING, phOffMerger, ) phoff = _get_offset(mode) return PhaseCoeffs( mode=mode, omega1PN=omega_pn_array[0], omega1halfPN=omega_pn_array[1], omega2PN=omega_pn_array[2], omega2halfPN=omega_pn_array[3], omega3PN=omega_pn_array[4], omega3halfPN=omega_pn_array[5], omegaInspC1=pseudo_pn.c1, omegaInspC2=pseudo_pn.c2, omegaInspC3=pseudo_pn.c3, omegaInspC4=pseudo_pn.c4, omegaInspC5=pseudo_pn.c5, omegaInspC6=pseudo_pn.c6, omegaRING=omegaRING, alpha1RD=alpha1RD, omegaRING_prec=omegaRING_prec, omegaPeak=omegaPeak, c1=c1, c2=c2, c3=c3, c4=c4, c1_prec=c1_prec, omegaMergerC1=omegaMergerC1, omegaMergerC2=omegaMergerC2, omegaMergerC3=omegaMergerC3, omegaCut=omegaCut, domegaCut=domegaCut, domegaPeak=domegaPeak, inspiral_cut=inspiral_cut, ringdown_cut=ringdown_cut, tt0=tt0, tEarly=tEarly, omegaCutPNAMP=OmegaCutPNAMP, phiCutPNAMP=PhiCutPNAMP, phoff=phoff, phiref0=phase_coeffs_22.phiref0, phOffInsp=phOffInsp, phOffMerger=phOffMerger, phOffRD=phOffRD, powers_of_5=powers_of_5, )
[docs] def imr_omega( time: float | Array, eta: float | Array, phase_coeffs: PhaseCoeffs ) -> float | Array: """ Compute the frequency :math:`\\omega(t) = 2\\pi f(t)` at given times for a given mode. Parameters ---------- time : float | Array Time(s) at which to compute the phase. eta : float | Array Symmetric mass ratio. phase_coeffs : PhaseCoeffs Phase coefficients for the mode. Returns ------- Array Phase value(s) at the given time(s). """ m = phase_coeffs.mode % 10 # Prepare coefficient arrays for helper functions omega_pn_coeffs = jnp.array( [ phase_coeffs.omega1PN, phase_coeffs.omega1halfPN, phase_coeffs.omega2PN, phase_coeffs.omega2halfPN, phase_coeffs.omega3PN, phase_coeffs.omega3halfPN, ] ) omega_pseudo_pn_coeffs = jnp.array( [ phase_coeffs.omegaInspC1, phase_coeffs.omegaInspC2, phase_coeffs.omegaInspC3, phase_coeffs.omegaInspC4, phase_coeffs.omegaInspC5, phase_coeffs.omegaInspC6, ] ) @jax.jit def _omega_scalar(t: Array) -> Array: """Compute omega at a single time t.""" is_post_inspiral = t >= phase_coeffs.inspiral_cut is_ringdown = t >= phase_coeffs.ringdown_cut # 0 if insp, 1 if interm, 2 if ringdown region_idx = is_post_inspiral.astype(jnp.int32) + is_ringdown.astype(jnp.int32) def _inspiral(t): return _inspiral_ansatz_omega_single( t, eta, omega_pn_coeffs, omega_pseudo_pn_coeffs, m, ) def _intermediate(t): return _intermediate_ansatz_omega_single( t, phase_coeffs.alpha1RD, phase_coeffs.omegaPeak, phase_coeffs.domegaPeak, phase_coeffs.omegaRING, phase_coeffs.omegaMergerC1, phase_coeffs.omegaMergerC2, phase_coeffs.omegaMergerC3, ) def _ringdown(t): return _ringdown_ansatz_omega_single( t, phase_coeffs.c1, phase_coeffs.c2, phase_coeffs.c3, phase_coeffs.c4, phase_coeffs.omegaRING, ) return jax.lax.switch( region_idx, [_inspiral, _intermediate, _ringdown], t, ) # Vectorize over time array time = jnp.asarray(time) time_shape = jnp.shape(time) time_flat = jnp.reshape(time, (-1,)) omegas_flat = jax.vmap(_omega_scalar)(time_flat) omegas = jnp.reshape(omegas_flat, time_shape) return omegas
[docs] def imr_omega_dot( time: float | Array, eta: float | Array, phase_coeffs: PhaseCoeffs ) -> float | Array: """ Compute the frequency derivative :math:`\\dot{\\omega}(t)` at given times for a given mode with JAX autodiff. Parameters ---------- time : float | Array Time(s) at which to compute the phase. eta : float | Array Symmetric mass ratio. phase_coeffs : PhaseCoeffs Phase coefficients for the mode. Returns ------- Array Phase derivative value(s) at the given time(s). """ domega_dt = jax.grad(lambda t: imr_omega(t, eta, phase_coeffs))(time) return domega_dt
[docs] def imr_phase( time: float | Array, eta: float | Array, phase_coeffs: PhaseCoeffs, phase_22: float | Array = 0.0, ) -> Array: """ Compute the phase at given times for a given mode. Parameters ---------- time : float | Array Time(s) at which to compute the phase. eta : float | Array Symmetric mass ratio. phase_coeffs : PhaseCoeffs Phase coefficients for the mode. phase_22 : float | Array, optional Phase of the (2,2) mode at the same times (default is 0.0). This is used for the higher modes' inspiral phase computation. Returns ------- Array Phase value(s) at the given time(s). """ # Prepare coefficient arrays for helper functions omega_pn_coeffs = jnp.array( [ phase_coeffs.omega1PN, phase_coeffs.omega1halfPN, phase_coeffs.omega2PN, phase_coeffs.omega2halfPN, phase_coeffs.omega3PN, phase_coeffs.omega3halfPN, ] ) omega_pseudo_pn_coeffs = jnp.array( [ phase_coeffs.omegaInspC1, phase_coeffs.omegaInspC2, phase_coeffs.omegaInspC3, phase_coeffs.omegaInspC4, phase_coeffs.omegaInspC5, phase_coeffs.omegaInspC6, ] ) @jax.jit def _phase_scalar(t: Array, _phase_22: float | Array) -> Array: # Determine region index: 0=Inspiral, 1=Intermediate, 2=Ringdown # Using boolean arithmetic is often faster than branching logic for indices is_post_inspiral = t >= phase_coeffs.inspiral_cut is_ringdown = t >= phase_coeffs.ringdown_cut # 0 if insp, 1 if interm, 2 if ringdown region_idx = is_post_inspiral.astype(jnp.int32) + is_ringdown.astype(jnp.int32) def _inspiral(t, _phase_22): return _inspiral_ansatz_phase_value( t, eta, phase_coeffs.powers_of_5, omega_pn_coeffs, omega_pseudo_pn_coeffs, phase_coeffs.phOffInsp, phase_coeffs.mode, phase_22=_phase_22, ) def _intermediate(t, _): val = _intermediate_ansatz_phase_value( t, phase_coeffs.alpha1RD, phase_coeffs.omegaMergerC1, phase_coeffs.omegaMergerC2, phase_coeffs.omegaMergerC3, phase_coeffs.omegaPeak, phase_coeffs.domegaPeak, phase_coeffs.omegaRING, phase_coeffs.phOffMerger, ) return val - phase_coeffs.phiCutPNAMP def _ringdown(t, _): val = _ringdown_ansatz_phase_value( t, phase_coeffs.c1_prec, phase_coeffs.c2, phase_coeffs.c3, phase_coeffs.c4, phase_coeffs.omegaRING_prec, phase_coeffs.phOffRD, ) return val - phase_coeffs.phiCutPNAMP # Use lax.switch which is cleaner than nested conds # We need to pass _phase_22 to all, even if unused, to match signature return jax.lax.switch( region_idx, [_inspiral, _intermediate, _ringdown], t, _phase_22 ) # Vectorize over time array time = jnp.asarray(time) phase_22 = jnp.asarray(phase_22) time_shape = jnp.shape(time) time_flat = jnp.reshape(time, (-1,)) # phase_22 may be a scalar or an array; broadcast to match time_flat shape if jnp.ndim(phase_22) == 0: phase_22_flat = jnp.full_like(time_flat, phase_22) else: phase_22_flat = jnp.reshape(phase_22, (-1,)) phases_flat = jax.vmap(_phase_scalar, in_axes=(0, 0))(time_flat, phase_22_flat) phases = jnp.reshape(phases_flat, time_shape) return phases
[docs] @jax.jit def get_time_of_frequency( freq: float | Array, eta: float | Array, phase_coeffs: PhaseCoeffs, t_low: float | Array = 0.0, t_high: float | Array = 500.0, atol: float = 1e-12, rtol: float = 1e-12, ) -> float | Array: """ Get time corresponding to a given frequency using root finding. Parameters ---------- freq : float | Array (Dimensionless) frequency at which to find the corresponding time. eta : float | Array Symmetric mass ratio. phase_coeffs : PhaseCoeffs Phase coefficients for the mode. t_low : float | Array, optional Lower bound for the time search (default is 0.0. In this case, it is adjusted based on the frequency). t_high : float | Array, optional Upper bound for the time search (default is 500.0). atol : float | Array, optional Absolute Bisection tolerance rtol : float | Array, optional Relative bisection tolerance """ t_low = jax.lax.cond( t_low == 0, lambda: -0.015 * freq ** (-2.7), # enlarging this a bit lambda: t_low, ) def time_of_freq(t, freq): time = jax.lax.cond( t < phase_coeffs.tEarly, lambda t: t - phase_coeffs.tt0, lambda t: t, t, ) omega = imr_omega(time, eta, phase_coeffs) return 2 * jnp.pi * freq - omega solver = optx.Bisection( # type: ignore atol=atol, rtol=rtol, ) time_root: optx.Solution = optx.root_find( time_of_freq, solver, args=freq, y0=-0.01 * freq ** (-2.7), options={"lower": t_low, "upper": t_high}, max_steps=1000, ) return time_root.value
# ============================================================================= # Helper functions # =============================================================================
[docs] @jax.jit def compute_domega_cut( tCut: float | Array, tCut_threshold: float | Array, eta: float | Array, phase_coeffs_22: PhaseCoeffs, ) -> float | Array: """ Compute domegaCut using JAX conditional. Parameters ---------- tCut : float | Array Transition time inspiral -> intermediate. tCut_threshold : float | Array Threshold time to switch between inspiral and merger branch. eta : float | Array Symmetric mass ratio. phase_coeffs_22 : PhaseCoeffs """ def inspiral_branch(t): omega_pn_coefficients = jnp.array( [ phase_coeffs_22.omega1PN, phase_coeffs_22.omega1halfPN, phase_coeffs_22.omega2PN, phase_coeffs_22.omega2halfPN, phase_coeffs_22.omega3PN, phase_coeffs_22.omega3halfPN, ] ) omega_pseudo_pn_coefficients = jnp.array( [ phase_coeffs_22.omegaInspC1, phase_coeffs_22.omegaInspC2, phase_coeffs_22.omegaInspC3, phase_coeffs_22.omegaInspC4, phase_coeffs_22.omegaInspC5, phase_coeffs_22.omegaInspC6, ] ) return _inspiral_ansatz_domega( t, eta, omega_pn_coefficients, omega_pseudo_pn_coefficients ) def merger_branch(t): arcsinh = jnp.arcsinh(phase_coeffs_22.alpha1RD * t) return ( -phase_coeffs_22.omegaRING / jnp.sqrt(1.0 + (phase_coeffs_22.alpha1RD * t) ** 2) * ( phase_coeffs_22.domegaPeak + phase_coeffs_22.alpha1RD * ( 2.0 * phase_coeffs_22.omegaMergerC1 * arcsinh + 3.0 * phase_coeffs_22.omegaMergerC2 * arcsinh * arcsinh + 4.0 * phase_coeffs_22.omegaMergerC3 * arcsinh**3 ) ) ) return jax.lax.cond(tCut < tCut_threshold, inspiral_branch, merger_branch, tCut)
@jax.jit def _get_offset(mode: int | Array) -> Array: """Get mode-dependent offset for phase computation.""" phoff = jax.lax.cond( mode == 33, lambda: -jnp.pi * 0.5, lambda: jax.lax.cond( mode == 44, lambda: jnp.pi, lambda: jax.lax.cond( mode == 55, lambda: jnp.pi * 0.5, lambda: jax.lax.cond( mode == 21, lambda: jnp.pi * 0.5, lambda: jnp.array(0.0), ), ), ), ) return phoff # ============================================================================= # Helper functions for ansatz evaluation # ============================================================================= @jax.jit def _inspiral_ansatz_omega_single( time: Array, eta: Array, omega_pn_coeffs: jnp.ndarray, omega_pseudo_pn_coeffs: jnp.ndarray, m: int | Array = 2, ) -> Array: """Evaluate inspiral omega ansatz at a single time.""" theta = jnp.power(-eta * time / 5.0, -1.0 / 8.0) # TaylorT3 part taylort3 = collocation.pn_ansatz_omega(theta, omega_pn_coeffs) # Pseudo-PN part theta8 = jnp.power(theta, 8) theta9 = theta8 * theta theta10 = theta9 * theta theta11 = theta10 * theta theta12 = theta11 * theta theta13 = theta12 * theta fac = theta * theta * theta / 8.0 pseudo_pn_sum = ( omega_pseudo_pn_coeffs[0] * theta8 + omega_pseudo_pn_coeffs[1] * theta9 + omega_pseudo_pn_coeffs[2] * theta10 + omega_pseudo_pn_coeffs[3] * theta11 + omega_pseudo_pn_coeffs[4] * theta12 + omega_pseudo_pn_coeffs[5] * theta13 ) return (taylort3 + 2.0 * fac * pseudo_pn_sum) * (m / 2.0) @jax.jit def _intermediate_ansatz_omega_single( time: Array, alpha1RD: Array, omegaPeak: Array, domegaPeak: Array, omegaRING: Array, omegaMergerC1: Array, omegaMergerC2: Array, omegaMergerC3: Array, ): x = jnp.arcsinh(alpha1RD * time) w = ( 1 - omegaPeak / omegaRING + x * ( domegaPeak / alpha1RD + x * (omegaMergerC1 + x * (omegaMergerC2 + x * omegaMergerC3)) ) ) return omegaRING * (1 - w) @jax.jit def _ringdown_ansatz_omega_single( time: Array, c1: Array, c2: Array, c3: Array, c4: Array, omegaRING: Array, ) -> Array: """Evaluate ringdown omega ansatz at a single time.""" expC = jnp.exp(-c2 * time) expC2 = expC * expC num = -c1 * c2 * (2 * c4 * expC2 + c3 * expC) den = 1 + c4 * expC2 + c3 * expC return num / den + omegaRING @jax.jit def _inspiral_ansatz_domega( time: Array, eta: Array, omega_pn_coeffs: jnp.ndarray, omega_pseudo_pn_coeffs: jnp.ndarray, m: int | Array = 2, ) -> Array: """Compute d(omega)/dt for inspiral ansatz at a single time.""" # Use JAX autodiff return jax.grad( lambda t: _inspiral_ansatz_omega_single( t, eta, omega_pn_coeffs, omega_pseudo_pn_coeffs, m ) )(time) @jax.jit def _ringdown_ansatz_domega( time: Array, c1: Array, c2: Array, c3: Array, c4: Array, ) -> Array: """Compute d(omega)/dt for ringdown ansatz at a single time.""" expC = jnp.exp(c2 * time) expC2 = expC * expC num = c1 * c2 * c2 * expC * (4 * c4 * expC + c3 * (c4 + expC2)) den = c4 + expC * (c3 + expC) return num / (den * den) @jax.jit def _solve_intermediate_omega_system( alpha1RD: Array, tCut: Array, tcpMerger: Array, omegaCutBar: Array, omegaMergerCP: Array, domegaCut: Array, domegaPeak: Array, omegaPeak: Array, omegaRING: Array, ) -> tuple: """Solve 3x3 linear system for intermediate omega coefficients.""" # System from Eqs. 28, 29, 31 in arXiv:2012.11923 ascut = jnp.arcsinh(alpha1RD * tCut) ascut2 = ascut * ascut ascut3 = ascut * ascut2 ascut4 = ascut * ascut3 bascut = jnp.arcsinh(alpha1RD * tcpMerger) bascut2 = bascut * bascut bascut3 = bascut * bascut2 bascut4 = bascut * bascut3 dencut = jnp.sqrt(1.0 + tCut * tCut * alpha1RD * alpha1RD) # Build matrix and RHS B = jnp.array( [ omegaCutBar - (1.0 - omegaPeak / omegaRING) - (domegaPeak / alpha1RD) * ascut, omegaMergerCP - (1.0 - omegaPeak / omegaRING) - (domegaPeak / alpha1RD) * bascut, domegaCut - domegaPeak / dencut, ] ) matrix = jnp.array( [ jnp.array([ascut2, ascut3, ascut4]), jnp.array([bascut2, bascut3, bascut4]), jnp.array( [ 2.0 * alpha1RD * ascut / dencut, 3.0 * alpha1RD * ascut2 / dencut, 4.0 * alpha1RD * ascut3 / dencut, ] ), ] ) # Solve solution = solve_3x3_explicit(matrix, B) # jnp.linalg.solve(matrix, B) return solution[0], solution[1], solution[2] @jax.jit def _pn_ansatz_phase( thetabar: Array, eta: Array, powers_of_5: jnp.ndarray, omega_pn_coeffs: jnp.ndarray, ) -> Array: """Evaluate PN ansatz phase at thetabar.""" # This is the integration of the PN omega ansatz # From phenomxpy numba_pn_ansatz_22_phase thetabar = thetabar * powers_of_5[1] thetabar2 = thetabar * thetabar thetabar3 = thetabar * thetabar2 thetabar4 = thetabar * thetabar3 thetabar5 = thetabar * thetabar4 thetabar6 = thetabar * thetabar5 thetabar7 = thetabar * thetabar6 logthetabar = jnp.log(thetabar) aux = ( 1 / eta / thetabar5 * ( -168 - 280 * omega_pn_coeffs[0] * thetabar2 - 420 * omega_pn_coeffs[1] * thetabar3 - 840 * omega_pn_coeffs[2] * thetabar4 + 840 * omega_pn_coeffs[3] * (logthetabar - 0.125 * jnp.log(5)) * thetabar5 - 321 * thetabar6 + 840 * omega_pn_coeffs[4] * thetabar6 + 321 * logthetabar * thetabar6 + 420 * omega_pn_coeffs[5] * thetabar7 ) ) / 84.0 return aux @jax.jit def _inspiral_ansatz_phase_value_22( time: Array, eta: Array, powers_of_5: jnp.ndarray, omega_pn_coeffs: jnp.ndarray, omega_pseudo_pn_coeffs: jnp.ndarray, phOffInsp: Array, ) -> Array: """Evaluate inspiral ansatz phase at a single time.""" thetabar = jnp.power(-eta * time, -1.0 / 8.0) thetabar2 = thetabar * thetabar thetabar3 = thetabar * thetabar2 thetabar4 = thetabar * thetabar3 thetabar5 = thetabar * thetabar4 thetabar6 = thetabar * thetabar5 thetabar7 = thetabar * thetabar6 logmtime = jnp.log(-time) log_theta_bar = jnp.log(jnp.power(5, 0.125)) - 0.125 * (jnp.log(eta) + logmtime) aux = ( -( 1 / powers_of_5[5] / (eta * eta) / time / thetabar7 * ( 3 * (-107 + 280 * omega_pn_coeffs[4]) * powers_of_5[6] + 321 * log_theta_bar * powers_of_5[6] + 420 * omega_pn_coeffs[5] * thetabar * powers_of_5[7] + 56 * (25 * omega_pseudo_pn_coeffs[0] + 3 * eta * time) * thetabar2 + 1050 * omega_pseudo_pn_coeffs[1] * powers_of_5[1] * thetabar3 + 280 * (3 * omega_pseudo_pn_coeffs[2] + eta * omega_pn_coeffs[0] * time) * powers_of_5[2] * thetabar4 + 140 * (5 * omega_pseudo_pn_coeffs[3] + 3 * eta * omega_pn_coeffs[1] * time) * powers_of_5[3] * thetabar5 + 120 * (5 * omega_pseudo_pn_coeffs[4] + 7 * eta * omega_pn_coeffs[2] * time) * powers_of_5[4] * thetabar6 + 525 * omega_pseudo_pn_coeffs[5] * powers_of_5[5] * thetabar7 + 105 * eta * omega_pn_coeffs[3] * time * logmtime * powers_of_5[5] * thetabar7 ) ) / 84.0 ) return aux + phOffInsp @jax.jit def _inspiral_ansatz_phase_hm( phase_22: Array, mode: int | Array, ): """Compute inspiral ansatz phase for HM modes based on the 22 mode phase.""" m = mode % 10 return (m / 2.0) * phase_22 def _inspiral_ansatz_phase_value( time: Array, eta: Array, powers_of_5: jnp.ndarray, omega_pn_coeffs: jnp.ndarray, omega_pseudo_pn_coeffs: jnp.ndarray, phOffInsp: Array, mode: int | Array, phase_22: float | Array = 0.0, ) -> Array: """Evaluate inspiral ansatz phase at a single time for given mode.""" phase = jax.lax.cond( mode == 22, lambda: _inspiral_ansatz_phase_value_22( time, eta, powers_of_5, omega_pn_coeffs, omega_pseudo_pn_coeffs, phOffInsp, ), lambda: _inspiral_ansatz_phase_hm( phase_22, mode, ), ) return phase @jax.jit def _intermediate_ansatz_phase_value( time: Array, alpha1RD: Array, omegaMergerC1: Array, omegaMergerC2: Array, omegaMergerC3: Array, omegaPeak: Array, domegaPeak: Array, omegaRING: Array, phOffMerger: Array, ) -> Array: """Evaluate intermediate ansatz phase at a single time.""" x = jnp.arcsinh(alpha1RD * time) x2 = x * x x3 = x * x2 x4 = x * x3 term1 = jnp.sqrt(1.0 + (alpha1RD * alpha1RD) * time * time) aux = omegaRING * time * ( 1.0 - ( 2.0 * omegaMergerC1 + 24.0 * omegaMergerC3 + (6.0 * omegaMergerC2 + domegaPeak / alpha1RD) * x + (1.0 - omegaPeak / omegaRING) + (omegaMergerC1 + 12.0 * omegaMergerC3) * x2 + omegaMergerC2 * x3 + omegaMergerC3 * x4 ) ) - (omegaRING / alpha1RD) * term1 * ( -domegaPeak / alpha1RD - 6.0 * omegaMergerC2 - x * (2.0 * omegaMergerC1 + 24.0 * omegaMergerC3) - 3.0 * omegaMergerC2 * x2 - 4.0 * omegaMergerC3 * x3 ) return aux + phOffMerger @jax.jit def _ringdown_ansatz_phase_value( time: Array, c1_prec: Array, c2: Array, c3: Array, c4: Array, omegaRING_prec: Array, phOffRD: Array, ) -> Array: """Evaluate ringdown ansatz phase at a single time.""" expC = jnp.exp(-c2 * time) num = 1 + c3 * expC + c4 * expC * expC den = 1 + c3 + c4 aux = jnp.log(num / den) return c1_prec * aux + omegaRING_prec * time + phOffRD