Skip to content

Commit 88ee3a3

Browse files
authored
Merge pull request #25 from CyberAgentAI/feature/load-data-in-c
Feature: Add data path to train method
2 parents 719fa62 + 4cf62b5 commit 88ee3a3

3 files changed

Lines changed: 24 additions & 8 deletions

File tree

ffm/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ def read_ffm_model(cls, model_path: str) -> "Model":
9393

9494

9595
def train(
96-
train_data: Dataset,
96+
train_data: Optional[Dataset] = None,
97+
train_path: Optional[str] = None,
9798
valid_data: Optional[Dataset] = None,
99+
valid_path: Optional[str] = None,
98100
eta: float = 0.2,
99101
lam: float = 0.00002,
100102
nr_iters: int = 15,
@@ -107,17 +109,21 @@ def train(
107109
random: bool = True,
108110
nds_rate: float = 1.0,
109111
) -> Model:
110-
tr = (train_data.data, train_data.labels)
111-
iw = train_data.importance_weights
112+
tr, iw = None, None
113+
if train_data is not None:
114+
tr = (train_data.data, train_data.labels)
115+
iw = train_data.importance_weights
112116

113117
va, iwv = None, None
114118
if valid_data is not None:
115119
va = (valid_data.data, valid_data.labels)
116120
iwv = valid_data.importance_weights
117121

118122
weights, best_iteration, normalization, best_va_loss = libffm_train(
119-
tr,
123+
tr=tr,
124+
tr_path=train_path,
120125
va=va,
126+
va_path=valid_path,
121127
iw=iw,
122128
iwv=iwv,
123129
eta=eta,

ffm/libffm.pyx

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cdef extern from "ffm.h" namespace "ffm" nogil:
5656
ffm_float best_va_loss
5757

5858
ffm_model *ffm_train_with_validation(ffm_problem *Tr, ffm_problem *Va, ffm_importance_weights *iws, ffm_importance_weights *iwvs, ffm_parameter param);
59+
ffm_problem *ffm_read_problem(char *path);
5960

6061

6162
cdef ffm_problem* make_ffm_prob(X, y):
@@ -174,8 +175,10 @@ cdef object _train(
174175

175176

176177
def train(
177-
tr,
178+
tr=None,
179+
tr_path=None,
178180
va=None,
181+
va_path=None,
179182
iw=None,
180183
iwv=None,
181184
eta=0.2,
@@ -205,11 +208,18 @@ def train(
205208
param.nds_rate = nds_rate
206209

207210
cdef:
208-
ffm_problem* tr_ptr = make_ffm_prob(tr[0], tr[1])
211+
ffm_problem* tr_ptr
209212
ffm_problem* va_ptr
210213
ffm_importance_weights *iw_ptr, *iwv_ptr
211214

212-
if va is not None:
215+
if tr_path is not None:
216+
tr_ptr = ffm_read_problem(tr_path.encode("utf-8"))
217+
else:
218+
tr_ptr = make_ffm_prob(tr[0], tr[1])
219+
220+
if va_path is not None:
221+
va_ptr = ffm_read_problem(va_path.encode("utf-8"))
222+
elif va is not None:
213223
va_ptr = make_ffm_prob(va[0], va[1])
214224
else:
215225
va_ptr = NULL

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
setup(
2828
name="ffm",
29-
version="0.3.1",
29+
version="0.4.0",
3030
description="LibFFM Python Package",
3131
long_description="LibFFM Python Package",
3232
install_requires=["numpy"],

0 commit comments

Comments
 (0)