CASA I/O: Recommended pipeline for extracting data from very large measurement sets#
Sometimes your measurement set is too big for casatools.ms to handle in memory all at once. In that case, you can use the following code as inspiration for a way to work around this issue. The I/O for loop is very slow, so consider yourself warned! To speed it up a bit, you can also try modifying this code to use casacore.tables instead, as done in this tutorial.
import os
import gc
import numpy as np
from tqdm import tqdm
import casatools
from astropy.constants import c
def _get_any_case(d, key: str):
"""Return d[key] with CASA's sometimes-annoying key casing."""
if key in d:
return d[key]
kl = key.lower()
if kl in d:
return d[kl]
ku = key.upper()
if ku in d:
return d[ku]
raise KeyError(f"Key {key!r} not found. Available keys: {list(d.keys())}")
ms_path = "/home/darthbarth/Big_red/SPT_GALAXIES/0346/SPT0346/SPT0346_bary_20kms.ms.split.cal.contsub"
assert os.path.exists(ms_path), f"MS not found: {ms_path}"
datadescid = 0
data_column = "DATA" # or "CORRECTED_DATA"
out_npz = "cube_extracted.npz"
# ---- Figure out total number of channels for this datadescid ----
msmd = casatools.msmetadata()
msmd.open(ms_path)
spw_id = msmd.spwfordatadesc(int(datadescid))
nchan_total = int(msmd.nchan(spw_id))
msmd.done()
print(f"datadescid={datadescid} -> spw={spw_id} -> nchan_total={nchan_total}")
datadescid=0 -> spw=0 -> nchan_total=80
# ---- Open MS ----
ms = casatools.ms()
ms.open(ms_path)
ms.selectinit(datadescid=int(datadescid))
True
# ---- Read ONE channel to learn npol/nrow + sanity-check column names ----
ms.selectchannel(nchan=1, start=0, width=1, inc=1)
q0 = ms.getdata(["FLAG", "AXIS_INFO", data_column])
flag0 = _get_any_case(q0, "FLAG") # (npol, 1, nrow)
data0 = _get_any_case(q0, data_column) # (npol, 1, nrow) complex
npol, _, nrow = flag0.shape
del q0, flag0, data0
gc.collect()
print(f"Inferred shapes: npol={npol}, nrow={nrow}")
Inferred shapes: npol=2, nrow=4557168
# ---- Read STATIC per-row columns ONCE (do not include FLAG/DATA here) ----
# Important: keep whatever channel selection is active; these columns are per-row anyway.
q_static = ms.getdata(["UVW", "ANTENNA1", "ANTENNA2", "WEIGHT", "FLAG_ROW"])
ant1 = _get_any_case(q_static, "ANTENNA1") # (nrow,)
ant2 = _get_any_case(q_static, "ANTENNA2") # (nrow,)
uvw_m_full = _get_any_case(q_static, "UVW") # (3, nrow) meters
weight_raw = _get_any_case(q_static, "WEIGHT") # typically (npol,nrow) or (nrow,) or (1,nrow)
flag_row_full = _get_any_case(q_static, "FLAG_ROW") # (nrow,)
del q_static
gc.collect()
0
# ---- Normalize WEIGHT to shape (npol, nrow) ----
weight = np.asarray(weight_raw)
if weight.ndim == 1:
if weight.shape[0] != nrow:
raise ValueError(f"Unexpected WEIGHT shape {weight.shape}, expected (nrow,) with nrow={nrow}")
weight = np.tile(weight[None, :], (npol, 1))
elif weight.ndim == 2:
if weight.shape[1] != nrow:
raise ValueError(f"Unexpected WEIGHT shape {weight.shape}, expected (*, nrow) with nrow={nrow}")
if weight.shape[0] == 1 and npol > 1:
weight = np.tile(weight, (npol, 1))
elif weight.shape[0] != npol:
# CASA sometimes returns e.g. (2,nrow) while DATA has 4 pols depending on correlation setup.
# If that happens, it's safer to bail loudly than silently broadcast wrong weights.
raise ValueError(f"WEIGHT has {weight.shape[0]} pols but FLAG/DATA has npol={npol}.")
else:
raise ValueError(f"Unexpected WEIGHT ndim={weight.ndim}, shape={weight.shape}")
# ---- Remove autocorrelations ONCE (static selection) ----
xc = np.where(ant1 != ant2)[0]
ant1 = ant1[xc]
ant2 = ant2[xc]
uvw_m = uvw_m_full[:, xc] # (3, nvis)
flag_row = flag_row_full[xc] # (nvis,)
weight = weight[:, xc] # (npol, nvis)
nvis = uvw_m.shape[1]
print(f"After autocorr removal: nvis={nvis}, uvw_m={uvw_m.shape}, weight={weight.shape}")
After autocorr removal: nvis=4361040, uvw_m=(3, 4361040), weight=(2, 4361040)
# Collapse per-pol weight to a single per-row weight AFTER pol-avg (same as your script)
weight_row = np.sum(weight, axis=0).astype(np.float32) # (nvis,)
# ---- Preallocate OUTPUT ARRAYS (this is the only big memory you keep) ----
chan_freq_all = np.empty((nchan_total,), dtype=np.float64)
vis_all = np.empty((nchan_total, nvis), dtype=np.complex64)
mask_all = np.empty((nchan_total, nvis), dtype=bool)
u_all = np.empty((nchan_total, nvis), dtype=np.float32)
v_all = np.empty((nchan_total, nvis), dtype=np.float32)
w_all = np.empty((nchan_total, nvis), dtype=np.float32)
# weight cube: identical per-channel in your pipeline (tile of weight_row)
# If you really want it stored per-channel, allocate; otherwise you can save just weight_row.
weight_all = np.tile(weight_row[None, :], (nchan_total, 1)) # float32 already
# Convenience for uvw conversion
u_m, v_m, w_m = uvw_m # each (nvis,)
# ---- Main loop: process 2 channels at a time ----
for ch0 in tqdm(range(0, nchan_total, 2), desc="Processing channels", unit="chanpair"):
nch = min(2, nchan_total - ch0)
ms.selectchannel(nchan=nch, start=ch0, width=1, inc=1)
q = ms.getdata(["FLAG", "AXIS_INFO", data_column])
# pull data/flag with casing safety
flag = _get_any_case(q, "FLAG") # (npol, nch, nrow)
data = _get_any_case(q, data_column) # (npol, nch, nrow) complex
info = _get_any_case(q, "AXIS_INFO")
chan_freq_hz = info["freq_axis"]["chan_freq"].reshape(-1) # (nch,)
# select only cross-correlations
flag = flag[:, :, xc] # (npol, nch, nvis)
data = data[:, :, xc] # (npol, nch, nvis)
# apply FLAG_ROW to every pol+chan
flag = np.logical_or(flag, flag_row[None, None, :])
# weighted pol-average (same math as your script)
# weight: (npol, nvis) -> broadcast to (npol, nch, nvis)
w_b = weight[:, None, :] # (npol, 1, nvis)
wsum = np.sum(w_b, axis=0) # (1, nvis)
wsum_safe = np.where(wsum > 0, wsum, 1.0)
data_pc = np.sum(data * w_b, axis=0) / wsum_safe # (nch, nvis)
flag_pc = np.any(flag, axis=0) # (nch, nvis)
mask_pc = ~flag_pc # (nch, nvis)
# UVW meters -> lambda for these channels
nu = chan_freq_hz[:, None] # (nch, 1)
u_lam = (u_m[None, :] * nu / c.value).astype(np.float32) # (nch, nvis)
v_lam = (v_m[None, :] * nu / c.value).astype(np.float32)
w_lam = (w_m[None, :] * nu / c.value).astype(np.float32)
# write into preallocated arrays
sl = slice(ch0, ch0 + nch)
chan_freq_all[sl] = chan_freq_hz
vis_all[sl] = data_pc.astype(np.complex64)
mask_all[sl] = mask_pc
u_all[sl] = u_lam
v_all[sl] = v_lam
w_all[sl] = w_lam
# free chunk memory
del q, flag, data, info, chan_freq_hz, data_pc, flag_pc, mask_pc, u_lam, v_lam, w_lam, nu, wsum, wsum_safe, w_b
gc.collect()
# ---- cleanup CASA selection + close ----
ms.selectinit(reset=True)
ms.close()
Processing channels: 22%|██▎ | 9/40 [1:21:03<4:39:15, 540.51s/chanpair]
# ---- Ensure increasing frequency (same behavior as your Step 8) ----
is_increasing = np.all(np.diff(chan_freq_all) > 0)
print("Frequency increasing?", is_increasing)
if not is_increasing:
chan_freq_all = chan_freq_all[::-1].copy()
vis_all = vis_all[::-1].copy()
mask_all = mask_all[::-1].copy()
u_all = u_all[::-1].copy()
v_all = v_all[::-1].copy()
w_all = w_all[::-1].copy()
weight_all = weight_all[::-1].copy()
# ---- Save exactly like before ----
np.savez(
out_npz,
chan_freq_hz=chan_freq_all,
u=u_all, v=v_all, w=w_all,
vis=vis_all,
weight=weight_all,
mask=mask_all,
)
print(f"Saved: {out_npz}")
print("Final cube shapes:")
print(" chan_freq_hz:", chan_freq_all.shape)
print(" vis:", vis_all.shape, vis_all.dtype)
print(" mask:", mask_all.shape, mask_all.dtype)
print(" u/v/w:", u_all.shape, v_all.shape, w_all.shape, u_all.dtype)
print(" weight:", weight_all.shape, weight_all.dtype)