Skip to content

Commit

Permalink
Adapt io.py to pyfact
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnoe committed Mar 24, 2017
1 parent f37bd88 commit f914ad3
Showing 1 changed file with 11 additions and 25 deletions.
36 changes: 11 additions & 25 deletions fact/io.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from os import path
import pandas as pd
import json
from sklearn_pandas import DataFrameMapper
from sklearn.externals import joblib
from sklearn2pmml import sklearn2pmml
import h5py
import sys
import logging
import numpy as np
from copy import copy

__all__ = [
'write_data', 'to_native_byteorder', 'read_h5py', 'read_h5py_chunked',
'read_pandas_hdf5', 'pickle_model', 'check_extension', 'read_data'
'write_data',
'to_native_byteorder',
'read_data',
'read_h5py',
'read_h5py_chunked',
'read_pandas_hdf5',
'check_extension',
]

log = logging.getLogger(__name__)
Expand All @@ -22,12 +24,12 @@
native_byteorder = native_byteorder = {'little': '<', 'big': '>'}[sys.byteorder]


def write_data(df, file_path, hdf_key='table'):
def write_data(df, file_path, key='table'):

name, extension = path.splitext(file_path)

if extension in ['.hdf', '.hdf5', '.h5']:
df.to_hdf(file_path, key=hdf_key)
df.to_hdf(file_path, key=key, format='table')

elif extension == '.json':
df.to_json(file_path)
Expand Down Expand Up @@ -176,6 +178,8 @@ def read_data(file_path, key=None, columns=None):
with open(file_path, 'r') as j:
d = json.load(j)
df = pd.DataFrame(d)
elif extension in ('.jsonl', '.jsonlines'):
df = pd.read_json(file_path, lines=True)
else:
raise NotImplementedError('Unknown data file extension {}'.format(extension))

Expand All @@ -186,21 +190,3 @@ def check_extension(file_path, allowed_extensions=allowed_extensions):
p, extension = path.splitext(file_path)
if extension not in allowed_extensions:
raise IOError('Allowed formats: {}'.format(allowed_extensions))


def pickle_model(classifier, feature_names, model_path, label_text='label'):
p, extension = path.splitext(model_path)
classifier.feature_names = feature_names
if (extension == '.pmml'):
print("Pickling model to {} ...".format(model_path))

mapper = DataFrameMapper([
(feature_names, None),
(label_text, None),
])

joblib.dump(classifier, p + '.pkl', compress=4)
sklearn2pmml(classifier, mapper, model_path)

else:
joblib.dump(classifier, model_path, compress=4)

0 comments on commit f914ad3

Please sign in to comment.