diff --git a/src/java/autoweka/ClassifierRunner.java b/src/java/autoweka/ClassifierRunner.java index 10dc769e..46913887 100644 --- a/src/java/autoweka/ClassifierRunner.java +++ b/src/java/autoweka/ClassifierRunner.java @@ -17,7 +17,7 @@ import weka.attributeSelection.AttributeSelection; import java.util.Map; import java.util.Arrays; - +import java.util.Collections; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,6 +42,7 @@ public class ClassifierRunner private boolean mDisableOutput = false; private java.io.PrintStream mSavedOutput = null; private String mPredictionsFileName = null; + private String mIndividualResultsFileName; /** * Prepares a runner with the specified properties. @@ -58,6 +59,7 @@ public ClassifierRunner(Properties props) mTestOnly = Boolean.valueOf(props.getProperty("onlyTest", "false")); mDisableOutput = Boolean.valueOf(props.getProperty("disableOutput", "false")); mPredictionsFileName = props.getProperty("predictionsFileName", null); + mIndividualResultsFileName = props.getProperty("individualResultsFile", "individual-results.tsv"); } /** @@ -313,6 +315,35 @@ private ClassifierResult _run(String instanceStr, String resultMetric, float tim attribEvalClassName, argMap.get("attributeeval"), attribSearchClassName, argMap.get("attributesearch"), instanceStr, res.getRawScore()); + + try { + FileWriter writer = new FileWriter(mIndividualResultsFileName, true); + StringBuilder builder = new StringBuilder(); + String delim = "\t"; + + List attributeEvalArgs = argMap.get("attributeeval"); + String strAttributeEvalArgs = attributeEvalArgs != null ? Util.joinStrings(" ", attributeEvalArgs) : ""; + + List attributeSearchArgs = argMap.get("attributesearch"); + String strAttributeSearchArgs = attributeSearchArgs != null ? Util.joinStrings(" ", attributeSearchArgs) : ""; + + builder + .append(targetClassifierName).append(delim) + .append(Util.joinStrings(" ", argsArraySaved)).append(delim) + .append(attribEvalClassName).append(delim) + .append(strAttributeEvalArgs).append(delim) + .append(attribSearchClassName).append(delim) + .append(strAttributeSearchArgs).append(delim) + .append(instanceStr).append(delim) + .append(res.getRawScore()).append(delim) + .append("\n"); + writer.write(builder.toString()); + writer.flush(); + writer.close(); + } + catch(IOException e) { + log.error(e.toString()); + } log.debug("Num Training: {}, num testing: {}", training.numInstances(), testing.numInstances()); return res; diff --git a/src/java/autoweka/tools/ExperimentRunner.java b/src/java/autoweka/tools/ExperimentRunner.java index 49cf3653..ba6d357a 100644 --- a/src/java/autoweka/tools/ExperimentRunner.java +++ b/src/java/autoweka/tools/ExperimentRunner.java @@ -2,6 +2,7 @@ import java.io.File; import java.net.URLDecoder; +import java.util.Arrays; import autoweka.Experiment; import autoweka.TrajectoryParser; diff --git a/src/java/weka/classifiers/meta/AutoWEKAClassifier.java b/src/java/weka/classifiers/meta/AutoWEKAClassifier.java index db97cc40..6007ee83 100644 --- a/src/java/weka/classifiers/meta/AutoWEKAClassifier.java +++ b/src/java/weka/classifiers/meta/AutoWEKAClassifier.java @@ -42,13 +42,14 @@ import java.io.BufferedReader; import java.io.File; +import java.io.FileReader; import java.io.InputStreamReader; import java.io.Serializable; import java.nio.file.Files; - +import java.text.DecimalFormat; import java.net.URLDecoder; - +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Enumeration; @@ -101,6 +102,8 @@ public class AutoWEKAClassifier extends AbstractClassifier implements Additional static final int DEFAULT_MEM_LIMIT = 1024; /** Default */ static final int DEFAULT_N_BEST = 1; + static final int DEFAULT_SHOW_DETAILED_RESULTS = 0; + /** Internal evaluation method. */ static enum Resampling { CrossValidation, @@ -165,6 +168,8 @@ static enum Resampling { protected int memLimit = DEFAULT_MEM_LIMIT; /** The amout of best configurations to return as output*/ protected int nBestConfigs = DEFAULT_N_BEST; + /** The number of detailed configurations to list in the output (SMAC score and options)*/ + protected int nDetailedResults; /** The internal evaluation method. */ protected Resampling resampling = DEFAULT_RESAMPLING; /** The arguments to the evaluation method. */ @@ -180,6 +185,8 @@ static enum Resampling { private transient weka.gui.Logger wLog; + private String detailedResultString; + /** * Main method for testing this class. * @@ -367,6 +374,43 @@ public void run() { log.info("classifier: {}, arguments: {}, attribute search: {}, attribute search arguments: {}, attribute evaluation: {}, attribute evaluation arguments: {}", classifierClass, classifierArgs, attributeSearchClass, attributeSearchArgs, attributeEvalClass, attributeEvalArgs); + if (nDetailedResults > 0) { + String indvResultsFile = "individual-results.tsv"; // TODO + BufferedReader reader = new BufferedReader(new FileReader(new File(msExperimentPath + expName, indvResultsFile))); + + Map totals = new HashMap(); + for (String line = reader.readLine(); line != null; line = reader.readLine()) { + String[] parts = line.split("\t"); + String classifierName = parts[0]; + String id = classifierName + " " + parts[1] + ", "; + id += parts[2] + " " + parts[3] + ", "; + id += parts[4] + " " + parts[5]; + ConfigurationStats stats = totals.get(id); + if (stats == null) { + stats = new ConfigurationStats(id); + totals.put(id, stats); + } + stats.addValue(Double.parseDouble(parts[7])); + } + StringBuilder builder = new StringBuilder(); + builder.append("\n======= Detailed Results ======").append("\n\n"); + List sorted = new ArrayList(totals.values()); + Collections.sort(sorted); + int remaining = Math.min(nDetailedResults, sorted.size()); + for (ConfigurationStats stats : sorted) { + // TODO: Ignore config stats that do not have enough data to be meaningful + builder.append(stats.toString()).append("\n"); + remaining--; + if (remaining == 0) { + break; + } + } + if (remaining > 0) + builder.append(remaining + " other results truncated"); + builder.append("\n"); + detailedResultString = builder.toString(); + reader.close(); + } //Print log of best configurations if (nBestConfigs>1){ @@ -446,6 +490,9 @@ public Enumeration