diff --git a/alipy/query_strategy/LAL_RL/Agent.py b/alipy/query_strategy/LAL_RL/Agent.py new file mode 100644 index 0000000..17c439b --- /dev/null +++ b/alipy/query_strategy/LAL_RL/Agent.py @@ -0,0 +1,122 @@ +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np + +class Net(nn.Module): + def __init__(self, candidate_size, bias_initialization=None): + super().__init__() + self.fc1 = nn.Linear(candidate_size, 10) + self.fc2 = nn.Linear(13, 5) + self.fc3 = nn.Linear(5, 1) + + if bias_initialization is not None: + self.fc3.bias = torch.nn.Parameter(torch.tensor(bias_initialization, dtype=torch.float)) + + def forward(self, t): + state = t[:,:-3] + action = t[:,-3:] + t = torch.sigmoid(self.fc1(state)) + t = torch.cat((t,action), dim=1) + t = torch.sigmoid(self.fc2(t)) + t = self.fc3(t) + return t + + +class Agent: + def __init__(self, n_state_estimation=30, learning_rate=1e-3, batch_size=32, bias_average=0, + target_copy_factor=0.01, gamma=0.999, device=None): + self.net = Net(n_state_estimation, bias_average).to(device) + self.target_net = Net(n_state_estimation, bias_average).to(device) + self.target_net.eval() + self.device = device + + # copy weihts from training net to target net + self.target_net.load_state_dict(self.net.state_dict()) + + # create loss function and optimizer + self.loss = nn.MSELoss(reduction='sum') + self.optimizer = optim.Adam(self.net.parameters(), lr=learning_rate) + self.batch_size = batch_size + self.target_copy_factor = target_copy_factor + self.gamma = gamma + + def train(self, minibatch): + max_prediction_batch = [] + + for i, next_classifier_state in enumerate(minibatch.next_classifier_state): + # Predict q-value function value for all available actions + n_next_actions = np.shape(minibatch.next_action_state[i])[1] + next_classifier_state = np.repeat([next_classifier_state], n_next_actions, axis=0) + next_classifier_state = np.concatenate((next_classifier_state, + minibatch.next_action_state[i].transpose()), axis=1) + input_tensor = torch.tensor(next_classifier_state, dtype=torch.float, device=self.device) + + # Use target_estimator + target_predictions = self.target_net(input_tensor) + + # Use estimator + predictions = self.net(input_tensor) + + target_predictions = np.ravel(target_predictions.detach().cpu().numpy()) + predictions = np.ravel(predictions.detach().cpu().numpy()) + + # Follow Double Q-learning idea of van Hasselt, Guez, and Silver 2016 + # Select the best action according to predictions of estimator + best_action_by_estimator = np.random.choice(np.where(predictions == np.amax(predictions))[0]) + # As the estimate of q-value of the best action, + # take the prediction of target estimator for the selecting action + max_target_prediction_i = target_predictions[best_action_by_estimator] + max_prediction_batch.append(max_target_prediction_i) + + max_prediction_batch = torch.tensor(max_prediction_batch, dtype=torch.float, device=self.device) + terminal_mask = torch.where(torch.tensor(minibatch.terminal, device=self.device), torch.zeros(self.batch_size, device=self.device), + torch.ones(self.batch_size, device=self.device)) + masked_target_predictions = max_prediction_batch * terminal_mask + expected_state_action_values = torch.tensor(minibatch.reward, dtype=torch.float, device=self.device) + self.gamma*masked_target_predictions + + input_tensor = np.concatenate((minibatch.classifier_state, minibatch.action_state), axis=1) + input_tensor = torch.from_numpy(input_tensor).to(self.device).float() + net_output = self.net(input_tensor) + net_output = net_output.flatten() + + td_errors = expected_state_action_values - net_output + + # actually train the network + loss = self.loss(net_output, expected_state_action_values) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Operation to copy parameter values (partially) to target estimator + new_state_dict = OrderedDict() + for var_name in self.net.state_dict(): + target_var = self.target_net.state_dict()[var_name] + policy_var = self.net.state_dict()[var_name] + target_var = target_var*(1-self.target_copy_factor) + policy_var*self.target_copy_factor + new_state_dict[var_name] = target_var + self.target_net.load_state_dict(new_state_dict) + + return td_errors.detach().cpu().numpy() + + def get_action(self, classifier_state, action_state): + input_tensor = np.concatenate((np.repeat(classifier_state[None,:], action_state.shape[1], axis=0), + action_state.transpose()), axis=1) + input_tensor = torch.tensor(input_tensor, dtype=torch.float, device=self.device) + predictions = self.net(input_tensor) + predictions = predictions.flatten() + + predictions = predictions.detach().cpu().numpy() + max_action = np.random.choice(np.where(predictions == predictions.max())[0]) + + return max_action + + def update_target_net(self): + self.target_net.load_state_dict(self.net.state_dict()) + + def save_net(self, path): + torch.save(self.net.state_dict(), path + ".pt") + + def save_target_net(self, path): + torch.save(self.target_net.state_dict(), path + "_target_net.pt") \ No newline at end of file diff --git a/alipy/query_strategy/LAL_RL/LAL_RL.py b/alipy/query_strategy/LAL_RL/LAL_RL.py new file mode 100644 index 0000000..b0aa771 --- /dev/null +++ b/alipy/query_strategy/LAL_RL/LAL_RL.py @@ -0,0 +1,390 @@ +import copy +from .envs import LalEnvTargetAccuracy +from .datasets import DatasetUCI +from .helpers import ReplayBuffer +from .Agent import Agent, Net +from sklearn import metrics +from sklearn.linear_model import LogisticRegression +import numpy as np +import scipy +import collections +from ..base import BaseIndexQuery + +from tqdm import tqdm + +import torch + + +class LAL_RL_StrategyLearner: + """Reference Paper: + Ksenia Konyushkova, Raphael Sznitman, Pascal Fua, 2018. + Discovering General-Purpose Active Learning Strategies. + https://arxiv.org/abs/1810.04114 + + The implementation is referred to + https://github.com/ksenia-konyushkova/LAL-RL + """ + def __init__(self, path, possible_dataset_names, n_state_estimation=30, size=-1, subset=-1, + quality_method=metrics.accuracy_score, tolerance_level=0.98, model=None, + replay_buffer_size = 1e4, prioritized_replay_exponent = 3) : + """ + path: the directory that contains the datasets + possible_dataset_names = the name of the datasets that will be used for training + n_state_estimation: how many datapoints are used to represent a state + size: An integer indicating the size of training dataset to sample, if -1 use all data + subset: An integer indicating what subset of data to use. 0: even, 1: odd, -1: all datapoints + quality_method: the measure that will be used for the target quality + tolerance_level: the ratio of the target quality that the agent has to achieve to end an episode + model: the model that is used to make predictions on the datasets, should implement fit, predict and predict_proba + replay_buffer_size: An interger indicating the maximum number of transaction to be stored in the replay buffer + prioritized_replay_exponent: A float that is used for turning the td error into a probability to be sampled + """ + dataset = DatasetUCI(possible_dataset_names, n_state_estimation=n_state_estimation, subset=subset, + size=size,path=path) + if model == None: + model = LogisticRegression() + self.env = LalEnvTargetAccuracy(dataset, model, quality_method=quality_method, + tolerance_level=tolerance_level) + self.n_state_estimation = n_state_estimation + self.replay_buffer = ReplayBuffer(buffer_size=replay_buffer_size, prior_exp=prioritized_replay_exponent) + + + def train_query_strategy(self, saving_path, file_name, warm_start_episodes=100, nn_updates_per_warm_start=100, + learning_rate=1e-4, batch_size=32, gamma=1, target_copy_factor=0.01, + training_iterations=1000, episodes_per_iteration=10, updates_per_iteration=60, + epsilon_start=1, epsilon_end=0, epsilon_step=1000, device=None, verbose=2): + """ + saving_path: the directory where the learnt strategy will be saved + file_name: the file name for the learnt strategy + warm_start_episodes: the number of warm start episodes that will be performed before the actual training + nn_updates_per_warm_start: the number of q-network updates after the warm-start-episodes + learning_rate: the learning rate of the deep q-network + batch_size: the size of the batches that will be sampled from replay memory for one q-network update + gamma: the discount factor in q-learning + target_copy_factor: the factor for copying the weights of the estimator to the target estimator + training_iterations: the amount of training iterations + episodes_per_iteration: the amount of episodes in one training iteration + updates_per_iteration: the number of q-network updates that are performed at the end of an iteration + epsilon_start: the start value of epsilon for the epsilon greedy strategy + epsilon_end: the end value of epsilon for the epsilon greedy strategy + epsilon_step: the number of iterations it takes for the epsilon value to decay from start value to end value + device: pytorch device that will be used for the computations + verbose: 3 - progessbar for warmstart episodes, iterations, episodes and steps in the environment + 2 - progessbar for warmstart episodes, iterations and episodes + 1 - just one progessbar for the iterations + 0 - no progressbars + """ + if verbose not in [0,1,2,3]: + raise ValueError("Verbose must be 0, 1, 2 or 3") + self.saving_path = saving_path + "/" + file_name + self.batch_size = batch_size + self.training_iterations = training_iterations + self.episodes_per_iteration = episodes_per_iteration + self.updates_per_iteration = updates_per_iteration + self.epsilon_start = epsilon_start + self.epsilon_end = epsilon_end + self.epsilon_step = epsilon_step + self.verbose = verbose + + bias_average = self.run_warm_start_episodes(warm_start_episodes) + self.agent = Agent(self.n_state_estimation, learning_rate, batch_size, bias_average, + target_copy_factor, gamma, device) + self.train_agent(nn_updates_per_warm_start) + self.run_training_iterations() + + + def continue_training(self, saving_path, file_name, net_path, target_net_path=None, learning_rate=1e-4, batch_size=32, + gamma=1, target_copy_factor=0.01, training_iterations=1000, episodes_per_iteration=10, updates_per_iteration=60, + epsilon_start=1, epsilon_end=0, epsilon_step=1000, device=None, verbose=2): + """ + net_path: the path to the q-network that has already been trained and shall now be further trained + target_net_path: the corresponding target net if None then a copy of the net will be used as target net + + the other parameters are exactly the same as in train_query_strategy + """ + if verbose not in [0,1,2,3]: + raise ValueError("Verbose must be 0, 1, 2 or 3") + state_dict = torch.load(net_path, map_location=device) + + if target_net_path != None: + target_state_dict = torch.load(target_net_path, map_location=device) + else: + target_state_dict = copy.deepcopy(state_dict) + + # test if given n_state_estimation matches the one of the loaded state_dict + if self.n_state_estimation != state_dict[list(state_dict.keys())[0]].size(1): + raise ValueError("given n_state_estimation doesn't match the one of the loaded state_dict") + # test if n_state_estimation of net and target net are the same + if state_dict[list(state_dict.keys())[0]].size(1) != state_dict[list(target_state_dict.keys())[0]].size(1): + raise ValueError("n_state_estimation of net and target net are not the same") + + self.saving_path = saving_path + "/" + file_name + self.batch_size = batch_size + self.training_iterations = training_iterations + self.episodes_per_iteration = episodes_per_iteration + self.updates_per_iteration = updates_per_iteration + self.epsilon_start = epsilon_start + self.epsilon_end = epsilon_end + self.epsilon_step = epsilon_step + + self.agent = Agent(self.n_state_estimation, learning_rate, batch_size, 0, + target_copy_factor, gamma, device) + self.agent.net.load_state_dict(state_dict) + self.agent.target_net.load_state_dict(target_state_dict) + self.run_training_iterations() + + + def run_warm_start_episodes(self, n_episodes): + # create function depending on verbose level + if self.verbose >= 2: + p_bar = tqdm(total=n_episodes, desc="Warmstart episodes", leave=False) + def update(): + p_bar.update() + def close(): + p_bar.close() + else: + update = lambda *x : None + close = lambda *x : None + + # Keep track of episode duration to compute average + episode_durations = [] + for _ in range(n_episodes): + # Reset the environment to start a new episode + # classifier_state contains vector representation of state of the environment (depends on classifier) + # next_action_state contains vector representations of all actions available to be taken at the next step + classifier_state, next_action_state = self.env.reset() + terminal = False + episode_duration = 0 + # before we reach a terminal state, make steps + while not terminal: + # Choose a random action + action = np.random.randint(0, self.env.n_actions) + # taken_action_state is a vector corresponding to a taken action + taken_action_state = next_action_state[:,action] + next_classifier_state, next_action_state, reward, terminal = self.env.step(action) + # Store the transition in the replay buffer + self.replay_buffer.store_transition(classifier_state, + taken_action_state, + reward, next_classifier_state, + next_action_state, terminal) + # Get ready for next step + classifier_state = next_classifier_state + episode_duration += 1 + episode_durations.append(episode_duration) + update() + # compute the average episode duration of episodes generated during the warm start procedure + av_episode_duration = np.mean(episode_durations) + close() + + return -av_episode_duration/2 + + + def train_agent(self, n_of_updates): + # check if there are enough experiences in replay memory + if self.replay_buffer.n < self.batch_size: + return + # create function depending on verbose level + if self.verbose >= 2: + p_bar = tqdm(total=n_of_updates, desc="Train q-net", leave=False) + def update(): + p_bar.update() + def close(): + p_bar.close() + else: + update = lambda *x : None + close = lambda *x : None + for _ in range(n_of_updates): + # Sample a batch from the replay buffer proportionally to the probability of sampling. + minibatch = self.replay_buffer.sample_minibatch(self.batch_size) + # Use batch to train an agent. Keep track of temporal difference errors during training. + td_error = self.agent.train(minibatch) + # Update probabilities of sampling each datapoint proportionally to the error. + self.replay_buffer.update_td_errors(td_error, minibatch.indeces) + update() + close() + + + def run_training_iterations(self): + # create function depending on verbose level + if self.verbose >= 1: + p_bar_iter = tqdm(total=self.training_iterations, desc="Train iterations", leave=(self.verbose > 2)) + def update_iter(): + p_bar_iter.update() + def close_iter(): + p_bar_iter.close() + else: + update_iter = lambda *x : None + close_iter = lambda *x : None + + for iteration in range(self.training_iterations): + # GENERATE NEW EPISODES + # Compute epsilon value according to the schedule. + epsilon = max(self.epsilon_end, self.epsilon_start-iteration*(self.epsilon_start-self.epsilon_end)/self.epsilon_step) + + # create function depending on verbose level + if self.verbose >= 2: + p_bar_episode = tqdm(total=self.episodes_per_iteration, desc="Episodes", leave=False) + def update_episode(): + p_bar_episode.update() + def close_episode(): + p_bar_episode.close() + else: + update_episode = lambda *x : None + close_episode = lambda *x : None + + # Simulate training episodes. + for _ in range(self.episodes_per_iteration): + # Reset the environment to start a new episode. + classifier_state, next_action_state = self.env.reset() + terminal = False + max_steps = len(self.env.dataset.train_labels) + + # create function depending on verbose level + if self.verbose >= 3: + p_bar_steps = tqdm(total=max_steps, desc=self.env.dataset.dataset_name, leave=False) + def update_steps(): + p_bar_steps.update() + def close_steps(): + p_bar_steps.close() + else: + update_steps = lambda *x : None + close_steps = lambda *x : None + # Run an episode. + while not terminal: + # Let an agent choose an action or with epsilon probability, take a random action. + if np.random.ranf() < epsilon: + action = np.random.randint(0, self.env.n_actions) + else: + action = self.agent.get_action(classifier_state, next_action_state) + + # taken_action_state is a vector that corresponds to a taken action + taken_action_state = next_action_state[:,action] + # Make another step. + next_classifier_state, next_action_state, reward, terminal = self.env.step(action) + # Store a step in replay buffer + self.replay_buffer.store_transition(classifier_state, + taken_action_state, + reward, + next_classifier_state, + next_action_state, + terminal) + # Change a state of environment. + classifier_state = next_classifier_state + update_steps() + close_steps() + update_episode() + close_episode() + # NEURAL NETWORK UPDATES + self.train_agent(self.updates_per_iteration) + update_iter() + + self.agent.save_net(self.saving_path) + self.agent.save_target_net(self.saving_path) + close_iter() + + + +class QueryInstanceLAL_RL(BaseIndexQuery): + """This class uses a strategy that was learnt by LAL_RL_StrategyLearner. + + Parameters + ---------- + X: 2D array, + Feature matrix of the whole dataset. It is a reference which will not use additional memory. + + y: array-like, + Label matrix of the whole dataset. It is a reference which will not use additional memory. + + model_path: file-like object or string or os.PathLike object, + state_dict of the trained strategy + + n_state_estimation: int, optional (default=None) + number of datapoints used by the strategy to build the state, if None is provided an inference is attempted + + pre_batch: int, optional (default=128) + batch size that is used when predicting with the learnt strategy + + device: torch.device, optional (default=None) + the pytorch device used for the calculations + + """ + def __init__(self, X, y, model_path, n_state_estimation=None, pred_batch=128, device=None, **kwargs): + super(QueryInstanceLAL_RL, self).__init__(X, y) + state_dict = torch.load(model_path, map_location=device) + self.pred_batch = pred_batch + self.device = device + if n_state_estimation == None: + self.n_state_estimation = state_dict[list(state_dict.keys())[0]].size(1) + else: + self.n_state_estimation = n_state_estimation + self.net = Net(self.n_state_estimation,0) + self.net.load_state_dict(state_dict) + self.net.to(device) + self.net.eval() + + def select(self, label_index, unlabel_index, model=None, batch_size=1, **kwargs): + # copy label_index and unlabel_index + label_index_copy = copy.deepcopy(label_index) + unlabel_index_copy = copy.deepcopy(unlabel_index) + assert (batch_size > 0) + assert (isinstance(unlabel_index_copy, collections.abc.Iterable)) + assert (isinstance(label_index_copy, collections.abc.Iterable)) + if len(unlabel_index_copy) <= batch_size: + return unlabel_index_copy + assert len(unlabel_index_copy) + len(label_index_copy) >= self.n_state_estimation + unlabel_index_copy = np.asarray(unlabel_index_copy) + + # initialize the model and train it if necessary + if model == None: + model = LogisticRegression() + model.fit(self.X[label_index_copy], self.y[label_index_copy]) + + # set aside some unlabeled data for the state representation, the data is removed from the unlabel_index + if len(unlabel_index_copy) >= self.n_state_estimation + batch_size: + chosen_indices = np.random.choice(len(unlabel_index_copy), size=self.n_state_estimation, replace=False) + state_indices = unlabel_index_copy[chosen_indices] + unlabel_index_copy = unlabel_index_copy[np.array([x for x in range(len(unlabel_index_copy)) if x not in chosen_indices])] + + # if there isn't enough data then also the label_index is used and the data is not removed from the indicies + else: + state_indices = np.random.choice(np.concatenate((np.array(label_index_copy), np.array(unlabel_index_copy))), + size=self.n_state_estimation, replace=False) + + # create the state + predictions = model.predict_proba(self.X[state_indices])[:,0] + predictions = np.array(predictions) + idx = np.argsort(predictions) + state = predictions[idx] + + #create the actions + a1 = model.predict_proba(self.X[unlabel_index_copy])[:,0] + + # calculate distances + data = self.X[np.concatenate((label_index_copy,unlabel_index_copy),axis=0)] + distances = scipy.spatial.distance.pdist(data, metric='cosine') + distances = scipy.spatial.distance.squareform(distances) + indeces_known = np.arange(len(label_index_copy)) + indeces_unknown = np.arange(len(label_index_copy), len(label_index_copy)+len(unlabel_index_copy)) + a2 = np.mean(distances[indeces_unknown,:][:,indeces_unknown],axis=0) + a3 = np.mean(distances[indeces_known,:][:,indeces_unknown],axis=0) + + actions = np.concatenate(([a1], [a2], [a3]), axis=0).transpose() + + # calculate the q-values according to the q-network + # first transform the state and actions for the network + state = np.repeat([state], actions.shape[0], axis=0) + state_actions = np.concatenate((state,actions),axis=1) + input_tensor = torch.tensor(state_actions, dtype=torch.float, device=self.device) + + # get the prediction from the network + pred = self.net(input_tensor[:self.pred_batch]) + for i in range(self.pred_batch, input_tensor.size(0), self.pred_batch): + pred = torch.cat((pred, self.net(input_tensor[i:i+self.pred_batch]))) + pred = pred.flatten() + + # sort the actions with respect to their q-value + idx = pred.argsort(descending=True) + idx = idx[:batch_size].detach().cpu().numpy() + + # return the correspoding indeces from the unlabeld index + return unlabel_index_copy[idx] \ No newline at end of file diff --git a/alipy/query_strategy/LAL_RL/__init__.py b/alipy/query_strategy/LAL_RL/__init__.py new file mode 100644 index 0000000..4e9dae2 --- /dev/null +++ b/alipy/query_strategy/LAL_RL/__init__.py @@ -0,0 +1,2 @@ +from .Agent import Net, Agent +from .LAL_RL import LAL_RL_StrategyLearner, QueryInstanceLAL_RL diff --git a/alipy/query_strategy/LAL_RL/datasets.py b/alipy/query_strategy/LAL_RL/datasets.py new file mode 100644 index 0000000..d77d889 --- /dev/null +++ b/alipy/query_strategy/LAL_RL/datasets.py @@ -0,0 +1,123 @@ +import numpy as np +import scipy +from sklearn.model_selection import train_test_split +from sklearn import preprocessing +import pickle as pkl + +class Dataset: + """The base class for all datasets. + + Every dataset class should inherit from Dataset + and load the data. Dataset only declaires the attributes. + + Attributes: + train_data: A numpy array with data that can be labelled. + train_labels: A numpy array with labels of train_data. + test_data: A numpy array with data that will be used for testing. + test_labels: A numpy array with labels of test_data. + n_state_estimation: An integer indicating #datapoints reserved for state representation estimation. + distances: A numpy array with pairwise Eucledian distances between all train_data. + """ + + def __init__(self, n_state_estimation): + """Inits the Dataset object and initialises the attributes with given or empty values.""" + self.train_data = np.array([[]]) + self.train_labels = np.array([[]]) + self.test_data = np.array([[]]) + self.test_labels = np.array([[]]) + self.n_state_estimation = n_state_estimation + self.dataset_name = "" + self.regenerate() + + def regenerate(self): + """The function for generating a dataset with new parameters.""" + pass + + def _scale_data(self): + """Scales train data to 0 mean and unit variance. Test data is scaled with parameters of train data.""" + scaler = preprocessing.StandardScaler().fit(self.train_data) + self.train_data = scaler.transform(self.train_data) + self.test_data = scaler.transform(self.test_data) + + def _keep_state_data(self): + """self.n_state_estimation samples in training data are reserved for estimating the state.""" + self.train_data, self.state_data, self.train_labels, self.state_labels = train_test_split( + self.train_data, self.train_labels, test_size=self.n_state_estimation) + + def _compute_distances(self): + """Computes the pairwise distances between all training datapoints""" + self.distances = scipy.spatial.distance.pdist(self.train_data, metric='cosine') + self.distances = scipy.spatial.distance.squareform(self.distances) + + + +class DatasetUCI(Dataset): + """Class for loading standard benchmark classification datasets. + + UCI dataset. Can be downloaded here: + https://archive.ics.uci.edu/ml/index.php + + Attributes: + possible_names: A list indicating the dataset names that can be used. + subset: An integer indicating what subset of data to use. 0: even, 1: odd, -1: all datapoints. + size: An integer indicating the size of training dataset to sample, if -1 use all data. + path: the path to the folder that contains the datasets + """ + + def __init__(self, possible_names, n_state_estimation, subset, size=-1, path=None): + """Inits a few attributes and the attributes of Dataset object.""" + self.possible_names = possible_names + self.subset = subset + self.size = size + if path == None: + self.path = "./dataUCI/" + else: + self.path = path + Dataset.__init__(self, n_state_estimation) + + def regenerate(self): + """Loads the data and split it into train and test.""" + # every time we select one of the possible datasets to sample data from + self.dataset_name = np.random.choice(self.possible_names) + # load data + data = pkl.load( open( self.path+"/"+self.dataset_name+".p", "rb" ) ) + X = data['X'] + y = data['y'] + if len(y.shape) == 1: + y = y.reshape(y.shape[0], 1) + dtst_size = np.size(y) + + # even datapoints subset + if self.subset == 0: + valid_indeces = list(range(0, dtst_size, 2)) + # odd datapoints subset + elif self.subset == 1: + valid_indeces = list(range(1, dtst_size, 2)) + # all datapoints + elif self.subset == -1: + valid_indeces = list(range(dtst_size)) + else: + print('Incorrect subset attribute value!') + + # try to split data into training and test subsets while insuring that + # all classes from test data are present in train data + done = False + while not done: + # get a part of dataset according to subset (even, odd or all) + train_test_data = X[valid_indeces,:] + train_test_labels = y[valid_indeces,:] + # use a random half/half split for train and test data + self.train_data, self.test_data, self.train_labels, self.test_labels = train_test_split( + train_test_data, train_test_labels, train_size=0.5) + + self._scale_data() + self._keep_state_data() + self._compute_distances() + + # keep only a part of data for training + self.train_data = self.train_data[:self.size,:] + self.train_labels = self.train_labels[:self.size,:] + + # this is needed to insure that some of the classes are missing in train or test data + done = len(np.unique(self.train_labels)) == len(np.unique(self.test_labels)) + diff --git a/alipy/query_strategy/LAL_RL/envs.py b/alipy/query_strategy/LAL_RL/envs.py new file mode 100644 index 0000000..69b8f7f --- /dev/null +++ b/alipy/query_strategy/LAL_RL/envs.py @@ -0,0 +1,364 @@ +import numpy as np +from sklearn.base import clone +import collections +from sklearn.ensemble import RandomForestClassifier + +class LalEnv(object): + """The base class for LAL environment. + + Following the conventions of OpenAI gym, + this class implements the environment + which simulates labelling of a given + annotated dataset. The classes differ + by the way how the reward is computed + and when the terminal state is reached. + It implements the environment that simulates + labelling of a given annotated dataset. + + Attributes: + dataset: An object of class Dataset. + model: A classifier from sklearn. + Should implement fit, predict and predict_proba. + model_rf: A random forest classifier that was fit + to the same data as the data used for model. + quality_method: A function that computes the quality of the prediction. + For example, can be metrics.accuracy_score or metrics.f1_score. + n_classes: An integer indicating the number of classes in a dataset. + Typically 2. + episode_qualities: A list of floats with the errors of classifiers at various steps. + n_actions: An integer indicating the possible number of actions + (the number of remaining unlabelled points). + indeces_known: A list of indeces of datapoints whose labels can be used for training. + indeces_unknown: A list of indeces of datapoint whose labels cannot be used for training yet. + """ + + def __init__(self, dataset, model, quality_method): + """Inits environment with attributes: dataset, model, quality function and other attributes.""" + self.dataset = dataset + self.model = model + self.quality_method = quality_method + # Compute the number of classes as a number of unique labels in train dataset + self.n_classes = np.size(np.unique(self.dataset.train_labels)) + # Initialise a list where quality at each iteration will be written + self.episode_qualities = [] + + def for_lal(self): + """Function that is used to compute features for lal-regr. + + Fits RF classifier to the data.""" + known_data = self.dataset.train_data[self.indeces_known,:] + known_labels = self.dataset.train_labels[self.indeces_known] + known_labels = np.ravel(known_labels) + self.model_rf = RandomForestClassifier(50, oob_score=True, n_jobs=1) + self.model_rf.fit(known_data, known_labels) + + def reset(self, n_start=2): + """Resets the environment. + + 1) The dataset is regenerated accoudring to its method regenerate. + 2) n_start datapoints are selected, at least one datapoint from each class is included. + 3) The model is trained on the initial dataset and the corresponding state of the problem is computed. + + Args: + n_start: An integer indicating the size of annotated set at the beginning. + + Returns: + classifier_state: a numpy.ndarray characterizing the current classifier + of size of number of features for the state, + in this case it is the size of number of datasamples in dataset.state_data + next_action_state: a numpy.ndarray + of size #features characterising actions (currently, 3) x #unlabelled datapoints + where each column corresponds to the vector characterizing each possible action. + """ + + # SAMPLE INITIAL DATAPOINTS + self.dataset.regenerate() + self.episode_qualities = [] + # To train an initial classifier we need at least self.n_classes samples + if n_start < self.n_classes: + n_start = self.n_classes + # Sample n_start datapoints + self.indeces_known = [] + self.indeces_unknown = [] + for i in np.unique(self.dataset.train_labels): + # First get 1 point from each class + cl = np.nonzero(self.dataset.train_labels==i)[0] + # Insure that we select random datapoints + indeces = np.random.permutation(cl) + self.indeces_known.append(indeces[0]) + self.indeces_unknown.extend(indeces[1:]) + self.indeces_known = np.array(self.indeces_known) + self.indeces_unknown = np.array(self.indeces_unknown) + # self.indeces_unknown now containts first all points of class1, then all points of class2 etc. + # So, we permute them + self.indeces_unknown = np.random.permutation(self.indeces_unknown) + # Then, sample the rest of the datapoints at random + if n_start > self.n_classes: + self.indeces_known = np.concatenate(([self.indeces_known, self.indeces_unknown[0:n_start-self.n_classes]])) + self.indeces_unknown = self.indeces_unknown[n_start-self.n_classes:] + + # BUILD AN INITIAL MODEL + # Get the data corresponding to the selected indeces + known_data = self.dataset.train_data[self.indeces_known,:] + known_labels = self.dataset.train_labels[self.indeces_known] + unknown_data = self.dataset.train_data[self.indeces_unknown,:] + unknown_labels = self.dataset.train_labels[self.indeces_unknown] + # Train a model using data corresponding to indeces_known + known_labels = np.ravel(known_labels) + self.model.fit(known_data, known_labels) + # Compute the quality score + test_prediction = self.model.predict(self.dataset.test_data) + new_score = self.quality_method(self.dataset.test_labels, test_prediction) + self.episode_qualities.append(new_score) + # Get the features categorizing the state + classifier_state, next_action_state = self._get_state() + self.n_actions = np.size(self.indeces_unknown) + return classifier_state, next_action_state + + def step(self, action): + """Makes a step in the environment. + + Follow the action, in this environment it means labelling a datapoint + at position 'action' in indeces_unknown. + + Args: + action: An interger indication the position of a datapoint to label. + + Returns: + classifier_state: a numpy.ndarray + of size #features characterising state = #datasamples in dataset.state_data + that characterizes the current classifier + next_action_state: a numpy.ndarray + of size #features characterising actions (currently, 3) x #unlabelled datapoints + where each column corresponds to the vector characterizing each possible action. + reward: A float with the reward after adding a new datapoint. + done: A boolean indicator if the episode in terminated. + """ + # Action indicates the position of a datapoint in self.indeces_unknown + # that we want to sample in unknown_data + # The index in train_data should be retrieved + selection_absolute = self.indeces_unknown[action] + # Label a datapoint: add its index to known samples and removes from unknown + self.indeces_known = np.concatenate(([self.indeces_known, np.array([selection_absolute])])) + self.indeces_unknown = np.delete(self.indeces_unknown, action) + # Train a model with new labeled data + known_data = self.dataset.train_data[self.indeces_known,:] + known_labels = self.dataset.train_labels[self.indeces_known] + known_labels = np.ravel(known_labels) + self.model.fit(known_data, known_labels) + # Get a new state + classifier_state, next_action_state = self._get_state() + # Update the number of available actions + self.n_actions = np.size(self.indeces_unknown) + # Compute the quality of the current classifier + test_prediction = self.model.predict(self.dataset.test_data) + new_score = self.quality_method(self.dataset.test_labels, test_prediction) + self.episode_qualities.append(new_score) + # Compute the reward + reward = self._compute_reward() + # Check if this episode terminated + done = self._compute_is_terminal() + return classifier_state, next_action_state, reward, done + + def _get_state(self): + """Private function for computing the state depending on the classifier and next available actions. + + This function computes 1) classifier_state that characterises + the current state of the classifier and it is computed as + a function of predictions on the hold-out dataset + 2) next_action_state that characterises all possible actions + (unlabelled datapoints) that can be taken at the next step. + + Returns: + classifier_state: a numpy.ndarray + of size of number of datapoints in dataset.state_data + characterizing the current classifier and, thus, the + state of the environment + next_action_state: a numpy.ndarray + of size #features characterising actions (currently, 3) x #unlabelled datapoints + where each column corresponds to the vector characterizing each possible action. + """ + # COMPUTE CLASSIFIER_STATE + predictions = self.model.predict_proba(self.dataset.state_data)[:,0] + predictions = np.array(predictions) + idx = np.argsort(predictions) + # the state representation is the *sorted* list of scores + classifier_state = predictions[idx] + + # COMPUTE ACTION_STATE + unknown_data = self.dataset.train_data[self.indeces_unknown,:] + # prediction (score) of classifier on each unlabelled sample + a1 = self.model.predict_proba(unknown_data)[:,0] + # average distance to every unlabelled datapoint + a2 = np.mean(self.dataset.distances[self.indeces_unknown,:][:,self.indeces_unknown],axis=0) + # average distance to every labelled datapoint + a3 = np.mean(self.dataset.distances[self.indeces_known,:][:,self.indeces_unknown],axis=0) + next_action_state = np.concatenate(([a1], [a2], [a3]), axis=0) + return classifier_state, next_action_state + + def _compute_reward(self): + """Private function to computes the reward. + + Default function always returns 0. + Every sub-class should implement its own reward function. + + Returns: + reward: a float reward + """ + reward = 0.0 + return reward + + def _compute_is_terminal(self): + """Private function to compute if the episode has reaches the terminal state. + + By default episode terminates when all the data was labelled. + Every sub-class should implement its own episode termination function. + + Returns: + done: A boolean that indicates if the episode is finished. + """ + # self.n_actions contains a number of unlabelled datapoints that is left + if self.n_actions==1: + # print('We ran out of samples!') + done = True + else: + done = False + return done + + +class LalEnvIncrementalReduction(LalEnv): + """The LAL environment class with reward that is incremental error reduction. + + This class inherits from LalEnv. + The reward is the difference between + the test errors at the consequetive + iterations. The terminal state is reached + when n_horizon samples are labelled. + + Attributes: + n_horizon: An integer indicating how many steps can be made in an episode. + """ + + def __init__(self, dataset, model, quality_method, n_horizon = 10): + """Inits environment with its normal attributes + n_horizon (the length of the episode).""" + LalEnv.__init__(self, dataset, model, quality_method) + self.n_horizon = n_horizon + + def _compute_reward(self): + """Computes the reward. + + Computes the reward that is the difference + between the previous model score and new model score. + + Returns: + reward: A float reward. + """ + last_score = self.episode_qualities[-2] + new_score = self.episode_qualities[-1] + reward = new_score - last_score + return reward + + def _compute_is_terminal(self): + """Computes if the episode has reaches the terminal state. + + The end of the episode is reached when number + of labelled points reaches predifined horizon. + + Returns: + done: A boolean that indicates if the episode is finished. + """ + # by default the episode will terminate when all samples are labelled + done = LalEnv._compute_is_terminal(self) + # it also terminates when self.n_horizon datapoints were labelled + if np.size(self.indeces_known) == self.n_horizon: + done = True + return done + + +class LalEnvTargetAccuracy(LalEnv): + """The LAL environment class where the episode lasts until a classifier reaches a predifined quality. + + This class inherits from LalEnv. + The reward is -1 at every step. + The terminal state is reached + when the predefined classificarion + quality is reached. Classification + quality is defined as a proportion + of the final quality (that is obtained + when all data is labelled). + + Attributes: + tolerance_level: A float indicating what proportion of the maximum reachable score + should be attained in order to terminate the episode. + target_quality: A float indication the minimum required accuracy + after reaching which the episode is terminated. + """ + + def __init__(self, dataset, model, quality_method, tolerance_level=0.9): + """Inits environment with its normal attributes + tolerance_level (proportion of quality to reach).""" + LalEnv.__init__(self, dataset, model, quality_method) + self.tolerance_level = tolerance_level + self._set_target_quality() + + def _set_target_quality(self): + """Sets the target accuracy according to the tolerance_level. + + This function computes the best reachable quality of the model + on the full potential training data and sets the target_quality + as tolerance_level*max_qualtity. + """ + best_model = clone(self.model) + # train and avaluate the model on the full size of potential dataset + best_model.fit(self.dataset.train_data, np.ravel(self.dataset.train_labels)) + test_prediction = best_model.predict(self.dataset.test_data) + max_quality = self.quality_method(self.dataset.test_labels, test_prediction) + # the target_quality after which the episode stops is a proportion of the max quality + self.target_quality = self.tolerance_level*max_quality + + def reset(self, n_start=2): + """Resets the environment. + + First, do, what is done for the parent environment and then: + 4) The target quality for this experiment is set. + + Args: + n_start: An integer indicating the size of annotated set at the beginning. + + Returns: + the same as the parent class. + """ + classifier_state, next_action_state = LalEnv.reset(self, n_start=n_start) + self._set_target_quality() + return classifier_state, next_action_state + + def _compute_reward(self): + """Computes the reward. + + The reward is -1 in all states. In the terminal state, + the environment will stop issueing a negative reward + and suffering of annotating data is be ended. + + Returns: + reward: A float: -1. + """ + reward = -1 + return reward + + + def _compute_is_terminal(self): + """Computes if the episode has reached the terminal state. + + The end of the episode is reached when the + classification accuracy reaches the predefined + level. + + Returns: + done: A boolean that indicates if the episode is finished. + """ + new_score = self.episode_qualities[-1] + # by default the episode will terminate when all samples are labelled + done = LalEnv._compute_is_terminal(self) + # it also terminates when a quality reaches a predefined level + if new_score >= self.target_quality: + done = True + return done \ No newline at end of file diff --git a/alipy/query_strategy/LAL_RL/helpers.py b/alipy/query_strategy/LAL_RL/helpers.py new file mode 100644 index 0000000..6315e1b --- /dev/null +++ b/alipy/query_strategy/LAL_RL/helpers.py @@ -0,0 +1,149 @@ +import numpy as np + +class Minibatch: + """Minibatch class that helps for gradient decent training. + + Attributes: + classifier_state: A numpy.ndarray of size batch_size x #classifier features + characterising the state of classifier at the sampled iterations + action_state: A numpy.ndarray of size batch_size x #action features + characterizing the action that was taken at the sampled iterations + reward: A numpy.ndarray of size batch_size + next_classifier_state: A numpy.ndarray of size batch_size x #classifier features + next_action_state: A list of size batch_size of numpy.ndarrays + characterizing the possible actions that were available at the sampled iterations + terminal: A numpy.ndarray of size batch_size of booleans indicating if the iteration was terminal + indeces: A numpy.ndarray of size batch_size that contains indeces of samples iterations in the replay buffer + """ + def __init__(self, classifier_state, action_state, reward, next_classifier_state, next_action_state, terminal, indeces): + """Inits the Minibatch object and initialises the attributes with given values.""" + self.classifier_state = classifier_state + self.action_state = action_state + self.reward = reward + self.next_classifier_state = next_classifier_state + self.next_action_state = next_action_state + self.terminal = terminal + self.indeces = indeces + + +class ReplayBuffer: + """Replay Buffer is used to store the transactions from episodes. + + Attributes: + buffer_size: An interger indicating the maximum number of transaction to be stored in the replay buffer. + n: An interger, the maximum index to be used for sampling. It is useful when the buffer is not filled in fully. + It grows from 0 till the buffer_size-1 and then stops changing. + write_index: An integer, the index where the next transaction should be written. + Goes from 0 till the buffer_size-1 and then starts from 0 again. + max_td_error: A float used to initialize the td error of newly added samples. + prior_exp: A float that is used for turning the td error into a probability to be sampled. + all_classifier_state: A numpy.ndarray of size batch_size x #classifier features + characterising the state of classifier at the sampled iterations. + all_action_states: A numpy.ndarray of size batch_size x #action features + characterizing the action that was taken at the sampled iterations. + all_rewards: A numpy.ndarray of size batch_size. + all_next_classifier_states: A numpy.ndarray of size batch_size x #classifier features + all_next_action_state: A list of size batch_size of numpy.ndarrays. + characterizing the possible actions that were available at the sampled iterations. + all_terminals: A numpy.ndarray of size batch_size of booleans indicating if the iteration was terminal. + all_td_errors: A numpy.ndarray of size batch_size with td errors of transactions + when each of them was used in a gradient update. + max_td_error: A float with the highest (absolute) value of td error from all transactions stored in the buffer. + """ + + def __init__(self, buffer_size=1e4, prior_exp=0.5): + """Inits a few attributes with 0 or the default values.""" + self.buffer_size = int(buffer_size) + self.n = 0 + self.write_index = 0 + self.max_td_error = 1000.0 + self.prior_exp = prior_exp + + def _init_nparray(self, classifier_state, action_state, reward, next_classifier_state, next_action_state, terminal): + """Initialize numpy arrays of all_xxx attributes to one transaction repeated buffer_size times.""" + self.all_classifier_states = np.array([classifier_state] * self.buffer_size) + self.all_action_state = np.array([action_state] * self.buffer_size) + self.all_rewards = np.array([reward] * self.buffer_size) + self.all_next_classifier_states = np.array([next_classifier_state] * self.buffer_size) + self.all_next_action_states = [next_action_state] * self.buffer_size + self.all_terminals = np.array([terminal] * self.buffer_size) + self.all_td_errors = np.array([self.max_td_error] * self.buffer_size) + # set the counters to 1 as one transaction is stored + self.n = 1 + self.write_index = 1 + + def store_transition(self, classifier_state, action_state, reward, next_classifier_state, next_action_state, terminal): + """Add a new transaction to a replay buffer.""" + # If buffer arrays not yet initialized, initialize it + if self.n == 0: + self._init_nparray(classifier_state, action_state, reward, next_classifier_state, next_action_state, terminal) + return + # write a tansaction at a write_index position + self.all_classifier_states[self.write_index] = classifier_state + self.all_action_state[self.write_index] = action_state + self.all_rewards[self.write_index] = reward + self.all_next_classifier_states[self.write_index] = next_classifier_state + self.all_next_action_states[self.write_index] = next_action_state + self.all_terminals[self.write_index] = terminal + self.all_td_errors[self.write_index] = self.max_td_error + # keep track of the index for writing + self.write_index += 1 + if self.write_index >= self.buffer_size: + self.write_index = 0 + # Keep track of the max index to be used for sampling. + if self.n < self.buffer_size: + self.n += 1 + + def sample_minibatch(self, batch_size=32): + """Sample a new minibatch from replay buffer. + + Args: + batch_size: An integer indicating how many transactions to be sampled from a replay buffer. + + Returns: + minibatch: An object of class Minibatch with sampled transactions. + """ + # Get td error of samples that were written in the buffer + td_errors_to_consider = self.all_td_errors[:self.n] + # Scale and normalize the td error to turn it into a probability for sampling + p = np.power(td_errors_to_consider, self.prior_exp) / np.sum(np.power(td_errors_to_consider, self.prior_exp)) + # choose indeces to sample according to the computed probability: + # the higher the td error is, the more likely it is that the sample will be selected + # first check if the number of non-zero elements in p is smaller than the batch_size + non_zero = np.count_nonzero(p) + if non_zero < batch_size <= self.n: + minibatch_indices = np.random.choice(range(self.n), size=non_zero, replace=False, p=p) + # add the missing elements + missing_elements = batch_size - non_zero + while missing_elements > 0: + num = np.random.choice(range(self.n)) + if not num in minibatch_indices: + minibatch_indices = np.concatenate((minibatch_indices, [num])) + missing_elements -= 1 + else: + minibatch_indices = np.random.choice(range(self.n), size=batch_size, replace=False, p=p) + minibatch = Minibatch( + self.all_classifier_states[minibatch_indices], + self.all_action_state[minibatch_indices], + self.all_rewards[minibatch_indices], + self.all_next_classifier_states[minibatch_indices], + [self.all_next_action_states[i] for i in minibatch_indices], + self.all_terminals[minibatch_indices], + minibatch_indices, + ) + return minibatch + + def update_td_errors(self, td_errors, indeces): + """Updates td_errors in replay buffer. + + After a gradient step was made, we need to updates + td errors to recently calculated errors. + + Args: + td_errors: A numpy array with new td errors. + indeces: A numpy array with indeces of points which td errors should be updated. + """ + # set the values for prioritized replay to the most recent td errors + self.all_td_errors[indeces] = np.ravel(np.absolute(td_errors)) + # find the max error from the replay buffer that will be used as a default value for new transactions + self.max_td_error = np.max(self.all_td_errors) \ No newline at end of file diff --git a/alipy/query_strategy/query_labels.py b/alipy/query_strategy/query_labels.py index 1876178..89cbff0 100644 --- a/alipy/query_strategy/query_labels.py +++ b/alipy/query_strategy/query_labels.py @@ -22,10 +22,15 @@ from sklearn.metrics.pairwise import rbf_kernel, polynomial_kernel, linear_kernel from sklearn.neighbors import kneighbors_graph from sklearn.utils.multiclass import unique_labels +import torch +from toma import toma +from tqdm.auto import tqdm from .base import BaseIndexQuery from ..utils.ace_warnings import * from ..utils.misc import nsmallestarg, randperm, nlargestarg +from . import joint_entropy +from .LAL_RL import QueryInstanceLAL_RL __all__ = ['QueryInstanceUncertainty', 'QueryRandom', @@ -39,6 +44,8 @@ 'QueryInstanceLAL', 'QueryInstanceCoresetGreedy', 'QueryInstanceDensityWeighted', + 'QueryInstanceBatchBALD', + 'QueryInstanceLAL_RL' ] @@ -1956,3 +1963,160 @@ def select(self, label_index, unlabel_index, batch_size=1, model=None, proba_pre assert len(pat) == len(div) == len(unlabel_index) scores = np.multiply(pat, div) return np.asarray(unlabel_index)[nlargestarg(scores, batch_size)] + + +class QueryInstanceBatchBALD(BaseIndexQuery): + """Reference Paper: + Andreas Kirsch, Joost van Amersfoort, Yarin Gal. 2019. + BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning + https://arxiv.org/abs/1906.08158 + + The implementation is referred to + https://github.com/BlackHC/batchbald_redux + + Parameters + ---------- + X: 2D array, optional (default=None) + Feature matrix of the whole dataset. It is a reference which will not use additional memory. + + y: array-like, optional (default=None) + Label matrix of the whole dataset. It is a reference which will not use additional memory. + """ + def __init__(self, X=None, y=None, verbose=1, **kwargs): + super(QueryInstanceBatchBALD, self).__init__(X, y) + self.verbose = verbose + + def select(self, label_index, unlabel_index, model=None, batch_size=1, num_samples=None, device=None, **kwargs): + """ + Parameters + ---------- + label_index: {list, np.ndarray, IndexCollection} + The indexes of labeled samples. + + unlabel_index: {list, np.ndarray, IndexCollection} + The indexes of unlabeled samples. + + model: object, optional (default=None) + Classification model, should be a bayesian model and must have the 'predict_proba' method for probabilistic output. + If not provided, LogisticRegression with default parameters implemented by sklearn will be used. + + batch_size: int, optional (default=1) + Selection batch size. + + num_samples: int, optional (default=None) + the maximum amount of memory that is used for entropy calculation. + + device: torch.device, optional (default=None) + pytorch device that will be used for the calculations, default is cpu. + if a cuda device is given, then the model must accept a device for the predict_proba method. + """ + assert (batch_size > 0) + assert (isinstance(unlabel_index, collections.Iterable)) + unlabel_index = np.asarray(unlabel_index) + if len(unlabel_index) <= batch_size: + return unlabel_index + + if self.X is None: + raise Exception('Data matrix is not provided, use select_by_prediction_mat() instead.') + if model is None: + model = LogisticRegression(solver='liblinear') + model.fit(self.X[label_index if isinstance(label_index, (list, np.ndarray)) else label_index.index], + self.y[label_index if isinstance(label_index, (list, np.ndarray)) else label_index.index]) + unlabel_x = self.X[unlabel_index, :] + + if device != None: + pred = model.predict_proba(unlabel_x, device) + else: + pred = model.predict_proba(unlabel_x) + + if len(pred.shape) == 2: + # assuming first dim is num of samples and second dim is num of classes + pred = pred[:,None,:] + elif len(pred.shape) != 3: + raise ValueError("predict_proba of the model should return array if dim 2 or 3") + + pred_tensor = torch.Tensor(pred).to(device) + + if num_samples == None: + num_samples = pred.shape[2] ** batch_size + + return unlabel_index[self.get_batchbald_batch(pred_tensor.log().double(), batch_size, num_samples, + dtype=torch.double, device=device)] + + def get_batchbald_batch(self, log_probs_N_K_C: torch.Tensor, batch_size: int, num_samples: int, dtype=None, device=None): + N, K, C = log_probs_N_K_C.shape + + batch_size = min(batch_size, N) + + candidate_indices = [] + candidate_scores = [] + + if batch_size == 0: + return [] + + conditional_entropies_N = self.compute_conditional_entropy(log_probs_N_K_C) + + batch_joint_entropy = joint_entropy.DynamicJointEntropy( + num_samples, batch_size - 1, K, C, dtype=dtype, device=device, verbose=self.verbose + ) + + # We always keep these on the CPU. + scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available()) + + if self.verbose >= 1: + p_bar = tqdm(total=batch_size, desc="BatchBALD", leave=False) + def update(): + p_bar.update() + def close(): + p_bar.close() + else: + update = lambda *x : None + close = lambda *x : None + + for i in range(batch_size): + if i > 0: + latest_index = candidate_indices[-1] + batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1]) + + shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum() + + batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N) + + scores_N -= conditional_entropies_N + shared_conditinal_entropies + scores_N[candidate_indices] = -float("inf") + + candidate_score, candidate_index = scores_N.max(dim=0) + + candidate_indices.append(candidate_index.item()) + candidate_scores.append(candidate_score.item()) + + update() + + close() + return candidate_indices + + + def compute_conditional_entropy(self, log_probs_N_K_C: torch.Tensor) -> torch.Tensor: + N, K, C = log_probs_N_K_C.shape + + entropies_N = torch.empty(N, dtype=torch.double) + + if self.verbose >= 1: + pbar = tqdm(total=N, desc="Conditional Entropy", leave=False) + def update(): + pbar.update() + def close(): + pbar.close() + else: + update = lambda *x : None + close = lambda *x : None + + @toma.execute.chunked(log_probs_N_K_C, 1024) + def compute(log_probs_n_K_C, start: int, end: int): + nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C) + + entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K) + update() + + close() + return entropies_N diff --git a/setup.py b/setup.py index cd42529..8e4c170 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ author_email='tangyp@nuaa.edu.cn, GuoXiangLi@nuaa.edu.cn, huangsj@nuaa.edu.cn', url='https://github.com/NUAA-AL/ALiPy', setup_requires=[], - install_requires=['numpy', 'scipy', 'scikit-learn', 'matplotlib', 'prettytable'], + install_requires=['numpy', 'scipy', 'scikit-learn', 'matplotlib', 'prettytable', 'torch', 'tqdm', 'toma'], packages=[ 'alipy', 'alipy.data_manipulate',