diff --git a/seqlearn/_inference/__init__.py b/seqlearn/_inference/__init__.py new file mode 100644 index 0000000..151b1e2 --- /dev/null +++ b/seqlearn/_inference/__init__.py @@ -0,0 +1 @@ +from . import forward_backward \ No newline at end of file diff --git a/seqlearn/_inference/forward_backward.pyx b/seqlearn/_inference/forward_backward.pyx new file mode 100644 index 0000000..6016b5f --- /dev/null +++ b/seqlearn/_inference/forward_backward.pyx @@ -0,0 +1,217 @@ +# cython: profile=True +# Author: Chyi-Kwei Yau + +"""Forward-Backward algorithm for CRF training & posterior calculation""" + +cimport cython +cimport numpy as np +import numpy as np + +from libc.math cimport exp, log + +np.import_array() + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef np.float64_t _logsumexp(np.ndarray[ndim=1, dtype=np.float64_t] arr): + """ + simple 1-D logsumexp function + """ + cdef np.npy_intp i, j, arr_length + cdef np.float64_t v_max, v_sum + + arr_length = arr.shape[0] + + # find max + v_max = arr[0] + for i from 1 <= i < arr_length: + if arr[i] > v_max: + v_max = arr[i] + + #sum of exp value + v_sum = 0.0 + for j from 0 <= j < arr_length: + v_sum += exp(arr[j] - v_max) + + # logsumexp value + v_sum = log(v_sum) + v_max + return v_sum + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef _forward(np.ndarray[ndim=2, dtype=np.float64_t] score, + np.ndarray[ndim=3, dtype=np.float64_t] trans_score, + np.ndarray[ndim=2, dtype=np.float64_t] b_trans, + np.ndarray[ndim=1, dtype=np.float64_t] init, + np.ndarray[ndim=1, dtype=np.float64_t] final): + """ + Forward Algorithm + + Parameters + ---------- + score : array, shape = (n_samples, n_states) + Scores per sample/class combination; in a linear model, X * w.T. + May be overwritten. + trans_score : array, shape = (n_samples, n_states, n_states), optional + Scores per sample/transition combination. + b_trans : array, shape = (n_states, n_states) + Transition weights. + init : array, shape = (n_states,) + final : array, shape = (n_states,) + Initial and final state weights. + + Return + ------ + forward : array, shape = (n_samples, n_states) + + References + ---------- + L. R. Rabiner (1989). A tutorial on hidden Markov models and selected + applications in speech recognition. Proc. IEEE 77(2):257-286. + + """ + + cdef np.ndarray[ndim=2, dtype=np.float64_t] forward + cdef np.ndarray[ndim=1, dtype=np.float64_t] temp_array + cdef np.npy_intp i, j, k, m, n_samples, n_states, last_index + + if trans_score is not None: + raise NotImplementedError("No transition scores for forward algorithm yet.") + + n_samples, n_states = score.shape[0], score.shape[1] + last_index = n_samples - 1 + forward = np.empty((n_samples, n_states), dtype=np.float64) + + # initialize + for j in range(n_states): + forward[0, j] = init[j] + score[0, j] + + for i in range(1, n_samples): + for k in range(n_states): + temp_array = forward[i-1, :] + b_trans[:, k] + score[i, k] + #if trans_score is not None: + # temp_array += trans_score[i-1, k, :] + if i == last_index: + temp_array += final[k] + forward[i, k] = _logsumexp(temp_array) + + return forward + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef _backward(np.ndarray[ndim=2, dtype=np.float64_t] score, + np.ndarray[ndim=3, dtype=np.float64_t] trans_score, + np.ndarray[ndim=2, dtype=np.float64_t] b_trans, + np.ndarray[ndim=1, dtype=np.float64_t] init, + np.ndarray[ndim=1, dtype=np.float64_t] final): + """ + Backward Algorithm (similar to forward Algorithm) + + Parameters + ---------- + Same as Forward function + + + Returns + ------- + backward : array, shape = (n_samples, n_states) + + """ + + cdef np.ndarray[ndim=2, dtype=np.float64_t] backward + cdef np.ndarray[ndim=1, dtype=np.float64_t] temp_array + cdef np.npy_intp i, j, k, m, n_samples, n_states, last_index + + if trans_score is not None: + raise NotImplementedError("No transition scores for backward yet.") + + n_samples, n_states = score.shape[0], score.shape[1] + last_index = n_samples - 1 + + backward = np.empty((n_samples, n_states), dtype=np.float64) + + # initialize + for j in range(n_states): + # inital backward value = 1.0 = exp(0.0) + backward[last_index, j] = 0.0 + + for i in range(last_index-1, -1, -1): + for k in range(n_states): + temp_array = backward[i+1, :] + b_trans[k, :] + score[i+1, :] + #if trans_score is not None: + # temp_array += trans_score[i, :, k] + if i == last_index-1: + temp_array += final + + backward[i, k] = _logsumexp(temp_array) + + + return backward + + +@cython.boundscheck(False) +@cython.wraparound(False) +def _posterior(np.ndarray[ndim=2, dtype=np.float64_t] score, + np.ndarray[ndim=3, dtype=np.float64_t] trans_score, + np.ndarray[ndim=2, dtype=np.float64_t] b_trans, + np.ndarray[ndim=1, dtype=np.float64_t] init, + np.ndarray[ndim=1, dtype=np.float64_t] final): + + """ + Calculate posterior distrubtion based on Forward-Backward algorithm + + Parameters + ---------- + Same as Forward function + + References + ---------- + C. Sutton (2006) An Introduction to Conditional Random Fields for + Relational Learning + + """ + + cdef np.ndarray[ndim=2, dtype=np.float64_t] forward, backward, state_posterior, trans_posterior + cdef np.npy_intp i, j, k, n_samples, n_states + # log likelihood value + cdef np.float64_t ll, temp_trans_val + + if trans_score is not None: + raise NotImplementedError("No transition scores for posterior func yet.") + + n_samples, n_states = score.shape[0], score.shape[1] + + # initialize + state_posterior = np.empty((n_samples, n_states), dtype=np.float64) + trans_posterior = np.zeros((n_states, n_states), dtype=np.float64) + + # get forward-backward values + forward = _forward(score, trans_score, b_trans, init, final) + backward = _backward(score, trans_score, b_trans, init, final) + + # get log likelihood + ll = _logsumexp(forward[n_samples-1, :]) + + # states posterior + for i in range(n_samples): + for j in range(n_states): + state_posterior[i, j] = forward[i, j] + backward[i, j] - ll + np.exp(state_posterior, out=state_posterior) + + # transition posterior + for i in range(n_samples-1): + for j in range(n_states): + for k in range(n_states): + temp_trans_val = forward[i, j] + b_trans[j, k] + score[i+1, k] + backward[i+1, k] - ll + # add final feature + if i == n_samples-2: + temp_trans_val += final[k] + # Note: get transition posterior from log scale and sum up + # from position 1 to (n_samples-1) + trans_posterior[j, k] += exp(temp_trans_val) + + return state_posterior, trans_posterior, ll + diff --git a/seqlearn/_inference/tests/test_forward_backward.py b/seqlearn/_inference/tests/test_forward_backward.py new file mode 100644 index 0000000..21058db --- /dev/null +++ b/seqlearn/_inference/tests/test_forward_backward.py @@ -0,0 +1,59 @@ +import numpy as np +from numpy.testing import assert_almost_equal +from seqlearn._inference.forward_backward import _forward, _backward, _posterior + +init = np.array([.2, .1]) +final = np.array([.1, .2]) + +trans = np.array([[.1, .2], + [.4, .3]]) + +score = np.array([[.3, .2], + [.4, .1], + [.5, .4]]) + + +def test_forward(): + forward = _forward(score, None, trans, init, final) + + true_forward = np.array([[0.5, 0.3], + [1.7443, 1.4444], + [3.1375, 3.1425]]) + + # assert equal + assert_almost_equal(true_forward, forward, decimal=3) + + +def test_backward(): + backward = _backward(score, None, trans, init, final) + + #true value + true_backward = np.array([[2.6375, 2.8425], + [1.4444, 1.6444], + [0.0, 0.0]]) + # assert equal + assert_almost_equal(true_backward, backward, decimal=3) + + +def test_forward_backward(): + forward = _forward(score, None, trans, init, final) + backward = _backward(score, None, trans, init, final) + + assert_almost_equal(forward[-1, :], backward[0, :] + score[0, :] + init) + + +def test_posterior(): + state_posterior, trans_posterior, ll = _posterior(score, None, trans, init, final) + + state_posterior_true = np.array([[0.4987, 0.5012], + [0.5249, 0.4750], + [0.4987, 0.5012]]) + + trans_posterior_true = np.array([[0.4987, 0.5249], + [0.5249, 0.4512]]) + + assert_almost_equal(state_posterior_true, state_posterior, decimal=3) + assert_almost_equal(trans_posterior_true, trans_posterior, decimal=3) + + # sum of transition posterior should sum up to (n_samples-1) + assert_almost_equal(np.sum(trans_posterior), state_posterior.shape[0]-1) diff --git a/seqlearn/crf.py b/seqlearn/crf.py new file mode 100644 index 0000000..0ea2155 --- /dev/null +++ b/seqlearn/crf.py @@ -0,0 +1,198 @@ +# Linear Chain CRF with Stochastic Gradient Descent. +# author: Chyi-Kwei Yau, 2014 + +from __future__ import division, print_function + +import numpy as np + +from .base import BaseSequenceClassifier +from ._utils import (atleast2d_or_csr, check_random_state, count_trans, + safe_add, safe_sparse_dot) + +from ._inference.forward_backward import _posterior + + +class LinearChainCRF(BaseSequenceClassifier): + """Linear Chian Conditional Random Field for sequence classification. + + This implements a linear chain CRF with Stochastic Gradient Descent. + + Parameters + ---------- + decode : string, optional + Decoding algorithm, either "bestfirst" or "viterbi" (default). + + lr : float, optional + Initial learning rate + + lr_exponent : float, optional + Exponent for inverse scaling learning rate. The effective learning + rate is lr / (t ** lr_exponent), where t is the iteration number. + + max_iter : integer, optional + Number of iterations (aka. epochs). Each sequence is visited once in + each iteration. + + random_state : {integer, np.random.RandomState}, optional + Random state or seed used for shuffling sequences within each + iteration. + + reg: L2 regularization value + + compute_obj_val: compute objective value. Set this to True to check whether the objective + value converges. + + verbose : integer, optional + Verbosity level. Defaults to zero (quiet mode). + + References + ---------- + J. Lafferty (2001). Conditional random fields: Probabilistic models + for segmenting and labeling sequence data + + C. Sutton (2006) An Introduction to Conditional Random Fields for + Relational Learning + + N. Schraudolph (2006). Accelerated Training of Conditional Random + Fields with Stochastic Gradient Methods + + """ + + def __init__(self, decode="viterbi", lr=1.0, lr_exponent=.1, max_iter=10, + random_state=None, reg=.01, compute_obj_val=False, verbose=0): + self.decode = decode + self.lr = lr + self.lr_exponent = lr_exponent + self.max_iter = max_iter + self.random_state = random_state + self.reg = reg + self.compute_obj_val = compute_obj_val + self.verbose = verbose + + def fit(self, X, y, lengths): + """Fit to a set of sequences. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Feature matrix of individual samples. + + y : array-like, shape (n_samples,) + Target labels. + + lengths : array-like of integers, shape (n_sequences,) + Lengths of the individual sequences in X, y. The sum of these + should be n_samples. + + Returns + ------- + self : LinearChainCRF + """ + + X = atleast2d_or_csr(X) + + classes, y = np.unique(y, return_inverse=True) + n_classes = len(classes) + + class_range = np.arange(n_classes) + Y_true = y.reshape(-1, 1) == class_range + + lengths = np.asarray(lengths) + n_samples, n_features = X.shape + n_sequence = lengths.shape[0] + + end = np.cumsum(lengths) + start = end - lengths + + # initialize parameters + w = np.zeros((n_classes, n_features), order='F') + b_trans = np.zeros((n_classes, n_classes)) + b_init = np.zeros(n_classes) + b_final = np.zeros(n_classes) + + w_avg = np.zeros_like(w) + b_trans_avg = np.zeros_like(b_trans) + b_init_avg = np.zeros_like(b_init) + b_final_avg = np.zeros_like(b_final) + + sequence_ids = np.arange(n_sequence) + rng = check_random_state(self.random_state) + + avg_count = 1. + for it in xrange(1, self.max_iter + 1): + if self.verbose: + print("Iteration {0:2d}...".format(it)) + + if self.compute_obj_val: + sample_count = 0 + sum_obj_val = 0.0 + + rng.shuffle(sequence_ids) + + lr = self.lr / (it ** self.lr_exponent) + + reg = self.reg / n_sequence + + for i in sequence_ids: + X_i = X[start[i]:end[i]] + y_t_i = Y_true[start[i]:end[i]] + t_trans = count_trans(y[start[i]:end[i]], n_classes) + + score = safe_sparse_dot(X_i, w.T) + + # posterior distribution for states & transtion + post_state, post_trans, ll = _posterior(score, None, b_trans, b_init, b_final) + + if self.compute_obj_val: + w_true = safe_sparse_dot(y_t_i.T, X_i) + feature_val = np.sum(w_true * w) + trans_val = np.sum(t_trans * b_trans) + init_val = np.sum(y_t_i[0] * b_init) + final_val = np.sum(y_t_i[-1] * b_final) + sum_obj_val += feature_val + trans_val + init_val + final_val - ll - (0.5 * reg * np.sum(w * w)) + + sample_count += 1 + if sample_count % 1000 == 0: + avg_obj_val = sum_obj_val / sample_count + print("iter: {0:d}, sample: {1:d}, avg. objective value {2:.4f}".format( + it, sample_count, avg_obj_val)) + + # update feature w + w_update = safe_sparse_dot(lr * (y_t_i - post_state).T, X_i) - ((lr * reg) * w) + + # update init & final matrix + b_init_update = lr * (post_state[0, :] - y_t_i[0] + reg * b_init) + b_final_update = lr * (post_state[-1, :] - y_t_i[-1] + reg * b_final) + + # update transition matrix + b_trans_update = lr * (post_trans - t_trans + reg * b_trans) + + safe_add(w, w_update) + b_init -= b_init_update + b_final -= b_final_update + b_trans -= b_trans_update + + w_update *= avg_count + b_trans_update *= avg_count + b_init_update *= avg_count + b_final_update *= avg_count + + safe_add(w_avg, w_update) + b_trans_avg -= b_trans_update + b_init_avg -= b_init_update + b_final_avg -= b_final_update + + avg_count += 1. + + w -= w_avg / avg_count + b_init -= b_init_avg / avg_count + b_trans -= b_trans_avg / avg_count + b_final -= b_final_avg / avg_count + + self.coef_ = w + self.intercept_init_ = b_init + self.intercept_trans_ = b_trans + self.intercept_final_ = b_final + self.classes_ = classes + + return self diff --git a/seqlearn/tests/test_crf.py b/seqlearn/tests/test_crf.py new file mode 100644 index 0000000..525a396 --- /dev/null +++ b/seqlearn/tests/test_crf.py @@ -0,0 +1,26 @@ +import numpy as np +from numpy.testing import assert_array_equal +from scipy.sparse import csc_matrix + +from seqlearn.crf import LinearChainCRF + + +def test_crf(): + X = np.array([[1, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + X = csc_matrix(X) + y = np.array(['0', '0', '1', '1', '0', '1', '1']) + lengths = np.array([4, 3]) + + for it in [1, 5, 10]: + clf = LinearChainCRF(max_iter=it) + clf.fit(X, y, lengths) + + y_pred = clf.predict(X, lengths) + assert_array_equal(y, y_pred) diff --git a/setup.py b/setup.py index f382328..3b56ba6 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ maintainer_email="L.J.Buitinck@uva.nl", license="MIT", url="https://github.com/larsmans/seqlearn", - packages=["seqlearn", "seqlearn._utils", "seqlearn._decode", "seqlearn.datasets"], + packages=["seqlearn", "seqlearn._utils", "seqlearn._decode", "seqlearn.datasets", "seqlearn._inference"], classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Science/Research", @@ -27,6 +27,9 @@ Extension("seqlearn._decode.viterbi", ["seqlearn/_decode/viterbi.pyx"]), Extension("seqlearn._utils.ctrans", ["seqlearn/_utils/ctrans.pyx"]), Extension("seqlearn._utils.safeadd", ["seqlearn/_utils/safeadd.pyx"]), + Extension("seqlearn._inference.forward_backward", + ["seqlearn/_inference/forward_backward.pyx"], + libraries=["m"]), ], )