Skip to content

API Reference

Core

InjectionConfig

Bases: BaseModel

Configuration for waveform injection source.

Source code in src/pythiabns/core/config.py
21
22
23
24
25
26
27
28
29
class InjectionConfig(BaseModel):
    """Configuration for waveform injection source."""

    # "nr": load from STRAIN_PATH via ID
    # "file": load from absolute/relative path
    # "analytic": simulate using a model from registry
    mode: str = "nr"
    target: str | None = None  # NR ID, File Path, or Model Name
    parameters: dict[str, float] = Field(default_factory=dict)

JobMatrix

Bases: BaseModel

Configuration for expanding into multiple simulations.

Source code in src/pythiabns/core/config.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class JobMatrix(BaseModel):
    """Configuration for expanding into multiple simulations."""

    # Support both legacy 'waveform' and new modular 'injection'
    waveform: list[str] | None = None
    injection: list[InjectionConfig] | None = None

    snr: list[float] = Field(default_factory=lambda: [50.0])
    model: list[str]

    sampler: SamplerConfig
    priors: PriorConfig

    # Extra model parameters like nfreqs
    model_params: dict[str, Any] = Field(default_factory=dict)
    # Legacy: injection_parameters at top level
    injection_parameters: dict[str, float] = Field(default_factory=dict)

SimulationConfig

Bases: BaseModel

Configuration for a single simulation run.

Source code in src/pythiabns/core/config.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class SimulationConfig(BaseModel):
    """Configuration for a single simulation run."""

    # Single instance of injection source
    injection: InjectionConfig

    snr: float
    model: str
    sampler: SamplerConfig
    priors: PriorConfig

    model_params: dict[str, Any] = Field(default_factory=dict)
    # Legacy support
    waveform: str | None = None
    injection_parameters: dict[str, float] = Field(default_factory=dict)

Registry

Generic registry for models, relations, etc. supporting metadata and variants.

Source code in src/pythiabns/core/registry.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class Registry:
    """Generic registry for models, relations, etc. supporting metadata and variants."""

    def __init__(self):
        # Store as list of candidates: {name: [{'obj': obj, 'meta': meta}, ...]}
        self._registry: dict[str, list[dict[str, Any]]] = {}

    def register(self, name: str, **metadata):
        """Decorator to register a class or function with optional metadata."""

        def decorator(obj):
            if name not in self._registry:
                self._registry[name] = []

            # Check for duplicates? or just append?
            # Append allows variants.
            entry = {"obj": obj, "meta": metadata}
            self._registry[name].append(entry)
            logger.debug(f"Registered {name} with metadata {metadata}")
            return obj

        return decorator

    def get(self, name: str, **filters) -> Any | None:
        """
        Get object by name, optionally filtering by metadata.
        Example: get("lorentzian", nfreqs=3)
        """
        candidates = self._registry.get(name, [])
        if not candidates:
            return None

        matches = []
        for cand in candidates:
            meta = cand["meta"]
            # Check if all filters match metadata
            # If filter key not in meta, mismatch?
            # Or if meta has key and filter matches.
            is_match = True
            for k, v in filters.items():
                if k not in meta or meta[k] != v:
                    is_match = False
                    break
            if is_match:
                matches.append(cand)

        if len(matches) == 0:
            logger.warning(f"No match found for {name} with filters {filters} among {len(candidates)} candidates.")
            return None
        elif len(matches) > 1:
            # Ambiguous? Return last registered or first?
            # Original code seemed to rely on overwriting or specific key.
            # We return the last one (most recently registered)
            logger.debug(f"Multiple matches for {name}, returning last one.")
            return matches[-1]["obj"]
        else:
            return matches[0]["obj"]

    def get_metadata(self, name: str, **filters) -> dict[str, Any]:
        candidates = self._registry.get(name, [])
        # Same logic as get, but return meta
        if not candidates:
            return {}

        matches = []
        for cand in candidates:
            meta = cand["meta"]
            is_match = True
            for k, v in filters.items():
                if k not in meta or meta[k] != v:
                    is_match = False
                    break
            if is_match:
                matches.append(cand)

        if matches:
            return matches[-1]["meta"]
        return {}

    def list_available(self) -> list:
        return list(self._registry.keys())

get(name, **filters)

Get object by name, optionally filtering by metadata. Example: get("lorentzian", nfreqs=3)

Source code in src/pythiabns/core/registry.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def get(self, name: str, **filters) -> Any | None:
    """
    Get object by name, optionally filtering by metadata.
    Example: get("lorentzian", nfreqs=3)
    """
    candidates = self._registry.get(name, [])
    if not candidates:
        return None

    matches = []
    for cand in candidates:
        meta = cand["meta"]
        # Check if all filters match metadata
        # If filter key not in meta, mismatch?
        # Or if meta has key and filter matches.
        is_match = True
        for k, v in filters.items():
            if k not in meta or meta[k] != v:
                is_match = False
                break
        if is_match:
            matches.append(cand)

    if len(matches) == 0:
        logger.warning(f"No match found for {name} with filters {filters} among {len(candidates)} candidates.")
        return None
    elif len(matches) > 1:
        # Ambiguous? Return last registered or first?
        # Original code seemed to rely on overwriting or specific key.
        # We return the last one (most recently registered)
        logger.debug(f"Multiple matches for {name}, returning last one.")
        return matches[-1]["obj"]
    else:
        return matches[0]["obj"]

register(name, **metadata)

Decorator to register a class or function with optional metadata.

Source code in src/pythiabns/core/registry.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def register(self, name: str, **metadata):
    """Decorator to register a class or function with optional metadata."""

    def decorator(obj):
        if name not in self._registry:
            self._registry[name] = []

        # Check for duplicates? or just append?
        # Append allows variants.
        entry = {"obj": obj, "meta": metadata}
        self._registry[name].append(entry)
        logger.debug(f"Registered {name} with metadata {metadata}")
        return obj

    return decorator

generate_plots(result, config, outdir)

Generate plots based on configuration.

Parameters:

Name Type Description Default
result

bilby.core.result.Result object

required
config

PlottingConfig object

required
outdir Path

Path to output directory

required
Source code in src/pythiabns/core/plotting.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def generate_plots(result, config, outdir: Path):
    """
    Generate plots based on configuration.

    Args:
        result: bilby.core.result.Result object
        config: PlottingConfig object
        outdir: Path to output directory
    """
    if not config.enabled:
        return

    outdir = Path(outdir)
    outdir.mkdir(exist_ok=True, parents=True)  # Should already exist but good check

    for plot_type in config.plots:
        try:
            if plot_type == "corner":
                plot_corner(result, outdir, **config.settings.get("corner", {}))
            elif plot_type == "trace":
                plot_trace(result, outdir, **config.settings.get("trace", {}))
            elif plot_type == "waveform":
                # This requires the waveform generator which isn't passed here easily
                # typically stored in result object metadata if using standard bilby
                # or we skip for now/handle in spine
                logger.warning(
                    "Waveform plotting from this module requires access to the generator. Handled in spine.py?"
                )
            else:
                logger.warning(f"Unknown plot type: {plot_type}")
        except Exception as e:
            logger.error(f"Failed to generate {plot_type} plot: {e}")

plot_waveform_posterior(result, waveform_generator, descriptors, outdir, n_samples=100)

Plot the waveform posterior against injection/data.

Parameters:

Name Type Description Default
result

Bilby Result

required
waveform_generator

Initiailized WaveformGenerator

required
descriptors

Dictionary of {'label': [list of ifos]} or similar to know what to plot? Actually, usually we plot strain vs time/freq.

required
outdir

Output path

required
Source code in src/pythiabns/core/plotting.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot_waveform_posterior(result, waveform_generator, descriptors, outdir, n_samples=100):
    """
    Plot the waveform posterior against injection/data.

    Args:
        result: Bilby Result
        waveform_generator: Initiailized WaveformGenerator
        descriptors: Dictionary of {'label': [list of ifos]} or similar to know what to plot?
                     Actually, usually we plot strain vs time/freq.
        outdir: Output path
    """
    logger.info("Generating waveform posterior plot...")

    # Select random samples
    if len(result.posterior) >= n_samples:
        result.posterior.sample(n=n_samples)

    # We need to compute waveforms for these samples
    # accessing waveform_generator.time_domain_source_model(parameter)

    # This is quite specific to the domain (time/freq) and IFOs
    # For now, let's just implement a simple Time Domain viewer if available

    pass

Models

EmpiricalRelation

Bases: ABC

Abstract base class for universal relations.

Source code in src/pythiabns/models/relations.py
14
15
16
17
18
19
20
21
22
class EmpiricalRelation(ABC):
    """Abstract base class for universal relations."""

    @abstractmethod
    def predict(self, m1: float, m2: float, eos_name: str) -> dict[str, float]:
        """
        Predict f_peak and potentially other properties.
        """
        ...

predict(m1, m2, eos_name) abstractmethod

Predict f_peak and potentially other properties.

Source code in src/pythiabns/models/relations.py
17
18
19
20
21
22
@abstractmethod
def predict(self, m1: float, m2: float, eos_name: str) -> dict[str, float]:
    """
    Predict f_peak and potentially other properties.
    """
    ...

WaveformModel

Bases: Protocol

Protocol that all waveform models must adhere to.

Source code in src/pythiabns/models/interface.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@runtime_checkable
class WaveformModel(Protocol):
    """Protocol that all waveform models must adhere to."""

    def __call__(self, frequency_array: np.ndarray, **params: float) -> dict[str, np.ndarray]:
        """
        Generate waveform polarizations.

        Args:
            frequency_array: Array of frequencies in Hz.
            **params: Model parameters (masses, spins, etc).

        Returns:
            Dict containing 'plus' and 'cross' keys with complex strain arrays.
        """
        ...

__call__(frequency_array, **params)

Generate waveform polarizations.

Parameters:

Name Type Description Default
frequency_array ndarray

Array of frequencies in Hz.

required
**params float

Model parameters (masses, spins, etc).

{}

Returns:

Type Description
dict[str, ndarray]

Dict containing 'plus' and 'cross' keys with complex strain arrays.

Source code in src/pythiabns/models/interface.py
10
11
12
13
14
15
16
17
18
19
20
21
def __call__(self, frequency_array: np.ndarray, **params: float) -> dict[str, np.ndarray]:
    """
    Generate waveform polarizations.

    Args:
        frequency_array: Array of frequencies in Hz.
        **params: Model parameters (masses, spins, etc).

    Returns:
        Dict containing 'plus' and 'cross' keys with complex strain arrays.
    """
    ...

Inference

PriorFactory

Factory to generate bilby PriorDicts based on configuration.

Source code in src/pythiabns/inference/priors.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class PriorFactory:
    """Factory to generate bilby PriorDicts based on configuration."""

    @staticmethod
    def create_priors(
        config: PriorConfig,
        model_name: str,
        metadata: dict[str, Any] | None = None,
        model_params: dict[str, Any] | None = None,
    ) -> bilby.core.prior.PriorDict:
        # Fetch conversion function from registry
        reg_meta = registry.ModelRegistry.get_metadata(model_name, **(model_params or {}))
        conversion_func = reg_meta.get("conversion_func")

        # 1. Load base priors
        priors = bilby.core.prior.PriorDict(conversion_function=conversion_func)

        # Load from file if specified
        base_filename = f"{model_name}.priors"
        # Check explicit source or default locations
        if config.source:
            if (constants.PRIORS_PATH / config.source).exists():
                priors.from_file(str(constants.PRIORS_PATH / config.source))
            elif (constants.PROJECT_ROOT / config.source).exists():
                priors.from_file(str(constants.PROJECT_ROOT / config.source))

        if not priors and (constants.PRIORS_PATH / base_filename).exists():
            priors.from_file(str(constants.PRIORS_PATH / base_filename))

        # 2. Apply Empirical Relations if requested
        if config.mode == "empirical" and metadata:
            PriorFactory._apply_empirical_relations(priors, config, metadata)

        return priors

    @staticmethod
    def _apply_empirical_relations(priors: bilby.core.prior.PriorDict, config: PriorConfig, metadata: dict[str, Any]):
        method = config.source  # e.g. "VSB_R"
        relation_cls = registry.RelationRegistry.get(method)
        if not relation_cls:
            logger.warning(f"Relation {method} not found in registry.")
            return

        relation = relation_cls()
        eos = metadata.get("id_eos", "SLY")
        m1 = metadata.get("id_mass_starA")
        m2 = metadata.get("id_mass_starB")

        if m1 is None or m2 is None:
            logger.warning("Masses not found in metadata, skipping empirical priors.")
            return

        preds = relation.predict(m1, m2, eos)
        fpeak = preds.get("f_peak")

        # Update f_peak prior
        if "f_peak" in priors and fpeak:
            # Assume 10% width or similar logic from original
            # Original used specific logic based on Distribution type (Gaussian/Uniform)
            # Here we simplify or need to inspect prior type.
            prior = priors["f_peak"]
            if isinstance(prior, bilby.core.prior.Uniform):
                width = 500  # Arbitrary default or derived?
                priors["f_peak"] = bilby.core.prior.Uniform(fpeak - width, fpeak + width, name="f_peak")
            elif isinstance(prior, bilby.core.prior.Gaussian):
                priors["f_peak"] = bilby.core.prior.Gaussian(mu=fpeak, sigma=200, name="f_peak")  # sigma??

PostMergerLikelihood

Bases: GravitationalWaveTransient

Likelihood class for Post-Merger PE. Inherits from GravitationalWaveTransient to use standard detector response logic.

Source code in src/pythiabns/inference/likelihood.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class PostMergerLikelihood(bilby.gw.likelihood.GravitationalWaveTransient):
    """
    Likelihood class for Post-Merger PE.
    Inherits from GravitationalWaveTransient to use standard detector response logic.
    """

    def __init__(self, interferometers, waveform_generator):
        super().__init__(interferometers, waveform_generator)

    # Note: The standard GravitationalWaveTransient is usually sufficient
    # IF the waveform_generator produces the correct mode (FD/TD).
    # However, we might want to override logic if we need custom handling meant for 3G detectors
    # or specific noise models not in Bilby (though Bilby is quite complete).

    # For now, this is a thin wrapper but allows us to extend it later
    # (e.g. for marginalization over calibration if needed).
    pass

Samplers

BilbyPocomcPrior

Wrapper to make bilby priors compatible with PocoMC Prior interface.

Source code in src/pythiabns/inference/samplers/pocomc.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class BilbyPocomcPrior:
    """Wrapper to make bilby priors compatible with PocoMC Prior interface."""

    def __init__(self, bilby_priors):
        self.priors = bilby_priors

        # Identify stochastic keys
        self.keys = []
        self.fixed_params = {}

        for k in self.priors.keys():
            p = self.priors[k]
            if isinstance(p, (bilby.core.prior.Constraint, bilby.core.prior.DeltaFunction)):
                if isinstance(p, bilby.core.prior.DeltaFunction):
                    self.fixed_params[k] = p.peak
                continue
            elif isinstance(p, (float, int)):
                self.fixed_params[k] = p
                continue
            else:
                self.keys.append(k)

        self.dim = len(self.keys)
        # Sort keys to ensure consistency
        self.keys.sort()

        # Precompute bounds
        self._bounds = []
        for k in self.keys:
            p = self.priors[k]
            if hasattr(p, "minimum") and hasattr(p, "maximum"):
                self._bounds.append([p.minimum, p.maximum])
            else:
                self._bounds.append([-np.inf, np.inf])
        self.bounds = np.array(self._bounds)

    def logpdf(self, x):
        x = np.atleast_2d(x)
        n = x.shape[0]
        res = np.zeros(n)

        for i in range(n):
            params = dict(zip(self.keys, x[i]))
            # Merge fixed params so constraints/dependencies work
            params.update(self.fixed_params)
            res[i] = self.priors.ln_prob(params)
        return res

    def rvs(self, size=1):
        samples = [self.priors.sample() for _ in range(size)]
        res = np.array([[s[k] for k in self.keys] for s in samples])
        return res

    def __call__(self, x):
        return self.logpdf(x)

PocoMCWrapper

Wrapper for PocoMC sampler.

Source code in src/pythiabns/inference/samplers/pocomc.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class PocoMCWrapper:
    """Wrapper for PocoMC sampler."""

    def __init__(
        self,
        likelihood: bilby.Likelihood,
        priors: bilby.core.prior.PriorDict,
        outdir: Path,
        label: str,
        settings: dict[str, Any] | None = None,
    ):
        self.likelihood = likelihood
        self.priors = priors
        self.outdir = Path(outdir)
        self.label = label
        self.settings = settings or {}

        self.outdir.mkdir(parents=True, exist_ok=True)
        self.wrapped_prior = BilbyPocomcPrior(self.priors)

        # Populate fixed parameters into likelihood initially
        self.likelihood.parameters.update(self.wrapped_prior.fixed_params)

        # Periodicity logic
        self.periodicity = []
        for i, key in enumerate(self.wrapped_prior.keys):
            p = self.priors[key]
            if hasattr(p, "boundary") and p.boundary == "periodic":
                self.periodicity.append(i)

    def log_likelihood(self, x):
        x = np.atleast_1d(x)
        if x.ndim > 1:
            res = []
            for xi in x:
                res.append(self._log_likelihood_single(xi))
            return np.array(res)
        return self._log_likelihood_single(x)

    def _log_likelihood_single(self, x):
        params = dict(zip(self.wrapped_prior.keys, x))
        # Fixed params already in likelihood.parameters, but safety:
        # If likelihood code modifies them? Usually not.
        # But if we update with `params`, it overwrites stochastic ones.
        self.likelihood.parameters.update(params)
        return self.likelihood.log_likelihood()

    def run(self):
        import pocomc as pc

        nwalkers = self.settings.get("npoints", 1000)

        # Using 'multiprocess' with dill support
        with multiprocess.Pool(self.settings.get("n_cpus", 1)) as pool:
            sampler = pc.Sampler(
                prior=self.wrapped_prior,
                likelihood=self.log_likelihood,
                n_dim=self.wrapped_prior.dim,
                n_effective=nwalkers,
                n_active=nwalkers,
                periodic=self.periodicity,
                pool=pool,
                vectorize=False,
            )

            sampler.run(progress=True)

            self._save_results(sampler)

    def _save_results(self, sampler):
        results = sampler.results
        samples = results.get("samples", results.get("posterior_samples"))
        if samples is None:
            # Fallback logic based on version behavior
            if "posterior_samples" in results and results["posterior_samples"] is not None:
                samples = results["posterior_samples"]
            else:
                # sampler might have 'samples' attribute?
                pass

        if samples is not None:
            df = pd.DataFrame(samples, columns=self.wrapped_prior.keys)
            df["log_prior"] = results.get("posterior_logp", results.get("log_prior"))
            df["log_likelihood"] = results.get("posterior_logl", results.get("log_likelihood"))

            df.to_json(self.outdir / f"{self.label}_result.json")

        with open(self.outdir / f"{self.label}_pocomc.pickle", "wb") as f:
            dill.dump(results, f)

ZeusWrapper

Wrapper for zeus sampler.

Source code in src/pythiabns/inference/samplers/zeus.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ZeusWrapper:
    """Wrapper for zeus sampler."""

    def __init__(
        self,
        likelihood: bilby.Likelihood,
        priors: bilby.core.prior.PriorDict,
        outdir: Path,
        label: str,
        settings: dict[str, Any] | None = None,
    ):
        self.likelihood = likelihood
        self.priors = priors
        self.outdir = Path(outdir)
        self.label = label
        self.settings = settings or {}

        self.outdir.mkdir(parents=True, exist_ok=True)
        self.wrapped_prior = BilbyPocomcPrior(self.priors)  # Zeus also needs prior samples/logpdf

        self.likelihood.parameters.update(self.wrapped_prior.fixed_params)

    def log_likelihood(self, x):
        params = dict(zip(self.wrapped_prior.keys, x))
        self.likelihood.parameters.update(params)
        return self.likelihood.log_likelihood()

    def run(self):
        import multiprocess
        import zeus

        nwalkers = self.settings.get("nwalkers", 2 * self.wrapped_prior.dim + 2)
        nsteps = self.settings.get("nsteps", 1000)
        n_cpus = self.settings.get("n_cpus", 1)

        # Initial positions
        start_pos = self.wrapped_prior.rvs(nwalkers)

        with multiprocess.Pool(n_cpus) as pool:
            sampler = zeus.Sampler(
                logprob_fn=self.log_likelihood, n_dim=self.wrapped_prior.dim, n_walkers=nwalkers, pool=pool
            )

            sampler.run_mcmc(start_pos, nsteps)

            self._save_results(sampler)

    def _save_results(self, sampler):
        samples = sampler.get_chain(flat=True)
        log_prob = sampler.get_log_prob(flat=True)

        df = pd.DataFrame(samples, columns=self.wrapped_prior.keys)
        df["log_likelihood"] = log_prob  # Note: zeus doesn't separate prior/likelihood by default if using logprob_fn

        df.to_json(self.outdir / f"{self.label}_result.json")

        with open(self.outdir / f"{self.label}_zeus.pickle", "wb") as f:
            dill.dump(sampler, f)

NumPyroWrapper

Wrapper for NumPyro sampler.

Source code in src/pythiabns/inference/samplers/numpyro_sampler.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class NumPyroWrapper:
    """Wrapper for NumPyro sampler."""

    def __init__(
        self,
        likelihood: bilby.Likelihood,
        priors: bilby.core.prior.PriorDict,
        outdir: Path,
        label: str,
        settings: dict[str, Any] | None = None,
    ):
        self.likelihood = likelihood
        self.priors = priors
        self.outdir = Path(outdir)
        self.label = label
        self.settings = settings or {}

        self.outdir.mkdir(parents=True, exist_ok=True)
        self.wrapped_prior = BilbyPocomcPrior(self.priors)

        self.likelihood.parameters.update(self.wrapped_prior.fixed_params)

    def log_likelihood(self, x):
        params = dict(zip(self.wrapped_prior.keys, x))
        self.likelihood.parameters.update(params)
        return self.likelihood.log_likelihood()

    def run(self):
        import jax
        import jax.numpy as jnp
        import numpyro
        from numpyro.infer import MCMC

        n_samples = self.settings.get("n_samples", 1000)
        n_warmup = self.settings.get("n_warmup", n_samples // 2)
        n_chains = self.settings.get("n_chains", 1)

        # We use a potential_fn because we have a custom log-likelihood/prior
        def potential_fn(params):
            # Convert jax dict to list in correct order
            x = [params[k] for k in self.wrapped_prior.keys]

            # Use pure_callback to call non-jax likelihood
            def wrapped_logp(x_np):
                lp = self.wrapped_prior.logpdf(np.atleast_2d(x_np))[0]
                ll = self.log_likelihood(x_np)
                return -(lp + ll)

            val = jax.pure_callback(wrapped_logp, jnp.float64(0.0), jnp.array(x))
            return val

        # For NUTS, we need gradients. pure_callback doesn't support them by default.
        # If we use NUTS, it might fail unless we use a sampler that doesn't need gradients
        # or we provide numerical gradients.

        logger.info(f"NumPyro sampling started with {n_samples} samples and {n_chains} chains.")

        # Initial values from prior
        rng_key = jax.random.PRNGKey(0)
        init_params = self.wrapped_prior.rvs(n_chains)
        init_dict = {
            self.wrapped_prior.keys[i]: jnp.array(init_params[:, i]) for i in range(len(self.wrapped_prior.keys))
        }

        # Use SA (Simulated Annealing) or another gradient-free sampler
        # because the models are not JAX-traceable.
        kernel = numpyro.infer.SA(potential_fn=potential_fn)
        mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=n_samples, num_chains=n_chains)
        mcmc.run(rng_key, init_params=init_dict)

        self._save_results(mcmc)

    def _save_results(self, mcmc):
        samples = mcmc.get_samples()
        # Convert to pandas. Handle potential chain/sample dimension overlap
        data = {}
        for k, v in samples.items():
            v_np = np.array(v)
            if v_np.ndim > 1:
                # Flatten if it's (num_samples, 1) or similar, else keep but pandas might complain
                if v_np.shape[1] == 1:
                    data[k] = v_np.flatten()
                else:
                    # Multi-dim param, might need special handling but for now just take it
                    data[k] = list(v_np)
            else:
                data[k] = v_np

        df = pd.DataFrame(data)

        df.to_json(self.outdir / f"{self.label}_result.json")

        with open(self.outdir / f"{self.label}_numpyro.pickle", "wb") as f:
            dill.dump(samples, f)

BlackJAXWrapper

Wrapper for BlackJAX sampler.

Source code in src/pythiabns/inference/samplers/blackjax_sampler.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class BlackJAXWrapper:
    """Wrapper for BlackJAX sampler."""

    def __init__(
        self,
        likelihood: bilby.Likelihood,
        priors: bilby.core.prior.PriorDict,
        outdir: Path,
        label: str,
        settings: dict[str, Any] | None = None,
    ):
        self.likelihood = likelihood
        self.priors = priors
        self.outdir = Path(outdir)
        self.label = label
        self.settings = settings or {}

        self.outdir.mkdir(parents=True, exist_ok=True)
        self.wrapped_prior = BilbyPocomcPrior(self.priors)

        self.likelihood.parameters.update(self.wrapped_prior.fixed_params)

    def log_likelihood(self, x):
        params = dict(zip(self.wrapped_prior.keys, x))
        self.likelihood.parameters.update(params)
        return self.likelihood.log_likelihood()

    def run(self):
        import blackjax
        import jax
        import jax.numpy as jnp

        n_samples = self.settings.get("n_samples", 1000)
        # n_warmup = self.settings.get("n_warmup", n_samples // 2)

        # BlackJAX is very flexible. Requires a log-density function.
        def log_density(x):
            # Wrapper to use pure_callback for non-jax likelihood/prior
            def wrapped_logp(x_np):
                lp = self.wrapped_prior.logpdf(np.atleast_2d(x_np))[0]
                ll = self.log_likelihood(x_np)
                return lp + ll

            return jax.pure_callback(wrapped_logp, jnp.float64(0.0), x)

        logger.info(f"BlackJAX sampling initialized with RMH, {n_samples} samples.")

        # Initial values
        rng_key = jax.random.PRNGKey(0)
        init_params = self.wrapped_prior.rvs(1)[0]
        init_params = jnp.array(init_params)

        # Using Random Walk Metropolis Hastings as it doesn't require gradients
        # Use additive_step_random_walk with a simple normal step
        def random_step(key, x):
            return x + jax.random.normal(key, x.shape) * 0.1

        rw = blackjax.additive_step_random_walk(log_density, random_step)
        state = rw.init(init_params)

        def step(state, key):
            state, _ = rw.step(key, state)
            return state, state

        keys = jax.random.split(rng_key, n_samples)
        _, states = jax.lax.scan(step, state, keys)

        self._save_results(states.position)

    def _save_results(self, samples):
        # Convert to pandas
        df = pd.DataFrame(np.array(samples), columns=self.wrapped_prior.keys)
        df.to_json(self.outdir / f"{self.label}_result.json")

        with open(self.outdir / f"{self.label}_blackjax.pickle", "wb") as f:
            dill.dump(samples, f)

TempestWrapper

Wrapper for tempest sampler.

Source code in src/pythiabns/inference/samplers/tempest.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class TempestWrapper:
    """Wrapper for tempest sampler."""

    def __init__(
        self,
        likelihood: bilby.Likelihood,
        priors: bilby.core.prior.PriorDict,
        outdir: Path,
        label: str,
        settings: dict[str, Any] | None = None,
    ):
        self.likelihood = likelihood
        self.priors = priors
        self.outdir = Path(outdir)
        self.label = label
        self.settings = settings or {}

        self.outdir.mkdir(parents=True, exist_ok=True)
        self.wrapped_prior = BilbyPocomcPrior(self.priors)

        self.likelihood.parameters.update(self.wrapped_prior.fixed_params)

    def log_likelihood(self, x):
        params = dict(zip(self.wrapped_prior.keys, x))
        self.likelihood.parameters.update(params)
        return self.likelihood.log_likelihood()

    def run(self):
        import multiprocess
        import tempest as tp

        # Extract settings for run/pool
        n_samples = self.settings.get("n_samples", 1000)
        n_cpus = self.settings.get("n_cpus", 1)

        # Filter settings for Sampler.__init__
        sampler_settings = self.settings.copy()
        sampler_settings.pop("n_samples", None)
        sampler_settings.pop("n_cpus", None)

        # Tempest requires a prior_transform (unit cube -> physical)
        def prior_transform(u):
            # Bilby priors rescale takes unit cube values
            return self.priors.rescale(self.wrapped_prior.keys, u)

        logger.info(f"Tempest sampling started with {n_samples} samples.")

        with multiprocess.Pool(n_cpus) as pool:
            sampler = tp.Sampler(
                prior_transform=prior_transform,
                log_likelihood=self.log_likelihood,
                n_dim=self.wrapped_prior.dim,
                pool=pool,
                **sampler_settings,
            )

            sampler.run(n_total=n_samples)

            self._save_results(sampler)

    def _save_results(self, sampler):
        results = sampler.results()
        samples = results.get("samples")

        if samples is not None:
            df = pd.DataFrame(samples, columns=self.wrapped_prior.keys)
            df["log_likelihood"] = results.get("log_likelihood")
            df["log_prior"] = results.get("log_prior")

            df.to_json(self.outdir / f"{self.label}_result.json")

        with open(self.outdir / f"{self.label}_tempest.pickle", "wb") as f:
            dill.dump(results, f)

Detectors

DetectorNetwork

Wrapper around bilby InterferometerList.

Source code in src/pythiabns/detectors/network.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class DetectorNetwork:
    """Wrapper around bilby InterferometerList."""

    def __init__(
        self,
        ifo_names: list[str] = ["H1", "L1", "V1"],
        sampling_frequency: float = 4096,
        duration: float = 1.0,
        start_time: float = 0.0,
    ):
        self.ifo_names = ifo_names
        self.sampling_frequency = sampling_frequency
        self.duration = duration
        self.start_time = start_time

        self.ifos = bilby.gw.detector.InterferometerList(ifo_names)
        self._configure_detectors()

    def _configure_detectors(self):
        for ifo in self.ifos:
            ifo.minimum_frequency = 20  # Configurable?
            ifo.maximum_frequency = self.sampling_frequency / 2.0

    def set_data(self, noise: bool = False):
        if noise:
            try:
                self.ifos.set_strain_data_from_power_spectral_densities(
                    sampling_frequency=self.sampling_frequency, duration=self.duration, start_time=self.start_time
                )
            except Exception as e:
                logger.warning(f"Failed to set noise from PSD: {e}. Fallback to zero noise + Gaussian?")
                # Logic from ifo.py line 150 used set_strain_data_from_power_spectral_densities
                # This generates Gaussian noise colored by PSD.
                raise e
        else:
            self.ifos.set_strain_data_from_zero_noise(
                sampling_frequency=self.sampling_frequency, duration=self.duration, start_time=self.start_time
            )

    def inject_signal(self, waveform_generator: bilby.gw.waveform_generator.WaveformGenerator, parameters: dict):
        """Inject signal into detectors."""
        self.ifos.inject_signal(waveform_generator=waveform_generator, parameters=parameters, raise_error=False)

    @property
    def meta_data(self):
        # Helper to access SNR etc
        # bilby ifo.meta_data usually stores optimal_SNR after injection
        return {ifo.name: ifo.meta_data for ifo in self.ifos}

inject_signal(waveform_generator, parameters)

Inject signal into detectors.

Source code in src/pythiabns/detectors/network.py
47
48
49
def inject_signal(self, waveform_generator: bilby.gw.waveform_generator.WaveformGenerator, parameters: dict):
    """Inject signal into detectors."""
    self.ifos.inject_signal(waveform_generator=waveform_generator, parameters=parameters, raise_error=False)

Data Utils

NumericalWaveform

Class to handle loading and processing of NR waveforms.

Source code in src/pythiabns/data_utils/nr.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class NumericalWaveform:
    """Class to handle loading and processing of NR waveforms."""

    def __init__(self, filename: str, sampling_frequency: float | None = None):
        self.filename = filename
        self.sampling_frequency = sampling_frequency

        # Determine path
        # Logic ported from NR_strains.py:
        # If 'Soultanis' in name, special handling. Else scan STRAIN_PATH.

        if os.path.isabs(filename) or os.path.exists(filename):
            self._load_from_path(Path(filename))
        elif filename.startswith("Soultanis"):
            self._load_soultanis(filename)
        else:
            self._load_standard(filename)

        # Common initialization
        self.Mtot = (self.m1 + self.m2) * constants.MSUN_SI

        # Convert to SI
        self._time_to_SI()
        self.hp, self.hc = self._set_to_1Mpc()

        # Resample if needed
        # In original, resampling was mandatory?
        # self.resample() # TODO: Verify if mandatory

    def _load_from_path(self, path: Path):
        self.datapath = path
        # Check if it's an NR directory
        if (path / "metadata.txt").exists():
            self.metadata_dict = self._load_metadata(self.datapath)
            self.m1 = float(self.metadata_dict.get("id_mass_starA", 1.4))
            self.m2 = float(self.metadata_dict.get("id_mass_starB", 1.4))
            self.rh_overmtot_p, self.rh_overmtot_c, self.time, self.extraction_radius = self._read_hdf5_data()
        else:
            # Simple file loading (txt/csv)
            self._load_simple_file(path)

    def _load_standard(self, filename: str):
        self.datapath = constants.STRAIN_PATH / filename
        self._load_from_path(self.datapath)

    def _load_simple_file(self, path: Path):
        # Assume 3 columns: time, hp, hc
        data = np.loadtxt(path)
        self.time = data[:, 0]
        self.rh_overmtot_p = data[:, 1]
        self.rh_overmtot_c = data[:, 2]
        self.metadata_dict = {"id_name": path.name}
        self.m1 = 1.4  # Defaults
        self.m2 = 1.4
        self.extraction_radius = 0

        # Flag to indicate its already SI and scaled to 1Mpc?
        # For simplicity, if loading a custom file, we assume it's h+ and hx at 1Mpc and in SI.
        # So we skip _time_to_SI and _set_to_1Mpc logic by setting special values.
        self._is_si = True

    def _load_soultanis(self, filename: str):
        # Soultanis/1.55
        mass = float(filename.split("/")[-1])
        base_dir = constants.STRAIN_PATH / filename.split("/")[0]

        # Find matching file?
        # Original: [i for i in os.listdir(...) if mass in i][0]
        # Assuming filename structure matches
        try:
            # This is a bit fragile but ports existing logic
            candidates = list(base_dir.glob(f"*{mass}*"))
            if not candidates:
                raise FileNotFoundError(f"No file found for {filename}")
            self.datapath = candidates[0]
        except Exception as e:
            raise FileNotFoundError(f"Error finding Soultanis file: {e}")

        self.metadata_dict = {"id_mass_starA": mass, "id_mass_starB": mass, "id_eos": "MPA1", "id_name": filename}
        self.m1 = mass
        self.m2 = mass
        self.extraction_radius = 0  # Not applicable/available?

        data = np.loadtxt(self.datapath)
        self.time = data.T[0] / 1000.0  # ms to s
        # Original scaling: data.T[1]/8.35898e+20*40
        # What is 8.358...? Likely unit conversion.
        # Original comment: #@ 1Mpc
        # Note: In _set_to_1Mpc we might re-scale.
        # But Soultanis load seems to return hp already scaled?
        # Original NumericalData inits rh_overmtot for standard, but hp for Soultanis.

        # Here I will populate rh_overmtot_p/c assuming they are NOT scaled to Mtot yet?
        # Actually Soultanis loader in original SETS hp directly.
        # So I should handle that difference.

        self.rh_overmtot_p = data.T[1] / 8.35898e20 * 40
        self.rh_overmtot_c = data.T[2] / 8.35898e20 * 40

        # Hack: set a flag to skip SI conversion if already in SI?
        # Original: time_msun...
        # Standard: load_NR_strains -> rh... then time_to_SI -> set_to_1Mpc
        # Soultanis: loads already converted time?
        # Original: time = data.T[0]/1000 (ms to s). So it IS in SI (seconds).
        # Standard loads geometric time?

    def _read_hdf5_data(self):
        h5_path = self.datapath / "data.h5"
        with h5py.File(h5_path, "r") as f:
            # List l=2, m=2 modes
            names = [x for x in f["/rh_22"] if "l2_m2" in x]
            # Select extraction at largest radius
            # Original logic: last one, check for Inf
            names.sort()  # Ensure order?
            # Original used list(f[]) which is unordered in some h5py versions?
            # Assuming sorted by string works for radii r100, r200 etc.
            selection = names[-1]
            if "Inf" in selection and len(names) > 1:
                selection = names[-2]

            dset = f[f"/rh_22/{selection}"]
            data = pd.DataFrame(dset[:])  # Read all

            time = data.iloc[:, 0].values
            rh_p = data.iloc[:, 1].values
            rh_c = data.iloc[:, 2].values

            # Extract radius from name "l2_m2_r400.txt" or similar
            # Original: float(selection.split(".")[0].split("r")[1])
            try:
                extraction_radius = float(selection.split("r")[-1].split(".")[0])
            except (ValueError, IndexError):
                extraction_radius = 0.0

            return rh_p, rh_c, time, extraction_radius

    def _load_metadata(self, path: Path) -> dict[str, Any]:
        meta_file = path / "metadata.txt"
        meta: dict[str, Any] = {}
        if not meta_file.exists():
            return meta

        with open(meta_file) as f:
            for line in f:
                parts = line.split()
                if not parts:
                    continue
                if "Evolution" in parts:
                    break  # Stop reading
                if "id_" in parts[0]:
                    key = parts[0]
                    raw_val = parts[-1]
                    val: Any = None
                    try:
                        val = float(raw_val)
                    except ValueError:
                        val = raw_val
                    meta[key] = val
        return meta

    def _time_to_SI(self):
        if hasattr(self, "_is_si") and self._is_si:
            return
        if self.filename.startswith("Soultanis"):
            return  # Already SI
        # Convert geometric time to seconds
        # time_SI = time_geom * G * M / c^3
        factor = constants.G_SI * self.Mtot / (constants.C_SI**3)
        self.time = self.time * factor

    def _set_to_1Mpc(self):
        if hasattr(self, "_is_si") and self._is_si:
            return self.rh_overmtot_p, self.rh_overmtot_c
        if self.filename.startswith("Soultanis"):
            # Already scaled? Original code set self.hp directly.
            # In my class I stored it in rh_overmtot for consistency of storage.
            # So just returns them.
            return self.rh_overmtot_p, self.rh_overmtot_c

        # Standard scaling
        # hp = (rh/M) * (G*M/c^2) * (1/dist)
        # rh_overmtot is actually r*h / Mtot ???
        # Original: rh_overmtot_p * mtot_geom / 1Mpc
        # mtot_geom = G * M / c^2 (Length)

        mtot_geom = constants.G_SI * self.Mtot / (constants.C_SI**2)
        one_mpc = 1e6 * 3.085677581e16  # Parsec to meters

        hp = self.rh_overmtot_p * mtot_geom / one_mpc
        hc = self.rh_overmtot_c * mtot_geom / one_mpc
        return hp, hc

    def get_post_merger(self, inplace=True):
        """Crop to post-merger signal."""
        # Find merger time (max amplitude)
        amp = np.sqrt(self.hp**2 + self.hc**2)
        idx = np.argmax(amp)

        if inplace:
            self.time = self.time[idx:]
            self.hp = self.hp[idx:]
            self.hc = self.hc[idx:]
            # Ensure odd/even length consistency? Original had some check
        else:
            return self.time[idx:], self.hp[idx:], self.hc[idx:]

    def resample(self, new_fs=None):
        if new_fs is None:
            if self.sampling_frequency:
                new_fs = self.sampling_frequency
            else:
                new_fs = 8192

        dt = 1.0 / new_fs
        new_time = np.arange(self.time[0], self.time[-1], dt)
        self.hp = processing.interpolate(self.time, self.hp, new_time)
        self.hc = processing.interpolate(self.time, self.hc, new_time)
        self.time = new_time

get_post_merger(inplace=True)

Crop to post-merger signal.

Source code in src/pythiabns/data_utils/nr.py
205
206
207
208
209
210
211
212
213
214
215
216
217
def get_post_merger(self, inplace=True):
    """Crop to post-merger signal."""
    # Find merger time (max amplitude)
    amp = np.sqrt(self.hp**2 + self.hc**2)
    idx = np.argmax(amp)

    if inplace:
        self.time = self.time[idx:]
        self.hp = self.hp[idx:]
        self.hc = self.hc[idx:]
        # Ensure odd/even length consistency? Original had some check
    else:
        return self.time[idx:], self.hp[idx:], self.hc[idx:]