Source code for xpandas.transformers.transformer

from sklearn.base import BaseEstimator, TransformerMixin

from ..data_container import XDataFrame, XSeries


[docs]class XSeriesTransformer(BaseEstimator, TransformerMixin): ''' XSeriesTransformer is a base class for all custom transformers. XSeriesTransformer is a high level abstraction to transform XSeries of specific data_types to an another XSeries or XDataFrame. XSeriesTransformer encapsulates transformation and based on scikit-learn BaseEstimator http://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html ''' _TRANSFORM_ARG_FUNCTION_NAME = 'transform_function' def __init__(self, transform_function=None, data_types=None, name=None, **kwargs): ''' :param transform_function: Callable that performs actual transform :param data_types: list of data_type that this transformer can work with. if None, error might be raised at run time :param name: name for transformer. if none, class name is default :param kwargs: additional arguments ''' if transform_function is not None and not callable(transform_function): raise ValueError('transform_function must be callable') self.transform_function = transform_function self.data_types = data_types if name is None: self.name = self.__class__.__name__ else: self.name = name def _check_input(self, input_data): ''' Check that input valid: input_data is XSeries and transformer "knows" how to work with input_data.data_type. In error raise exception. ''' if type(input_data) != XSeries: raise ValueError('X must be XSeries type') elif type(input_data) == XSeries and self.data_types is not None \ and input_data.data_type not in self.data_types: raise ValueError('Estimator does not support {} type'.format(input_data.data_type))
[docs] def fit(self, X=None, y=None, **kwargs): ''' Fit transformer for giver data. Must be overwritten in child classes :param X: XSeries to fit transformer on :param y: Labels column for X :param kwargs: additional arguments for transformer :return: fitted self object ''' if X is not None: self._check_input(X) return self
def _transform_series(self, custom_series): ''' Helper method to transform XSeries :param custom_series: XSeries object :return: transformed XSeries. it could be XSeries or XDataFrame object ''' return custom_series.apply(func=self.transform_function, prefix=self.name)
[docs] def transform(self, X): ''' Apply transformation to X with current transformer :param X: input XSeries :param columns: deprecated :return: transformed XSeries. it could be XSeries or XDataFrame object ''' if not hasattr(self, self._TRANSFORM_ARG_FUNCTION_NAME): raise ValueError('You mast pass transform_function argument with a function') self._check_input(X) transform_series = self._transform_series(X) transform_series.index = X.index return transform_series
[docs]class XDataFrameTransformer(BaseEstimator, TransformerMixin): ''' XDataFrameTransformer is a set of XSeriesTransformer instances. XDataFrameTransformer can transform XDataFrame object to another XDataFrame based on set of XSeriesTransformer transformers. ''' def _validate_transformations(self, transformations): for k, v in transformations.items(): if not isinstance(k, str): raise TypeError('Key must be a string {}'.format(k)) if isinstance(v, list): for t in v: if not isinstance(t, XSeriesTransformer): raise TypeError('All objects of {} must be a Transformer object. Issue with {}'.format(v, t)) elif not isinstance(v, XSeriesTransformer): raise TypeError('Value must be a Transformer object {}'.format(v)) def _wrap_transformers_in_list(self, transformations): new_transformers = {} for k, v in transformations.items(): if isinstance(v, list): new_transformers[k] = v else: new_transformers[k] = [v] return new_transformers def __init__(self, transformations): ''' Init XDataFrameTransformer with a dict of transformations. Each transformation specify column and transformer object :param transformations: dict {column_name: Transformer object or [Transformer object]} ''' self._validate_transformations(transformations) self.transformations = self._wrap_transformers_in_list(transformations)
[docs] def fit(self, X=None, y=None, **kwargs): ''' Fit each transformer at self.transformations dictionary ''' if not isinstance(X, XDataFrame): raise TypeError('X must be a XDataFrame type. Not {}'.format(type(X))) for col_name, transformations in self.transformations.items(): for t in transformations: t.fit(X[col_name]) return self
[docs] def transform(self, X, columns_mapping=None): ''' Transform X with fitted dictionary self.transformations. :param columns_mapping: {old_col: new_col} mapping between columns in fit data set and current X :return: ''' if columns_mapping is None: columns_mapping = {} transformers_df = X.copy() for col_name, transformations in self.transformations.items(): for t in transformations: new_col_name = columns_mapping.get(col_name, col_name) transformed_column = t.transform(X[new_col_name]) if type(transformed_column) == XSeries: transformers_df.rename(columns={ new_col_name: transformed_column.name }, inplace=True) transformers_df[transformed_column.name] = transformed_column else: transformers_df.drop(new_col_name, inplace=True, axis=1) transformers_df = XDataFrame.concat_dataframes( [transformers_df, transformed_column] ) return transformers_df