Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/java/autoweka/Configuration.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ public int compareTo(Object aTarget){ //Compares only the average score. If nece
this.lazyUpdateAverage();
cTarget.lazyUpdateAverage();

if (this.mFolds.size() > cTarget.mFolds.size()){
return 1;
}else if (this.mFolds.size() < cTarget.mFolds.size()){
return -1;
}else{
if (this.mAverageScore < cTarget.mAverageScore ) return 1; //Assumes smaller score is better. If that isn't the case, change that.
else if (this.mAverageScore > cTarget.mAverageScore) return -1; // @TODO make this class receive a Metric as input and do this change automatically
else return 0;
}
if (this.mFolds.size() > cTarget.mFolds.size()){
return 1;
}else if (this.mFolds.size() < cTarget.mFolds.size()){
return -1;
}else{
if (this.mAverageScore < cTarget.mAverageScore ) return 1; //Assumes smaller score is better. If that isn't the case, change that.
else if (this.mAverageScore > cTarget.mAverageScore) return -1; // @TODO make this class receive a Metric as input and do this change automatically
else return 0;
}

}

Expand Down Expand Up @@ -192,7 +192,7 @@ public double getAverageScore(){
lazyUpdateAverage();
return mAverageScore;
}


public int getEvaluationAmount(){
return this.mScores.size();
Expand Down
54 changes: 45 additions & 9 deletions src/java/autoweka/ConfigurationCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ public class ConfigurationCollection extends XmlSerializable{
public ConfigurationCollection(){
mConfigurations = new ArrayList<Configuration>();
}

public ConfigurationCollection(List<Configuration> aConfigurations){
mConfigurations = new ArrayList<Configuration> (aConfigurations);
}
public ConfigurationCollection(ArrayList<Configuration> aConfigurations){
mConfigurations = aConfigurations;
}

public void add(Configuration c){
mConfigurations.add(c);
}
Expand All @@ -38,19 +37,56 @@ public int size (){
return mConfigurations.size();
}

//Returns the amount of configurations evaluated on all folds. Assumes the thing is sorted. TODO check if it is
public int getFullyEvaluatedAmt() {
int largestFoldAmt = mConfigurations.get(0).getAmtFolds();
//Gets the highest amount of fold evaluations from all the configurations within the collection
public int getHighestEvaluationAmt(){
int highest = 0;
for (Configuration c : mConfigurations){
if(c.getAmtFolds()>highest){
highest = c.getAmtFolds();
}
}
return highest;
}

//Returns the amount of configurations evaluated on all folds. The input integer is the maximum amount of folds possible according to user's CV options.
public int getFullyEvaluatedAmt(int maxAmt){
int counter = 0;
for (Configuration c : mConfigurations){
if(c.getAmtFolds()==largestFoldAmt) counter++;
if(c.getAmtFolds()==maxAmt) counter++;
}
return counter;
}

//Optional version of the method that assumes the configuration with the most evaluations has the highest possible amount
public int getFullyEvaluatedAmt() {
int maxAmt = this.getHighestEvaluationAmt();
return this.getFullyEvaluatedAmt(maxAmt);
}

//Returns a new ConfigurationCollection containing only the fully evaluated configurations
public ConfigurationCollection getFullyEvaluatedCollection(int maxAmt){
List<Configuration> rvConfigurations = new ArrayList<Configuration>();
for(Configuration c : mConfigurations){
if (c.getAmtFolds() == maxAmt){
rvConfigurations.add(c);
}
}
return counter;
}
ConfigurationCollection rv = new ConfigurationCollection(rvConfigurations);
return rv;
}

//Optional version of the method that assumes the configuration with the most evaluations has the highest possible amount
public ConfigurationCollection getFullyEvaluatedCollection(){
int maxAmt = this.getHighestEvaluationAmt();
return this.getFullyEvaluatedCollection(maxAmt);
}



public ArrayList<Configuration> asArrayList(){
return mConfigurations;
}

public static <T extends XmlSerializable> T fromXML(String filename, Class<T> c){ //Original is protected so we're overriding to make it public.
return XmlSerializable.fromXML(filename,c);
}
Expand Down
2 changes: 1 addition & 1 deletion src/java/autoweka/ConfigurationRanker.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ public static void rank(int n, String temporaryDirPath, String smacBest) throws
Set<String> configHashes;

//Reading the hashes and removing duplicates
//TODO improve hashing

redundantConfigHashes = (new Scanner(hashSetFile)).nextLine().split(",");
configHashes = new HashSet<String>(Arrays.asList(redundantConfigHashes));

//

for(String hash : configHashes){
configs.add(Configuration.fromXML(cdPath+hash+".xml",Configuration.class));
Expand Down
111 changes: 111 additions & 0 deletions src/java/autoweka/Ensemble.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package autoweka;

import weka.classifiers.meta.AutoWEKAClassifier;
import weka.core.Instance;
import weka.core.Instances;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Ensemble {

private List<EnsembleElement> elements;
private Map<AutoWEKAClassifier.Metric,Double> lastMeasuredPerformances;

@Override
public String toString(){
return elements.toString();
}


public Ensemble(){
this.elements = new ArrayList<EnsembleElement>();
this.lastMeasuredPerformances = new HashMap<AutoWEKAClassifier.Metric, Double>();
}

public Ensemble(List<EnsembleElement> elements){
this.elements = new ArrayList<EnsembleElement>(elements); //shallow copying
this.lastMeasuredPerformances = new HashMap<AutoWEKAClassifier.Metric, Double>();
}

public Ensemble shallowCopy(){
Ensemble rv = new Ensemble(this.elements);
for(Map.Entry<AutoWEKAClassifier.Metric,Double> e : lastMeasuredPerformances.entrySet()){
rv.lastMeasuredPerformances.put(e.getKey(),e.getValue());
}
return rv;
}

public void appendElement(EnsembleElement ee){
this.elements.add(ee);
}

public void removeLastElement(){
this.elements.remove(this.elements.size()-1);
}

public Double measureEnsemblePerformance(Instances validationSet, AutoWEKAClassifier.Metric metric, String evaluateAlgorithm){
//TODO implement the computation for other metrics we might want. So far it only supports error rate
double incorrectAmount = 0;

for(int i = 0; i<validationSet.numInstances(); i++){
Instance currentInst = validationSet.instance(i);
if(currentInst.classValue() != this.evaluateInstance(currentInst,evaluateAlgorithm)){
incorrectAmount+=1;
}
}
double errorRate = (incorrectAmount/(double)validationSet.numInstances());

this.lastMeasuredPerformances.put(metric,errorRate);
return errorRate;
}

public Double getLastMeasuredPerformance(AutoWEKAClassifier.Metric m){
Double rv =lastMeasuredPerformances.get(m);
System.out.println("gLMP:"+rv);
if( rv==null ){
throw new RuntimeException("Trying to check the last measured performance for a metric whose performance was never measured");
}else{
return rv;
}
}

private Double evaluateInstance(Instance i, String evaluateAlgorithm){
if(evaluateAlgorithm.equals("MAJORITY_VOTING")){
return evaluateByMajorityVoting(i);
}else{
throw new RuntimeException("Invalid ensemble evaluation algorithm");
}
}

private Double evaluateByMajorityVoting(Instance i){
Map<Double,Integer> votes = new HashMap<Double, Integer>();

for(EnsembleElement ee : elements){
Double vote = ee.evaluateInstance(i);
Integer amount = votes.get(vote);
if (amount == null) {
votes.put(vote, 1);
}else{
votes.put(vote, amount + 1);
}
}

Map.Entry<Double, Integer> firstEntry = votes.entrySet().iterator().next();
Double rv = firstEntry.getKey();
Integer maxVotes = firstEntry.getValue();
for(Map.Entry<Double,Integer> entry : votes.entrySet()){
if(entry.getValue()>=maxVotes){
maxVotes = entry.getValue();
rv = entry.getKey();
}
}
return rv;
}

public List<EnsembleElement> getElements() {
return elements;
}
}
127 changes: 127 additions & 0 deletions src/java/autoweka/EnsembleElement.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package autoweka;

import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.AttributeSelection;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class EnsembleElement{


private Configuration configuration;

/** The chosen classifier. */
protected Classifier classifier;
/** The chosen attribute selection method. */
protected AttributeSelection as;

/** The class of the chosen classifier. */
protected String classifierClass;
/** The arguments of the chosen classifier. */
protected String[] classifierArgs;

/** The class of the chosen attribute search method. */
protected String attributeSearchClass;
/** The arguments of the chosen attribute search method. */
protected String[] attributeSearchArgs;

/** The class of the chosen attribute evaluation. */
protected String attributeEvalClass;
/** The arguments of the chosen attribute evaluation method. */
protected String[] attributeEvalArgs;


private Map<Instance,Double> cachedPredictions;

public String toString(){
return "Classifier Class: "+classifierClass+"\nArgs:[+\n"+String.join(",",classifierArgs)+"]\n";
}

public Configuration getConfiguration(){ return this.configuration; }

public String getClassifierClass(){
return this.classifierClass;
}

public EnsembleElement(Configuration configuration){
this.configuration = configuration;
this.cachedPredictions = new HashMap<Instance, Double>();

WekaArgumentConverter.Arguments wekaArgs = WekaArgumentConverter.convert(Arrays.asList(configuration.getArgStrings().split(" ")));
classifierClass = wekaArgs.propertyMap.get("targetclass");
String tempClassifierArgs = Util.joinStrings(" ", Util.quoteStrings(wekaArgs.argMap.get("classifier")));
classifierArgs = Util.splitQuotedString(tempClassifierArgs).toArray(new String[0]);

if(wekaArgs.propertyMap.containsKey("attributesearch") && !"NONE".equals(wekaArgs.propertyMap.get("attributesearch"))){
attributeSearchClass = wekaArgs.propertyMap.get("attributesearch");
String tempAttributeSearchArgs = Util.joinStrings(" ", Util.quoteStrings(wekaArgs.argMap.get("attributesearch")));
if(tempAttributeSearchArgs != null) {
attributeSearchArgs = Util.splitQuotedString(tempAttributeSearchArgs).toArray(new String[0]);
}

attributeEvalClass = wekaArgs.propertyMap.get("attributeeval");
String tempAttributeEvalArgs = Util.joinStrings(" ", Util.quoteStrings(wekaArgs.argMap.get("attributeeval")));
if(tempAttributeEvalArgs != null) {
attributeEvalArgs = Util.splitQuotedString(tempAttributeEvalArgs).toArray(new String[0]);
}
}
}

public void train(Instances trainingInstances){

//Training
try{
as = new AttributeSelection();

if(attributeSearchClass != null) {
ASSearch asSearch = ASSearch.forName(attributeSearchClass, attributeSearchArgs.clone());
as.setSearch(asSearch);
}
if(attributeEvalClass != null) {
ASEvaluation asEval = ASEvaluation.forName(attributeEvalClass, attributeEvalArgs.clone());
as.setEvaluator(asEval);
}
as.SelectAttributes(trainingInstances);

classifier = AbstractClassifier.forName(classifierClass, classifierArgs.clone());

trainingInstances = as.reduceDimensionality(trainingInstances);
classifier.buildClassifier(trainingInstances);

}catch (Exception e){
throw new RuntimeException("Caught an exception while trying to train an EnsembleElement with argstrings:"+ configuration.getArgStrings());
}

}

public void cachePredictions(Instances validationInstances){

for(int i = 0; i<validationInstances.numInstances(); i++){
Instance inst = validationInstances.instance(i);
try{
Instance inst_withReduction = as.reduceDimensionality(inst);
cachedPredictions.put(inst,classifier.classifyInstance(inst_withReduction));
}catch(Exception e){
throw new RuntimeException("Caught an exception while trying to cache predictions for the EnsembleElement with argstrings:"+configuration.getArgStrings()+"\nError message:"+e.toString());
}
}

}


public double evaluateInstance(Instance i){
Double rv = cachedPredictions.get(i);
if(rv == null){
throw new RuntimeException("Something wrong with the instance pointers. Trying to evaluate an instance whose prediction by the classifier wasn't cached");
}else{
return rv;
}
}
}
Loading