Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions python/spark_sklearn/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyspark.ml.regression import LinearRegressionModel
from pyspark.mllib.linalg import DenseVector, SparseVector, Vectors, VectorUDT
from pyspark.sql.functions import udf
from pyspark.sql import SQLContext

from udt import CSRVectorUDT
from util import _new_java_obj, _randomUID
Expand All @@ -36,6 +37,7 @@ def __init__(self, sc):
:param sc: SparkContext
"""
self.sc = sc
self.sqlContext = SQLContext(self.sc)
# For conversions sklearn -> Spark
self._skl2spark_classes = {
SKL_LogisticRegression :
Expand Down Expand Up @@ -161,3 +163,98 @@ def toScipy(self, X):
else:
raise TypeError("Converter.toScipy expected numpy.ndarray of"
" scipy.sparse.csr.csr_matrix instances, but found: %s" % type(X))

@staticmethod
def _analyze_element(x):
if type(x) is float:
return (x, np.double)
if type(x) is int:
return (x, np.int)
if type(x) is long:
return (x, np.long)
if type(x) is DenseVector:
return (x.toArray(), (np.double, len(x.toArray())))
# TODO(tjh) support sparse arrays
raise ValueError("The type %s could not be understood. Element was %s" % (type(x), x))

@staticmethod
def _analyze_df(df):
""" Converts a dataframe into a numpy array.
"""
rows = df.collect()
conversions = [[Converter._analyze_element(x) for x in row] for row in rows]
types = [t for d, t in conversions[0]]
data = [tuple([d for d, t in labeled_elts]) for labeled_elts in conversions]
names = list(df.columns)
dt = np.dtype({'names': names, 'formats': types})
arr = np.array(data, dtype=dt)
return arr

def numpy_to_df(self, **kwargs):
""" Converts a set of numpy arrays into a single dataframe.

The argument name is used to infer the name of the column. The value of the argument is a
numpy array with a shape of length 0, 1, or 2. The dtype is one of the data types supported
by sparkSQL. This includes np.double, np.float, np.int and np.long.
See the whole list of supported types in the Spark SQL documentation:
http://spark.apache.org/docs/latest/sql-programming-guide.html#data-types

The columns may not be in the order they are provided. Each column needs to have the same
number of elements.

:return: A pyspark.sql.DataFrame object, or raises a ValueError if the data type could not
be understood.

Example:

>>> X = np.zeros((10,4))
>>> y = np.ones(10)
>>> df = conv.numpy_to_df(x = X, y = y)
>>> df.printSchema()
root
|-- y: double (nullable = true)
|-- x: vector (nullable = true)
"""
def convert(z):
if len(z.shape) == 1:
return z.tolist()
if len(z.shape) == 2:
return [Vectors.dense(row) for row in z.tolist()]
raise ValueError("Cannot convert a numpy array with more than 2 dimensions")
pairs = [(name, convert(data)) for (name, data) in kwargs.items()]
vecs = zip(*[data for (_, data) in pairs])
names = [name for (name, _) in pairs]
return self.sqlContext.createDataFrame(vecs, names)

@staticmethod
def df_to_numpy(df, *args):
""" Converts a dataframe into a (local) numpy array. Each column is named after the same
column name in the data frame.

The varargs provide (in order) the list of columns to extract from the dataframe.
If none are provided, all the columns from the dataframe are extracted.

This method only handles basic numerical types, or dense vectors with the same length.

Note: it is not particularly optimized, do not push it too hard.

:param df: a pyspark.sql.DataFrame object
:param args: a list of strings that are column names in the dataframe
:return: a structured numpy array with the content of the data frame.

Example:
>>> z = conv.df_to_numpy(df)
>>> z['x'].dtype, z['x'].shape
>>> z = conv.df_to_numpy(df, 'y')
>>> z['y'].dtype, z['y'].shape
"""
column_names = df.columns
if not args:
args = column_names
column_nameset = set(column_names)
for name in args:
assert name in column_nameset, (name, column_names)
# Just get the interesting columns
projected = df.select(*args)

return Converter._analyze_df(projected)
50 changes: 50 additions & 0 deletions python/spark_sklearn/converter_np_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import numpy.random as rd
import unittest

from .converter import Converter
from .test_utils import create_sc

sc = create_sc()

n = 5
A = rd.rand(n,4)
B = rd.rand(n)
C = rd.randint(10, size=n)

class NumpyConverterTestCase(unittest.TestCase):

def setUp(self):
super(NumpyConverterTestCase, self).setUp()
self.sc = sc
self.conv = Converter(self.sc)

def test_pack(self):
df = self.conv.numpy_to_df(a=A, b=B, c=C)
dt = dict(df.dtypes)
assert dt == {'a':'vector', 'b': 'double', 'c': 'bigint'}, dt
z = df.collect()
assert len(z) == n
for row in z:
assert len(row) == 3, row
assert row['a'] is not None, row
assert row['b'] is not None, row
assert row['c'] is not None, row

def test_unpack(self):
df = self.conv.numpy_to_df(a=A, b=B, c=C)
Z = Converter.df_to_numpy(df)

assert np.all(Z['a'] == A), (Z['a'], A)
assert np.all(Z['b'] == B), (Z['b'], B)
assert np.all(Z['c'] == C), (Z['c'], C)
assert Z['c'].dtype == C.dtype

def test_unpack_select(self):
df = self.conv.numpy_to_df(a=A, b=B, c=C)
Z = Converter.df_to_numpy(df, 'a', 'c')

assert np.all(Z['a'] == A), (Z['a'], A)
assert np.all(Z['c'] == C), (Z['c'], C)
assert 'b' not in Z.dtype.fields