diff --git a/papermill/s3.py b/papermill/s3.py index c756e7c7..d54a0898 100644 --- a/papermill/s3.py +++ b/papermill/s3.py @@ -282,7 +282,7 @@ def cat( size = 0 bytes_read = 0 err = None - undecoded = '' + undecoded = b'' if key: # try to read the file multiple times for i in range(100): @@ -317,8 +317,8 @@ def cat( if encoding and not raw: try: - decoded = undecoded + s.decode(encoding) - undecoded = '' + decoded = (undecoded + s).decode(encoding) + undecoded = b'' yield decoded except UnicodeDecodeError: undecoded += s diff --git a/papermill/tests/test_s3.py b/papermill/tests/test_s3.py index bf006830..ebc15cce 100644 --- a/papermill/tests/test_s3.py +++ b/papermill/tests/test_s3.py @@ -152,6 +152,7 @@ def test_s3_defaults(): test_string = 'Hello' test_file_path = 'notebooks/s3/s3_in/s3-simple_notebook.ipynb' test_empty_file_path = 'notebooks/s3/s3_in/s3-empty.ipynb' +test_unicode_file_path = 'notebooks/s3/s3_in/s3-unicode.txt' with open(os.path.join(local_dir, test_file_path)) as f: test_nb_content = f.read() @@ -171,11 +172,13 @@ def s3_client(): client.create_bucket(Bucket=test_bucket_name, CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) client.put_object(Bucket=test_bucket_name, Key=test_file_path, Body=test_nb_content) client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body='') + client.put_object(Bucket=test_bucket_name, Key=test_unicode_file_path, Body='') yield S3() try: client.delete_object(Bucket=test_bucket_name, Key=test_file_path) client.delete_object(Bucket=test_bucket_name, Key=f"{test_file_path}.txt") client.delete_object(Bucket=test_bucket_name, Key=test_empty_file_path) + client.delete_object(Bucket=test_bucket_name, Key=test_unicode_file_path) except Exception: pass mock_aws.stop() @@ -214,5 +217,14 @@ def test_s3_listdir(s3_client): s3_dir = f"s3://{test_bucket_name}/{dir_name}" s3_path = f"s3://{test_bucket_name}/{test_file_path}" dir_listings = s3_client.listdir(s3_dir) - assert len(dir_listings) == 2 + assert len(dir_listings) == 3 assert s3_path in dir_listings + + +def test_s3_read_multibyte_chunks(s3_client): + s3_path = f"s3://{test_bucket_name}/{test_unicode_file_path}" + multibyte_content = "第一行\n第二行" + s3_client.cp_string(multibyte_content, s3_path) + + cat_data = ''.join(s3_client.cat(s3_path, buffersize=1)) + assert cat_data == multibyte_content