Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit b786e96

Browse files
author
Frédéric Branchaud-Charron
authored
Merge pull request #37 from juharris/pad_str
pad_sequences: Add support for string value.
2 parents dad7fcc + bfb7297 commit b786e96

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

keras_preprocessing/sequence.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import random
1010
import json
1111
from six.moves import range
12+
import six
1213

1314

1415
def pad_sequences(sequences, maxlen=None, dtype='int32',
@@ -35,12 +36,13 @@ def pad_sequences(sequences, maxlen=None, dtype='int32',
3536
sequences: List of lists, where each element is a sequence.
3637
maxlen: Int, maximum length of all sequences.
3738
dtype: Type of the output sequences.
39+
To pad sequences with variable length strings, you can use `object`.
3840
padding: String, 'pre' or 'post':
3941
pad either before or after each sequence.
4042
truncating: String, 'pre' or 'post':
4143
remove values from sequences larger than
4244
`maxlen`, either at the beginning or at the end of the sequences.
43-
value: Float, padding value.
45+
value: Float or String, padding value.
4446
4547
# Returns
4648
x: Numpy array with shape `(len(sequences), maxlen)`
@@ -70,7 +72,13 @@ def pad_sequences(sequences, maxlen=None, dtype='int32',
7072
sample_shape = np.asarray(s).shape[1:]
7173
break
7274

73-
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
75+
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
76+
if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
77+
raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
78+
"You should set `dtype=object` for variable length strings."
79+
.format(dtype, type(value)))
80+
81+
x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
7482
for idx, s in enumerate(sequences):
7583
if not len(s):
7684
continue # empty list/array was found

tests/sequence_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import numpy as np
44
from numpy.testing import assert_allclose
5+
from numpy.testing import assert_equal
56
from numpy.testing import assert_raises
67

78
import keras
@@ -35,6 +36,27 @@ def test_pad_sequences():
3536
assert_allclose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])
3637

3738

39+
def test_pad_sequences_str():
40+
a = [['1'], ['1', '2'], ['1', '2', '3']]
41+
42+
# test padding
43+
b = sequence.pad_sequences(a, maxlen=3, padding='pre', value='pad', dtype=object)
44+
assert_equal(b, [['pad', 'pad', '1'], ['pad', '1', '2'], ['1', '2', '3']])
45+
b = sequence.pad_sequences(a, maxlen=3, padding='post', value='pad', dtype='<U3')
46+
assert_equal(b, [['1', 'pad', 'pad'], ['1', '2', 'pad'], ['1', '2', '3']])
47+
48+
# test truncating
49+
b = sequence.pad_sequences(a, maxlen=2, truncating='pre', value='pad',
50+
dtype=object)
51+
assert_equal(b, [['pad', '1'], ['1', '2'], ['2', '3']])
52+
b = sequence.pad_sequences(a, maxlen=2, truncating='post', value='pad',
53+
dtype='<U3')
54+
assert_equal(b, [['pad', '1'], ['1', '2'], ['1', '2']])
55+
56+
with pytest.raises(ValueError, match="`dtype` int32 is not compatible with "):
57+
sequence.pad_sequences(a, maxlen=2, truncating='post', value='pad')
58+
59+
3860
def test_pad_sequences_vector():
3961
a = [[[1, 1]],
4062
[[2, 1], [2, 2]],

0 commit comments

Comments
 (0)