| 
 | 1 | +import pytest  | 
 | 2 | +from pytest import fixture  | 
 | 3 | +import pandas as pd  | 
 | 4 | +import numpy as np  | 
 | 5 | +import time  | 
 | 6 | +import pyarrow as pa  | 
 | 7 | +import dask.dataframe as dd  | 
 | 8 | +import s3fs  | 
 | 9 | +import moto.server  | 
 | 10 | +import sys  | 
 | 11 | + | 
 | 12 | + | 
 | 13 | +@fixture(scope="session")  | 
 | 14 | +def partitioned_dataset() -> dict:  | 
 | 15 | +    rows = 500000  | 
 | 16 | +    cds_df = pd.DataFrame(  | 
 | 17 | +        {  | 
 | 18 | +            "id": range(rows),  | 
 | 19 | +            "part_key": np.random.choice(["A", "B", "C", "D"], rows),  | 
 | 20 | +            "timestamp": np.random.randint(1051638817, 1551638817, rows),  | 
 | 21 | +            "int_value": np.random.randint(0, 60000, rows),  | 
 | 22 | +        }  | 
 | 23 | +    )  | 
 | 24 | +    return {"dataframe": cds_df, "partitioning_column": "part_key"}  | 
 | 25 | + | 
 | 26 | + | 
 | 27 | +def free_port():  | 
 | 28 | +    import socketserver  | 
 | 29 | + | 
 | 30 | +    with socketserver.TCPServer(("localhost", 0), None) as s:  | 
 | 31 | +        free_port = s.server_address[1]  | 
 | 32 | +        return free_port  | 
 | 33 | + | 
 | 34 | + | 
 | 35 | +@fixture(scope="session")  | 
 | 36 | +def moto_server():  | 
 | 37 | +    import subprocess  | 
 | 38 | + | 
 | 39 | +    port = free_port()  | 
 | 40 | +    process = subprocess.Popen([  | 
 | 41 | +        sys.executable,  | 
 | 42 | +        moto.server.__file__,  | 
 | 43 | +        '--port', str(port),  | 
 | 44 | +        '--host', 'localhost',  | 
 | 45 | +        's3'  | 
 | 46 | +    ])  | 
 | 47 | + | 
 | 48 | +    s3fs_kwargs = dict(  | 
 | 49 | +        client_kwargs={"endpoint_url": f'http://localhost:{port}'},  | 
 | 50 | +    )  | 
 | 51 | + | 
 | 52 | +    start = time.time()  | 
 | 53 | +    while True:  | 
 | 54 | +        try:  | 
 | 55 | +            fs = s3fs.S3FileSystem(skip_instance_cache=True, **s3fs_kwargs)  | 
 | 56 | +            fs.ls("/")  | 
 | 57 | +        except:  | 
 | 58 | +            if time.time() - start > 30:  | 
 | 59 | +                raise TimeoutError("Could not get a working moto server in time")  | 
 | 60 | +        time.sleep(0.1)  | 
 | 61 | + | 
 | 62 | +        break  | 
 | 63 | + | 
 | 64 | +    yield s3fs_kwargs  | 
 | 65 | + | 
 | 66 | +    process.terminate()  | 
 | 67 | + | 
 | 68 | + | 
 | 69 | +@fixture(scope="session")  | 
 | 70 | +def moto_s3fs(moto_server):  | 
 | 71 | +    return s3fs.S3FileSystem(**moto_server)  | 
 | 72 | + | 
 | 73 | + | 
 | 74 | +@fixture(scope="session")  | 
 | 75 | +def s3_bucket(moto_server):  | 
 | 76 | +    test_bucket_name = 'test'  | 
 | 77 | +    from botocore.session import Session  | 
 | 78 | +    # NB: we use the sync botocore client for setup  | 
 | 79 | +    session = Session()  | 
 | 80 | +    client = session.create_client('s3', **moto_server['client_kwargs'])  | 
 | 81 | +    client.create_bucket(Bucket=test_bucket_name, ACL='public-read')  | 
 | 82 | +    return test_bucket_name  | 
 | 83 | + | 
 | 84 | + | 
 | 85 | +@fixture(scope="session")  | 
 | 86 | +def partitioned_parquet_path(partitioned_dataset, moto_s3fs, s3_bucket):  | 
 | 87 | +    cds_df = partitioned_dataset["dataframe"]  | 
 | 88 | +    table = pa.Table.from_pandas(cds_df, preserve_index=False)  | 
 | 89 | +    path = s3_bucket + "/partitioned/dataset"  | 
 | 90 | +    import pyarrow.parquet  | 
 | 91 | + | 
 | 92 | +    pyarrow.parquet.write_to_dataset(  | 
 | 93 | +        table,  | 
 | 94 | +        path,  | 
 | 95 | +        filesystem=moto_s3fs,  | 
 | 96 | +        partition_cols=[  | 
 | 97 | +            partitioned_dataset["partitioning_column"]  | 
 | 98 | +        ],  # new parameter included  | 
 | 99 | +    )  | 
 | 100 | + | 
 | 101 | +    # storage_options = dict(use_listings_cache=False)  | 
 | 102 | +    # storage_options.update(docker_aws_s3.s3fs_kwargs)  | 
 | 103 | +    #  | 
 | 104 | +    # import dask.dataframe  | 
 | 105 | +    #  | 
 | 106 | +    # ddf = dask.dataframe.read_parquet(  | 
 | 107 | +    #     f"s3://{path}", storage_options=storage_options, gather_statistics=False  | 
 | 108 | +    # )  | 
 | 109 | +    # all_rows = ddf.compute()  | 
 | 110 | +    # assert "name" in all_rows.columns  | 
 | 111 | +    return path  | 
 | 112 | + | 
 | 113 | + | 
 | 114 | +@pytest.fixture(scope='session', params=[  | 
 | 115 | +    pytest.param("pyarrow"),  | 
 | 116 | +    pytest.param("fastparquet"),  | 
 | 117 | +])  | 
 | 118 | +def parquet_engine(request):  | 
 | 119 | +    return request.param  | 
 | 120 | + | 
 | 121 | + | 
 | 122 | +@pytest.fixture(scope='session', params=[  | 
 | 123 | +    pytest.param(False, id='gather_statistics=F'),  | 
 | 124 | +    pytest.param(True, id='gather_statistics=T'),  | 
 | 125 | +])  | 
 | 126 | +def gather_statistics(request):  | 
 | 127 | +    return request.param  | 
 | 128 | + | 
 | 129 | + | 
 | 130 | +def test_partitioned_read(partitioned_dataset, partitioned_parquet_path, moto_server, parquet_engine, gather_statistics):  | 
 | 131 | +    """The directory based reading is quite finicky"""  | 
 | 132 | +    storage_options = moto_server.copy()  | 
 | 133 | +    ddf = dd.read_parquet(  | 
 | 134 | +        f"s3://{partitioned_parquet_path}",  | 
 | 135 | +        storage_options=storage_options,  | 
 | 136 | +        gather_statistics=gather_statistics,  | 
 | 137 | +        engine=parquet_engine  | 
 | 138 | +    )  | 
 | 139 | + | 
 | 140 | +    assert 'part_key' in ddf.columns  | 
 | 141 | +    actual = ddf.compute().sort_values('id')  | 
 | 142 | + | 
 | 143 | +    assert actual == partitioned_dataset["dataframe"]  | 
 | 144 | + | 
 | 145 | + | 
 | 146 | +def test_non_partitioned_read(partitioned_dataset, partitioned_parquet_path, moto_server, parquet_engine, gather_statistics):  | 
 | 147 | +    """The directory based reading is quite finicky"""  | 
 | 148 | +    storage_options = moto_server.copy()  | 
 | 149 | +    ddf = dd.read_parquet(  | 
 | 150 | +        f"s3://{partitioned_parquet_path}/part_key=A",  | 
 | 151 | +        storage_options=storage_options,  | 
 | 152 | +        gather_statistics=gather_statistics,  | 
 | 153 | +        engine=parquet_engine  | 
 | 154 | +    )  | 
 | 155 | + | 
 | 156 | +    if parquet_engine == 'pyarrow':  | 
 | 157 | +        assert 'part_key' in ddf.columns  | 
 | 158 | +    actual = ddf.compute().sort_values('id')  | 
 | 159 | +    expected = partitioned_dataset["dataframe"]  | 
 | 160 | +    expected = expected.loc[expected.part_key == "A"]  | 
 | 161 | + | 
 | 162 | +    assert actual == expected  | 
0 commit comments