Source code for viscube.grid_cube

import warnings
from typing import Callable, Tuple, Sequence, Optional, Union
import numpy as np
from scipy.spatial import cKDTree
import inspect
from tqdm import tqdm
from functools import wraps

# Use your existing implementations
from .gridder import bin_data
from .windows import (
    kaiser_bessel_window,
    casa_pswf_window,
    pillbox_window,
    sinc_window,
)

# -----------------------
# Low-level utilities
# -----------------------


[docs] def load_and_mask( frequencies: np.ndarray, uu: np.ndarray, vv: np.ndarray, vis: np.ndarray, weight: np.ndarray, sigma_re: np.ndarray, sigma_im: np.ndarray, mask: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Apply per-channel mask and compact arrays. Returns frequencies, u0, v0, vis0, w0, sigma_re0, sigma_im0. Assumes the number of valid visibilities is the same for every channel (as in your current implementation). If not, this should be changed to ragged lists. """ F = len(frequencies) Nmasked = int(mask[0].sum()) u0 = np.zeros((F, Nmasked), dtype=np.float64) v0 = np.zeros((F, Nmasked), dtype=np.float64) vis0 = np.zeros((F, Nmasked), dtype=np.complex128) w0 = np.zeros((F, Nmasked), dtype=np.float64) s_re0 = np.zeros((F, Nmasked), dtype=np.float64) s_im0 = np.zeros((F, Nmasked), dtype=np.float64) for i in range(F): mi = mask[i] # Optional safety check: if int(mi.sum()) != Nmasked: raise ValueError( f"mask has variable valid count across channels; channel {i} has {int(mi.sum())}, " f"expected {Nmasked}. Use a ragged representation instead." ) u0[i] = uu[i][mi] v0[i] = vv[i][mi] vis0[i] = vis[i][mi] w0[i] = weight[i][mi] s_re0[i] = sigma_re[i][mi] s_im0[i] = sigma_im[i][mi] return frequencies, u0, v0, vis0, w0, s_re0, s_im0
[docs] def hermitian_augment( u0: np.ndarray, v0: np.ndarray, vis0: np.ndarray, w0: np.ndarray, sigma_re0: np.ndarray, sigma_im0: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Hermitian augment: (u, v, Re, Im, w, sigma_re, sigma_im) -> concat with (-u, -v, +Re, -Im, w, sigma_re, sigma_im) Returns ------- uu, vv, vis_re, vis_imag, w, sigma_re_aug, sigma_im_aug """ uu = np.concatenate([u0, -u0], axis=1) vv = np.concatenate([v0, -v0], axis=1) vis_re = np.concatenate([vis0.real, vis0.real], axis=1) vis_imag = np.concatenate([vis0.imag, -vis0.imag], axis=1) w = np.concatenate([w0, w0], axis=1) # Variance does not change under sign flip/conjugation sigma_re_aug = np.concatenate([sigma_re0, sigma_re0], axis=1) sigma_im_aug = np.concatenate([sigma_im0, sigma_im0], axis=1) return uu, vv, vis_re, vis_imag, w, sigma_re_aug, sigma_im_aug
_ARCSEC_PER_RAD = 3600 * 180/np.pi
[docs] def make_uv_grid( uu: np.ndarray, vv: np.ndarray, npix: int, pad_uv: float, *, fov_arcsec: Optional[float] = None, warn_crop: bool = True, ) -> Tuple[np.ndarray, np.ndarray, float, float]: """ Build symmetric square uv grid; truncation_radius == delta_u. Parameters ---------- fov_arcsec : float, optional Image-plane field of view in arcseconds. If provided, uv cell size is set by delta_u = 1 / fov_rad, where fov_rad = fov_arcsec / 206265. Notes ----- Assumes u,v are in wavelengths. Then: - image-plane angle is radians, - Fourier dual spacing satisfies FOV ≈ 1/delta_u. """ if npix <= 0: raise ValueError(f"npix must be positive; got {npix}.") # Legacy mode: infer grid extent from data (with pad_uv), as before. if fov_arcsec is None: maxuv = max(np.abs(uu).max(), np.abs(vv).max()) u_min = -maxuv * (1.0 + pad_uv) u_max = +maxuv * (1.0 + pad_uv) u_edges = np.linspace(u_min, u_max, npix + 1, dtype=float) v_edges = np.linspace(u_min, u_max, npix + 1, dtype=float) delta_u = float(u_edges[1] - u_edges[0]) truncation_radius = delta_u return u_edges, v_edges, delta_u, truncation_radius # Explicit-FOV mode (arcsec -> rad) fov_arcsec = float(fov_arcsec) if not np.isfinite(fov_arcsec) or fov_arcsec <= 0.0: raise ValueError(f"fov_arcsec must be a positive finite float; got {fov_arcsec!r}.") fov_rad = fov_arcsec / _ARCSEC_PER_RAD # radians # fov_rad and npix fully determine the uv grid in this mode. if pad_uv != 0.0: warnings.warn( "pad_uv is ignored when fov_arcsec is specified (because fov_arcsec and npix " "fully determine the uv grid). To change oversampling/resolution, adjust fov_arcsec and/or npix.", RuntimeWarning, stacklevel=2, ) delta_u = 1.0 / fov_rad half_range = 0.5 * npix * delta_u # uv half-extent u_min = -half_range u_max = +half_range u_edges = np.linspace(u_min, u_max, npix + 1, dtype=float) v_edges = np.linspace(u_min, u_max, npix + 1, dtype=float) truncation_radius = delta_u if warn_crop: maxuv_data = max(np.abs(uu).max(), np.abs(vv).max()) if maxuv_data > half_range: outside = (np.abs(uu) > half_range) | (np.abs(vv) > half_range) frac = float(np.count_nonzero(outside)) / float(outside.size) warnings.warn( "Requested (fov_arcsec, npix) implies a uv grid smaller than the data extent:\n" f" data max(|u|,|v|) = {maxuv_data:.6g} wavelengths\n" f" grid half-range = {half_range:.6g} wavelengths\n" f" -> uv-space will be cropped; approx fraction outside grid: {frac:.3%}\n" "Consider increasing npix and/or decreasing fov_arcsec.", RuntimeWarning, stacklevel=2, ) return u_edges, v_edges, delta_u, truncation_radius
[docs] def build_grid_centers(u_edges: np.ndarray, v_edges: np.ndarray) -> np.ndarray: """ Measurement Set conventions for grid centers. """ Nu = len(u_edges) - 1 Nv = len(v_edges) - 1 centers = np.array( [ ((u_edges[k] + u_edges[k + 1]) / 2.0, (v_edges[j] + v_edges[j + 1]) / 2.0) for k in range(Nu) for j in range(Nv) ], dtype=float, ) return centers
[docs] def precompute_pairs( uu_i: np.ndarray, vv_i: np.ndarray, centers: np.ndarray, truncation_radius: float, *, p_metric: int = 1 ) -> Tuple[cKDTree, cKDTree, Sequence[Sequence[int]]]: """ Build KD-trees and query neighbor pairs for a single channel. """ uv_points = np.vstack((uu_i.ravel(), vv_i.ravel())).T uv_tree = cKDTree(uv_points) grid_tree = cKDTree(centers) pairs = grid_tree.query_ball_tree(uv_tree, truncation_radius, p=p_metric) return uv_tree, grid_tree, pairs
[docs] def grid_channel( uu_i: np.ndarray, vv_i: np.ndarray, vis_re_i: np.ndarray, vis_imag_i: np.ndarray, w_i: np.ndarray, u_edges: np.ndarray, v_edges: np.ndarray, window_fn, truncation_radius, uv_tree: cKDTree, grid_tree: cKDTree, pairs: Sequence[Sequence[int]], *, verbose_mean: int = 1, verbose_std: int = 2, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Grid one frequency channel using your existing bin_data. """ bins = (u_edges, v_edges) params = (uu_i, vv_i, w_i, bins, window_fn, truncation_radius, uv_tree, grid_tree, pairs) vis_bin_re = bin_data(uu_i, vv_i, vis_re_i, *params[2:], statistics_fn="mean", verbose=verbose_mean) std_bin_re = bin_data(uu_i, vv_i, vis_re_i, *params[2:], statistics_fn="std", verbose=verbose_std) vis_bin_imag = bin_data(uu_i, vv_i, vis_imag_i, *params[2:], statistics_fn="mean", verbose=verbose_mean) std_bin_imag = bin_data(uu_i, vv_i, vis_imag_i, *params[2:], statistics_fn="std", verbose=verbose_std) counts = bin_data(uu_i, vv_i, vis_re_i, *params[2:], statistics_fn="count", verbose=verbose_mean) return vis_bin_re, std_bin_re, vis_bin_imag, std_bin_imag, counts
[docs] def uv_grid_to_fft_image_convention(arr_uv: np.ndarray) -> np.ndarray: """ Convert UV grid from [u, v] axis order to image/FFT-friendly [v, u] row/col order. Works for 2D or cubes with last two axes = (Nu, Nv). """ # swap last two axes: (..., u, v) -> (..., v, u) #return np.swapaxes(arr_uv, -2, -1) return np.flip(np.swapaxes(arr_uv, -2, -1), axis=-2)
# ----------------------- # User-facing helpers # ----------------------- def _bind_window(fn, pixel_size, window_kwargs): """ Return a callable window(u, center) with kwargs safely bound. Only passes arguments that `fn` actually accepts. Always passes pixel_size if `fn` accepts it and it's not already provided. """ params = inspect.signature(fn).parameters kw = dict(window_kwargs or {}) if "pixel_size" in params and "pixel_size" not in kw: kw["pixel_size"] = pixel_size @wraps(fn) def bound(u, c): return fn(u, c, **kw) bound._window_base = fn bound._window_kwargs = kw return bound def _window_from_name(name: str, *, pixel_size: float, window_kwargs: Optional[dict] = None ): """ Build a window(u, center) callable from a string and a kwargs dict. No assumptions about which kwargs exist; only forwards what the window accepts. """ key = name.lower() if key in {"kb", "kaiser", "kaiser_bessel", "kaiser-bessel"}: base = kaiser_bessel_window elif key in {"pswf", "casa", "spheroidal"}: base = casa_pswf_window elif key in {"pillbox", "boxcar"}: base = pillbox_window elif key == "sinc": base = sinc_window else: raise ValueError(f"Unknown window name: {name!r}") return _bind_window(base, pixel_size=pixel_size, window_kwargs=window_kwargs)
[docs] def grid_cube_all_stats( *, frequencies: np.ndarray, uu: np.ndarray, vv: np.ndarray, vis_re: np.ndarray, vis_imag: np.ndarray, weight: np.ndarray, invvar_group_re: np.ndarray, # NEW: same shape as vis_re invvar_group_im: np.ndarray, # NEW: same shape as vis_imag npix: int = 501, fov_arcsec: Optional[float] = None, pad_uv: float = 0.0, window_name: Optional[str] = "kaiser_bessel", window_kwargs: Optional[dict] = None, window_fn = None, p_metric: int = 1, std_p: int = 1, std_workers: int = 6, std_min_effective: int = 5, n_eff_mode: str = "both" ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: u_edges, v_edges, delta_u, trunc_r = make_uv_grid( uu, vv, npix=npix, pad_uv=pad_uv, fov_arcsec=fov_arcsec, warn_crop=True ) centers = build_grid_centers(u_edges, v_edges) if window_fn is not None: window = _bind_window(window_fn, pixel_size=delta_u, window_kwargs=window_kwargs) else: if window_name is None: raise ValueError("Provide either window_name or a ready-made window_fn.") window = _window_from_name(window_name, pixel_size=delta_u, window_kwargs=window_kwargs) F = uu.shape[0] Nu = len(u_edges) - 1 Nv = len(v_edges) - 1 mean_re = np.zeros((F, Nu, Nv), dtype=np.float64) std_re = np.zeros((F, Nu, Nv), dtype=np.float64) mean_im = np.zeros((F, Nu, Nv), dtype=np.float64) std_im = np.zeros((F, Nu, Nv), dtype=np.float64) counts = np.zeros((F, Nu, Nv), dtype=np.float64) pbar = tqdm(range(F), unit="channel") for i in pbar: uv_tree, grid_tree, pairs = precompute_pairs( uu[i], vv[i], centers, trunc_r, p_metric=p_metric ) # Means vb_re = bin_data( uu[i], vv[i], vis_re[i], weight[i], None, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="mean", verbose=0 ) vb_im = bin_data( uu[i], vv[i], vis_imag[i], weight[i], None, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="mean", verbose=0 ) # Hybrid std (empirical normal pixels, propagated fallback on low-info) sb_re, stats_re = bin_data( uu[i], vv[i], vis_re[i], weight[i], invvar_group_re[i], (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="std", verbose=0, std_min_effective=std_min_effective, std_workers=std_workers, std_p=std_p, collect_stats=True, n_eff_mode = n_eff_mode, ) sb_im, stats_im = bin_data( uu[i], vv[i], vis_imag[i], weight[i], invvar_group_im[i], (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="std", verbose=0, std_min_effective=std_min_effective, std_workers=std_workers, std_p=std_p, collect_stats=True, n_eff_mode = n_eff_mode, ) # Counts cnt = bin_data( uu[i], vv[i], vis_re[i], weight[i], None, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="count", verbose=0 ) mean_re[i] = vb_re mean_im[i] = vb_im std_re[i] = sb_re std_im[i] = sb_im counts[i] = cnt pbar.set_postfix( fallback_pix_re=stats_re, fallback_pix_im=stats_im, ) return ( uv_grid_to_fft_image_convention(np.asarray(mean_re)), uv_grid_to_fft_image_convention(np.asarray(mean_im)), uv_grid_to_fft_image_convention(np.asarray(std_re)), uv_grid_to_fft_image_convention(np.asarray(std_im)), uv_grid_to_fft_image_convention(np.asarray(counts)), u_edges, v_edges )
def _make_w_edges( ww: np.ndarray, w_bins: Union[int, np.ndarray], *, w_range: Optional[Tuple[float, float]] = None, w_abs: bool = False, ) -> np.ndarray: """ Create w bin edges. Parameters ---------- ww : ndarray Full w array used to determine default range. w_bins : int or ndarray If int, number of uniform bins in w. If ndarray, explicit bin edges. w_range : (min, max), optional Range for uniform bins. If None, uses data min/max (after abs if w_abs=True). w_abs : bool If True, bins |w| instead of w. Returns ------- w_edges : ndarray, shape (Nw+1,) """ wvals = np.asarray(ww, dtype=float) if w_abs: wvals = np.abs(wvals) if isinstance(w_bins, np.ndarray): w_edges = np.asarray(w_bins, dtype=float) if w_edges.ndim != 1 or w_edges.size < 2: raise ValueError("If w_bins is an array, it must be 1D with length >= 2 (bin edges).") if not np.all(np.isfinite(w_edges)): raise ValueError("w bin edges contain non-finite values.") if np.any(np.diff(w_edges) <= 0): raise ValueError("w bin edges must be strictly increasing.") return w_edges # integer number of bins n_w = int(w_bins) if n_w <= 0: raise ValueError("If w_bins is an int, it must be >= 1.") if w_range is None: wmin = float(np.nanmin(wvals)) wmax = float(np.nanmax(wvals)) else: wmin, wmax = map(float, w_range) if not np.isfinite(wmin) or not np.isfinite(wmax): raise ValueError("w range contains non-finite values.") if wmax <= wmin: raise ValueError(f"Invalid w range: max ({wmax}) must be > min ({wmin}).") return np.linspace(wmin, wmax, n_w + 1, dtype=float)
[docs] def grid_cube_all_stats_wbinned( *, frequencies: np.ndarray, uu: np.ndarray, vv: np.ndarray, ww: np.ndarray, vis_re: np.ndarray, vis_imag: np.ndarray, weight: np.ndarray, invvar_group_re: np.ndarray, # NEW: same shape as vis_re invvar_group_im: np.ndarray, # NEW: same shape as vis_imag npix: int = 501, fov_arcsec: Optional[float] = None, pad_uv: float = 0.0, w_bins: Union[int, np.ndarray] = 8, w_range: Optional[Tuple[float, float]] = None, w_abs: bool = False, window_name: Optional[str] = "kaiser_bessel", window_kwargs: Optional[dict] = None, window_fn: Optional[Callable] = None, p_metric: int = 1, # Std controls (kept aligned with bin_data) std_p: int = 1, std_workers: int = 6, std_min_effective: int = 5, tqdm_ncols: int = 200, n_eff_mode: str = "both" ) -> Tuple[ np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray ]: """ Grid complex visibilities into UVW-binned UV pixels using `bin_data`. """ # ----------------------- # Basic validation # ----------------------- if uu.shape != vv.shape or uu.shape != ww.shape: raise ValueError( f"uu, vv, ww must have the same shape. " f"Got uu={uu.shape}, vv={vv.shape}, ww={ww.shape}." ) if vis_re.shape != uu.shape or vis_imag.shape != uu.shape or weight.shape != uu.shape: raise ValueError("vis_re, vis_imag, weight must match uu/vv/ww shape.") if invvar_group_re.shape != uu.shape or invvar_group_im.shape != uu.shape: raise ValueError("invvar_group_re and invvar_group_im must match uu/vv/ww shape.") if uu.ndim < 2: raise ValueError("Expected uu/vv/ww to be shaped (F, ...). Got ndim < 2.") # ----------------------- # UV grid + window binding # ----------------------- u_edges, v_edges, delta_u, trunc_r = make_uv_grid( uu, vv, npix=npix, pad_uv=pad_uv, fov_arcsec=fov_arcsec, warn_crop=True ) centers = build_grid_centers(u_edges, v_edges) w_edges = _make_w_edges(ww, w_bins, w_range=w_range, w_abs=w_abs) Nw = len(w_edges) - 1 if window_fn is not None: window = _bind_window(window_fn, pixel_size=delta_u, window_kwargs=window_kwargs) else: if window_name is None: raise ValueError("Provide either window_name or a ready-made window_fn.") window = _window_from_name(window_name, pixel_size=delta_u, window_kwargs=window_kwargs) # Dimensions F = uu.shape[0] Nu = len(u_edges) - 1 Nv = len(v_edges) - 1 mean_re = np.zeros((F, Nw, Nu, Nv), dtype=np.float64) std_re = np.full((F, Nw, Nu, Nv), np.nan, dtype=np.float64) mean_im = np.zeros((F, Nw, Nu, Nv), dtype=np.float64) std_im = np.full((F, Nw, Nu, Nv), np.nan, dtype=np.float64) counts = np.zeros((F, Nw, Nu, Nv), dtype=np.float64) # ----------------------- # Main loop over channels # ----------------------- pbar = tqdm(range(F), unit="channel", desc="Channels", ncols=tqdm_ncols) for i in pbar: # Flatten channel data u_all = uu[i].ravel() v_all = vv[i].ravel() wvals = ww[i].ravel().astype(float) re_all = vis_re[i].ravel() im_all = vis_imag[i].ravel() wgt_all = weight[i].ravel() invv_re_all = invvar_group_re[i].ravel() invv_im_all = invvar_group_im[i].ravel() if w_abs: wvals = np.abs(wvals) # Assign each datum to a w-bin # right=False means bins are [edge_k, edge_{k+1}) wbin = np.digitize(wvals, w_edges, right=False) - 1 valid = (wbin >= 0) & (wbin < Nw) # Channel diagnostics ch_fallback_re = 0 # number of UV pixels that triggered low-info fallback (Re) ch_fallback_im = 0 # number of UV pixels that triggered low-info fallback (Im) ch_nan_re = 0 # NaN pixels remaining (Re) ch_nan_im = 0 # NaN pixels remaining (Im) wbar = tqdm( range(Nw), unit="wbin", desc=f"w-bins (ch {i+1}/{F})", leave=False, ncols=tqdm_ncols, ) for b in wbar: sel = valid & (wbin == b) if not np.any(sel): wbar.set_postfix_str("empty") continue u_b = u_all[sel] v_b = v_all[sel] re_b = re_all[sel] im_b = im_all[sel] wgt_b = wgt_all[sel] invv_re_b = invv_re_all[sel] invv_im_b = invv_im_all[sel] # Precompute geometry for this UVW subset uv_tree, grid_tree, pairs = precompute_pairs( u_b, v_b, centers, trunc_r, p_metric=p_metric ) # ----------------------- # Pass A: regular UVW gridding (means/counts unchanged) # ----------------------- vb_re = bin_data( u_b, v_b, re_b, wgt_b, None, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="mean", verbose=0 ) vb_im = bin_data( u_b, v_b, im_b, wgt_b, None, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="mean", verbose=0 ) cnt = bin_data( u_b, v_b, re_b, wgt_b, None, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="count", verbose=0 ) sb_re, stats_re = bin_data( u_b, v_b, re_b, wgt_b, invv_re_b, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="std", verbose=0, std_min_effective=std_min_effective, std_workers=std_workers, std_p=std_p, collect_stats=True, n_eff_mode = n_eff_mode, ) sb_im, stats_im = bin_data( u_b, v_b, im_b, wgt_b, invv_im_b, (u_edges, v_edges), window, trunc_r, uv_tree, grid_tree, pairs, statistics_fn="std", verbose=0, std_min_effective=std_min_effective, std_workers=std_workers, std_p=std_p, collect_stats=True, n_eff_mode = n_eff_mode, ) # Store mean_re[i, b] = vb_re mean_im[i, b] = vb_im std_re[i, b] = sb_re std_im[i, b] = sb_im counts[i, b] = cnt # Diagnostics fallback_re_bin = int(stats_re) # n_fallback returned by bin_data fallback_im_bin = int(stats_im) nan_re_bin = int(np.isnan(sb_re).sum()) nan_im_bin = int(np.isnan(sb_im).sum()) ch_fallback_re += fallback_re_bin ch_fallback_im += fallback_im_bin ch_nan_re += nan_re_bin ch_nan_im += nan_im_bin wbar.set_postfix( fallback_re=fallback_re_bin, fallback_im=fallback_im_bin, nan_re=nan_re_bin, nan_im=nan_im_bin, ) pbar.set_postfix( w_bins=Nw, fallback_re=ch_fallback_re, fallback_im=ch_fallback_im, nan_re=ch_nan_re, nan_im=ch_nan_im, ) # Keep final axis-flip behavior unchanged return ( uv_grid_to_fft_image_convention(np.asarray(mean_re)), uv_grid_to_fft_image_convention(np.asarray(mean_im)), uv_grid_to_fft_image_convention(np.asarray(std_re)), uv_grid_to_fft_image_convention(np.asarray(std_im)), uv_grid_to_fft_image_convention(np.asarray(counts)), u_edges, v_edges, w_edges, )