diff --git a/CHANGES.rst b/CHANGES.rst index 1879f11851..565c5a4574 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,11 @@ API changes Service fixes and enhancements ------------------------------ +mast +^^^^ + +- Filtering by file extension or by a string column is now case-insensitive in ``MastMissions.filter_products`` + and ``Observations.filter_products``. [#3427] Infrastructure, Utility and Other Changes and Additions diff --git a/astroquery/mast/missions.py b/astroquery/mast/missions.py index 82ed5c3218..ba87307d30 100644 --- a/astroquery/mast/missions.py +++ b/astroquery/mast/missions.py @@ -517,12 +517,7 @@ def filter_products(self, products, *, extension=None, **filters): # Filter by file extension, if provided if extension: - extensions = [extension] if isinstance(extension, str) else extension - ext_mask = np.array( - [not isinstance(x, np.ma.core.MaskedConstant) and any(x.endswith(ext) for ext in extensions) - for x in products["filename"]], - dtype=bool - ) + ext_mask = utils.apply_extension_filter(products, extension, 'filename') filter_mask &= ext_mask # Apply column-based filters diff --git a/astroquery/mast/observations.py b/astroquery/mast/observations.py index 0c73a5c124..320af9ed78 100644 --- a/astroquery/mast/observations.py +++ b/astroquery/mast/observations.py @@ -594,12 +594,7 @@ def filter_products(self, products, *, mrp_only=False, extension=None, **filters # Filter by file extension, if provided if extension: - extensions = [extension] if isinstance(extension, str) else extension - ext_mask = np.array( - [not isinstance(x, np.ma.core.MaskedConstant) and any(x.endswith(ext) for ext in extensions) - for x in products["productFilename"]], - dtype=bool - ) + ext_mask = utils.apply_extension_filter(products, extension, 'productFilename') filter_mask &= ext_mask # Apply column-based filters diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 8c1f07147f..4b1613f4fb 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -428,7 +428,7 @@ def test_missions_filter_products(patch_post): assert all(~((filtered['size'] >= 14400) & (filtered['size'] <= 17280))) # Negate a string match - filtered = mast.MastMissions.filter_products(products, category='!CALIBRATED') + filtered = mast.MastMissions.filter_products(products, category='!calibrated') assert all(filtered['category'] != 'CALIBRATED') # Negate one string in a list @@ -798,7 +798,7 @@ def test_observations_get_product_list(patch_post): def test_observations_filter_products(patch_post): products = mast.Observations.get_product_list('2003738726') filtered = mast.Observations.filter_products(products, - productType=["SCIENCE"], + productType=["sCiEnCE"], mrp_only=False) assert isinstance(filtered, Table) assert len(filtered) == 7 @@ -808,8 +808,10 @@ def test_observations_filter_products(patch_post): assert all(filtered['productGroupDescription'] == 'Minimum Recommended Products') # Filter by extension - filtered = mast.Observations.filter_products(products, extension='fits') + filtered = mast.Observations.filter_products(products, extension='FITS') assert len(filtered) > 0 + filtered = mast.Observations.filter_products(products, extension=['png']) + assert len(filtered) == 0 # Numeric filtering filtered = mast.Observations.filter_products(products, size='<50000') diff --git a/astroquery/mast/utils.py b/astroquery/mast/utils.py index fa3ff17f71..e74973e2b7 100644 --- a/astroquery/mast/utils.py +++ b/astroquery/mast/utils.py @@ -445,6 +445,38 @@ def remove_duplicate_products(data_products, uri_key): return unique_products +def apply_extension_filter(products, extension, filename_key): + """ + Applies an extension filter to a product table. + + Parameters + ---------- + products : `~astropy.table.Table` + The product table to filter. + extension : str + The extension to filter by (e.g., 'fits', 'csv'). + filename_key : str + The column name representing the filename of a product. + + Returns + ------- + ext_mask : `numpy.ndarray` + A boolean mask indicating which rows of the product table have the specified extension. + """ + # Normalize extensions to lowercase + extensions = [extension] if isinstance(extension, str) else extension + extensions = tuple(ext.lower() for ext in extensions) + + # Build mask + ext_mask = np.array( + [not isinstance(x, np.ma.core.MaskedConstant) + and str(x).lower().endswith(extensions) + for x in products[filename_key]], + dtype=bool + ) + return ext_mask + + def _combine_positive_negative_masks(mask_funcs): """ Combines a list of mask functions into a single mask according to: @@ -571,7 +603,9 @@ def apply_column_filters(products, filters): v = val[1:] if is_negated else val def func(col, v=v): - return np.isin(col, [v]) + # Normalize both column values and filter to lowercase strings for case-insensitive comparison + col_lower = np.char.lower(col.astype(str)) + return np.isin(col_lower, [v.lower()]) mask_funcs.append((func, is_negated)) this_mask = _combine_positive_negative_masks(mask_funcs)(col_data)