Skip to content

Commit

Permalink
get_data groups on non-independent observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Mann authored and Michael Mann committed Dec 3, 2018
1 parent 2e10782 commit 3d76f4c
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions tsraster/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split as tts
from sklearn.model_selection import GroupShuffleSplit

from sklearn.linear_model import ElasticNet
from sklearn.metrics import accuracy_score,confusion_matrix,cohen_kappa_score
#from sklearn.preprocessing import StandardScaler as scaler
Expand All @@ -9,16 +11,17 @@
from os.path import isfile


def get_data(obj, test_size=0.33,scale=False,stratify=True):
def get_data(obj, test_size=0.33,scale=False,stratify=None,groups=None):
'''
:param obj: path to csv or name of pandas dataframe with yX, or list holding dataframes [y,X]
:param test_size: percentage to hold out for testing (default 0.33)
:param scale: should data be centered and scaled True or False (default)
:param stratify: should the sample be stratified by the dependent valueTrue or False (default)
:param scale: should data be centered and scaled True or False
:param stratify: should the sample be stratified by the dependent valueTrue or False
:param groups: group information defining domain specific stratifications of the samples, ex pixel_id, df.index.get_level_values('index') (default None)
:return: X_train, X_test, y_train, y_test splits
'''

# read in inputs
print("input should be csv or pandas dataframe with yX, or [y,X]")
if str(type(obj)) == "<class 'pandas.core.frame.DataFrame'>":
Expand All @@ -31,29 +34,32 @@ def get_data(obj, test_size=0.33,scale=False,stratify=True):
df = pd.read_csv(obj)
else:
print("input format not dataframe, csv, or list")


df = df.drop(['Unnamed: 0'], axis=1,errors ='ignore') # clear out unknown columns

# check if center and scale
if scale == True:
min_max_scaler = preprocessing.MinMaxScaler()
np_scaled = min_max_scaler.fit_transform(df)
df = pd.DataFrame(np_scaled)

y = df.iloc[:,0]
X = df.iloc[:,1:]

if stratify==True:
X_train, X_test, y_train, y_test = tts(X, y,
test_size=test_size,
stratify=y,
random_state=42)
if groups is not None:
# test train accounting for independent groups
train_inds, test_inds = next(GroupShuffleSplit().split(X, groups=groups))
X_train, X_test, y_train, y_test = X.iloc[train_inds,:], X.iloc[test_inds,:], y.iloc[train_inds], y.iloc[test_inds]

else:
# ungrouped test train split
X_train, X_test, y_train, y_test = tts(X, y,
test_size=test_size,
stratify=stratify,
random_state=42)



return X_train, X_test, y_train, y_test


Expand Down

0 comments on commit 3d76f4c

Please sign in to comment.