-
Notifications
You must be signed in to change notification settings - Fork 228
[WIP] Converts dataframe to/from named numpy arrays #4
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| from pyspark.ml.regression import LinearRegressionModel | ||
| from pyspark.mllib.linalg import DenseVector, SparseVector, Vectors, VectorUDT | ||
| from pyspark.sql.functions import udf | ||
| from pyspark.sql import SQLContext | ||
|
|
||
| from udt import CSRVectorUDT | ||
| from util import _new_java_obj, _randomUID | ||
|
|
@@ -36,6 +37,7 @@ def __init__(self, sc): | |
| :param sc: SparkContext | ||
| """ | ||
| self.sc = sc | ||
| self.sqlContext = SQLContext(self.sc) | ||
| # For conversions sklearn -> Spark | ||
| self._skl2spark_classes = { | ||
| SKL_LogisticRegression : | ||
|
|
@@ -161,3 +163,84 @@ def toScipy(self, X): | |
| else: | ||
| raise TypeError("Converter.toScipy expected numpy.ndarray of" | ||
| " scipy.sparse.csr.csr_matrix instances, but found: %s" % type(X)) | ||
|
|
||
| @staticmethod | ||
| def _analyze_element(x): | ||
| if type(x) is float: | ||
| return (x, np.double) | ||
| if type(x) is int: | ||
| return (x, np.int) | ||
| if type(x) is long: | ||
| return (x, np.long) | ||
| if type(x) is DenseVector: | ||
| return (x.toArray(), (np.double, len(x.toArray()))) | ||
| # TODO(tjh) support sparse arrays | ||
| raise ValueError("The type %s could not be understood. Element was %s" % (type(x), x)) | ||
|
|
||
| @staticmethod | ||
| def _analyze_df(df): | ||
| """ Converts a dataframe into a numpy array. | ||
| """ | ||
| rows = df.collect() | ||
| conversions = [[Converter._analyze_element(x) for x in row] for row in rows] | ||
| types = [t for d, t in conversions[0]] | ||
| data = [tuple([d for d, t in labeled_elts]) for labeled_elts in conversions] | ||
| names = list(df.columns) | ||
| dt = np.dtype({'names': names, 'formats': types}) | ||
| arr = np.array(data, dtype=dt) | ||
| return arr | ||
|
|
||
| def pack_DataFrame(self, **args): | ||
|
||
| """ Converts a set of numpy arrays into a single dataframe. | ||
|
||
|
|
||
| The argument name is used to infer the name of the column. The columns may not be in | ||
| the order they are provided. Each column needs to have the same number of elements. | ||
|
|
||
| Example: | ||
|
|
||
| >>> X = np.zeros((10,4)) | ||
| >>> y = np.ones(10) | ||
| >>> df = conv.packDataFrame(x = X, y = y) | ||
| >>> df.printSchema() | ||
| root | ||
| |-- y: double (nullable = true) | ||
| |-- x: vector (nullable = true) | ||
| """ | ||
| def convert(z): | ||
| if len(z.shape) == 1: | ||
| return z.tolist() | ||
| if len(z.shape) == 2: | ||
| return [Vectors.dense(row) for row in z.tolist()] | ||
| assert False, (z.shape) | ||
| pairs = [(name, convert(data)) for (name, data) in args.items()] | ||
| vecs = zip(*[data for (_, data) in pairs]) | ||
| names = [name for (name, _) in pairs] | ||
| return self.sqlContext.createDataFrame(vecs, names) | ||
|
|
||
| @staticmethod | ||
| def df_to_numpy(df, *args): | ||
|
||
| """ Converts a dataframe into a (local) numpy array. Each column is named after the same | ||
| column name in the data frame. | ||
|
|
||
| The varargs provide (in order) the list of columns to extract from the dataframe. | ||
| If none are provided, all the columns from the dataframe are extracted. | ||
|
|
||
| This method only handles basic numerical types, or dense vectors with the same length. | ||
|
|
||
| Note: it is not particularly optimized, do not push it too hard. | ||
|
|
||
| Example: | ||
| >>> z = dataFrameColumn(df) | ||
|
||
| >>> z['x'].dtype, z['x'].shape | ||
| >>> z = dataFrameColumn(df, 'y') | ||
| >>> z['y'].dtype, z['y'].shape | ||
| """ | ||
| column_names = df.columns | ||
| if not args: | ||
| args = column_names | ||
| column_nameset = set(column_names) | ||
| for name in args: | ||
| assert name in column_nameset, (name, column_names) | ||
| # Just get the interesting columns | ||
| projected = df.select(*args) | ||
| return Converter._analyze_df(projected) | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import sklearn | ||
| import numpy as np | ||
| import numpy.random as rd | ||
| import unittest | ||
|
|
||
| from .test_common import get_context | ||
| from pdspark import Converter | ||
|
|
||
| sc = get_context() | ||
|
|
||
| n = 5 | ||
| A = rd.rand(n,4) | ||
| B = rd.rand(n) | ||
| C = rd.randint(10, size=n) | ||
|
|
||
| class MLlibTestCase(unittest.TestCase): | ||
|
|
||
| def setUp(self): | ||
|
||
| self.sc = sc | ||
| self.conv = Converter(self.sc) | ||
| self.n = 5 | ||
|
|
||
| def test_pack(self): | ||
| df = self.conv.pack_DataFrame(a=A, b=B, c=C) | ||
| dt = dict(df.dtypes) | ||
| assert dt == {'a':'vector', 'b': 'double', 'c': 'bigint'}, dt | ||
| df.collect() # Force creation | ||
|
||
|
|
||
| def test_unpack(self): | ||
| df = self.conv.pack_DataFrame(a=A, b=B, c=C) | ||
| Z = Converter.df_to_numpy(df) | ||
|
|
||
| assert np.all(Z['a'] == A), (Z['a'], A) | ||
| assert np.all(Z['b'] == B), (Z['B'], B) | ||
|
||
| assert np.all(Z['c'] == C), (Z['c'], C) | ||
| assert Z['c'].dtype == C.dtype | ||
|
|
||
| def test_unpack_select(self): | ||
| df = self.conv.pack_DataFrame(a=A, b=B, c=C) | ||
| Z = Converter.df_to_numpy(df, 'a', 'c') | ||
|
|
||
| assert np.all(Z['a'] == A), (Z['a'], A) | ||
| assert np.all(Z['c'] == C), (Z['c'], C) | ||
| assert 'b' not in Z.dtype.fields | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| """ | ||
| Common variables for all tests. | ||
| """ | ||
| from pyspark import SparkContext | ||
|
|
||
| __all__ = ['get_context'] | ||
|
|
||
| _sc = None | ||
|
|
||
| def get_context(): | ||
| global _sc | ||
| if not _sc: | ||
| _sc = SparkContext('local[4]', "spark-sklearn tests") | ||
| return _sc | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this will be very slow for larger data? That's OK for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will; we can always improve it later.