Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/java/autoweka/ClassifierRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import weka.attributeSelection.AttributeSelection;
import java.util.Map;
import java.util.Arrays;

import java.util.Collections;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -42,6 +42,7 @@ public class ClassifierRunner
private boolean mDisableOutput = false;
private java.io.PrintStream mSavedOutput = null;
private String mPredictionsFileName = null;
private String mIndividualResultsFileName;

/**
* Prepares a runner with the specified properties.
Expand All @@ -58,6 +59,7 @@ public ClassifierRunner(Properties props)
mTestOnly = Boolean.valueOf(props.getProperty("onlyTest", "false"));
mDisableOutput = Boolean.valueOf(props.getProperty("disableOutput", "false"));
mPredictionsFileName = props.getProperty("predictionsFileName", null);
mIndividualResultsFileName = props.getProperty("individualResultsFile", "individual-results.tsv");
}

/**
Expand Down Expand Up @@ -313,6 +315,35 @@ private ClassifierResult _run(String instanceStr, String resultMetric, float tim
attribEvalClassName, argMap.get("attributeeval"),
attribSearchClassName, argMap.get("attributesearch"),
instanceStr, res.getRawScore());

try {
FileWriter writer = new FileWriter(mIndividualResultsFileName, true);
StringBuilder builder = new StringBuilder();
String delim = "\t";

List<String> attributeEvalArgs = argMap.get("attributeeval");
String strAttributeEvalArgs = attributeEvalArgs != null ? Util.joinStrings(" ", attributeEvalArgs) : "";

List<String> attributeSearchArgs = argMap.get("attributesearch");
String strAttributeSearchArgs = attributeSearchArgs != null ? Util.joinStrings(" ", attributeSearchArgs) : "";

builder
.append(targetClassifierName).append(delim)
.append(Util.joinStrings(" ", argsArraySaved)).append(delim)
.append(attribEvalClassName).append(delim)
.append(strAttributeEvalArgs).append(delim)
.append(attribSearchClassName).append(delim)
.append(strAttributeSearchArgs).append(delim)
.append(instanceStr).append(delim)
.append(res.getRawScore()).append(delim)
.append("\n");
writer.write(builder.toString());
writer.flush();
writer.close();
}
catch(IOException e) {
log.error(e.toString());
}

log.debug("Num Training: {}, num testing: {}", training.numInstances(), testing.numInstances());
return res;
Expand Down
1 change: 1 addition & 0 deletions src/java/autoweka/tools/ExperimentRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.io.File;
import java.net.URLDecoder;
import java.util.Arrays;

import autoweka.Experiment;
import autoweka.TrajectoryParser;
Expand Down
124 changes: 122 additions & 2 deletions src/java/weka/classifiers/meta/AutoWEKAClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.InputStreamReader;
import java.io.Serializable;

import java.nio.file.Files;

import java.text.DecimalFormat;
import java.net.URLDecoder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
Expand Down Expand Up @@ -101,6 +102,8 @@ public class AutoWEKAClassifier extends AbstractClassifier implements Additional
static final int DEFAULT_MEM_LIMIT = 1024;
/** Default */
static final int DEFAULT_N_BEST = 1;
static final int DEFAULT_SHOW_DETAILED_RESULTS = 0;

/** Internal evaluation method. */
static enum Resampling {
CrossValidation,
Expand Down Expand Up @@ -165,6 +168,8 @@ static enum Resampling {
protected int memLimit = DEFAULT_MEM_LIMIT;
/** The amout of best configurations to return as output*/
protected int nBestConfigs = DEFAULT_N_BEST;
/** The number of detailed configurations to list in the output (SMAC score and options)*/
protected int nDetailedResults;
/** The internal evaluation method. */
protected Resampling resampling = DEFAULT_RESAMPLING;
/** The arguments to the evaluation method. */
Expand All @@ -180,6 +185,8 @@ static enum Resampling {

private transient weka.gui.Logger wLog;

private String detailedResultString;

/**
* Main method for testing this class.
*
Expand Down Expand Up @@ -367,6 +374,43 @@ public void run() {
log.info("classifier: {}, arguments: {}, attribute search: {}, attribute search arguments: {}, attribute evaluation: {}, attribute evaluation arguments: {}",
classifierClass, classifierArgs, attributeSearchClass, attributeSearchArgs, attributeEvalClass, attributeEvalArgs);

if (nDetailedResults > 0) {
String indvResultsFile = "individual-results.tsv"; // TODO
BufferedReader reader = new BufferedReader(new FileReader(new File(msExperimentPath + expName, indvResultsFile)));

Map<String, ConfigurationStats> totals = new HashMap<String, ConfigurationStats>();
for (String line = reader.readLine(); line != null; line = reader.readLine()) {
String[] parts = line.split("\t");
String classifierName = parts[0];
String id = classifierName + " " + parts[1] + ", ";
id += parts[2] + " " + parts[3] + ", ";
id += parts[4] + " " + parts[5];
ConfigurationStats stats = totals.get(id);
if (stats == null) {
stats = new ConfigurationStats(id);
totals.put(id, stats);
}
stats.addValue(Double.parseDouble(parts[7]));
}
StringBuilder builder = new StringBuilder();
builder.append("\n======= Detailed Results ======").append("\n\n");
List<ConfigurationStats> sorted = new ArrayList<ConfigurationStats>(totals.values());
Collections.sort(sorted);
int remaining = Math.min(nDetailedResults, sorted.size());
for (ConfigurationStats stats : sorted) {
// TODO: Ignore config stats that do not have enough data to be meaningful
builder.append(stats.toString()).append("\n");
remaining--;
if (remaining == 0) {
break;
}
}
if (remaining > 0)
builder.append(remaining + " other results truncated");
builder.append("\n");
detailedResultString = builder.toString();
reader.close();
}

//Print log of best configurations
if (nBestConfigs>1){
Expand Down Expand Up @@ -446,6 +490,9 @@ public Enumeration<Option> listOptions() {
result.addElement(
new Option("\tThe amount of best configurations to return.\n" + "\t(default: " + DEFAULT_MEM_LIMIT + ")",
"nBestConfigs", 1, "-nBestConfigs <limit>"));
result.addElement(
new Option("\tThe number of detailed results that should be shown.\n" + "\t(default: " + DEFAULT_SHOW_DETAILED_RESULTS + ")",
"nDetailedResults", 1, "-nDetailedResults <n>"));
//result.addElement(
// new Option("\tThe type of resampling used.\n" + "\t(default: " + String.valueOf(DEFAULT_RESAMPLING) + ")",
// "resampling", 1, "-resampling <resampling>"));
Expand Down Expand Up @@ -481,6 +528,8 @@ public String[] getOptions() {
result.add("" + memLimit);
result.add("-nBestConfigs");
result.add("" + nBestConfigs);
result.add("-nDetailedResults");
result.add("" + nDetailedResults);
//result.add("-resampling");
//result.add("" + resampling);
//result.add("-resamplingArgs");
Expand Down Expand Up @@ -528,6 +577,13 @@ public void setOptions(String[] options) throws Exception {
} else {
nBestConfigs = DEFAULT_N_BEST;
}

tmpStr = Utils.getOption("nDetailedResults", options);
if (tmpStr.length() != 0) {
nDetailedResults = Integer.parseInt(tmpStr);
} else {
nDetailedResults = DEFAULT_SHOW_DETAILED_RESULTS;
}


//tmpStr = Utils.getOption("resampling", options);
Expand Down Expand Up @@ -643,6 +699,21 @@ public int getnBestConfigs() {
return nBestConfigs;
}

/**
* @return the number of detailed configs to show
*/
public int getnDetailedResults() {
return nDetailedResults;
}

/**
* Sets the number of detailed configs to show
* @param nDetailedResults the number of detailed configs to show
*/
public void setnDetailedResults(int nDetailedResults) {
this.nDetailedResults = nDetailedResults;
}

/**
* Returns the tip text for this property.
* @return tip text for this property
Expand All @@ -651,7 +722,15 @@ public String nBestConfigsTipText() {
return "How many of the best configurations should be returned as output";
}

/**
* Returns the tip text for this property.
* @return tip text for this property
*/
public String nDetailedResultsTipText() {
return "How many of the best individual runs should be show in the output";
}


//public void setResampling(Resampling r) {
// resampling = r;
// resamplingArgs = resamplingArgsMap.get(r);
Expand Down Expand Up @@ -784,6 +863,12 @@ public String toString() {
res += eval.toClassDetailsString();
} catch(Exception e) { /*TODO treat*/ }

if (nDetailedResults > 0) {
res += "\n";
res += detailedResultString;
res += "\n";
}

if(nBestConfigs>1){

ConfigurationCollection cc = ConfigurationCollection.fromXML(msExperimentPath+expName+"/"+configurationRankingPath,ConfigurationCollection.class);
Expand Down Expand Up @@ -836,4 +921,39 @@ public double getMeasure(String additionalMeasureName) {
+ " not supported (Auto-WEKA)");
}
}

private static class ConfigurationStats implements Comparable<ConfigurationStats> {
private static final DecimalFormat df = new DecimalFormat("0.00");
int n;
double total;
String id;

public ConfigurationStats(String id) {
this.id = id;
}

public void addValue(double value) {
total += value;
// TODO: might be useful to hold on to the individual values and/or compute additional summary stats
n++;
}

public double getTotalRuns() {
return total;
}

public double getAverageScore() {
return total / n;
}

@Override
public int compareTo(ConfigurationStats o) {
return Double.compare(this.getAverageScore(), o.getAverageScore());
}

@Override
public String toString() {
return df.format(total/n) + " (" + df.format(total) + " / " + n + "): " + id;
}
}
}