diff --git a/grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java b/grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java index 805ac1a510..2168408269 100644 --- a/grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java +++ b/grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java @@ -1,5 +1,7 @@ package org.grobid.core.jni; +import org.apache.commons.lang3.StringUtils; + import org.grobid.core.GrobidModel; import org.grobid.core.engines.label.TaggingLabels; import org.grobid.core.exceptions.GrobidException; @@ -252,7 +254,35 @@ public void run() { LOGGER.error("DeLFT model training via JEP failed", e); } } - } + } + + protected static List getTrainCommand(String modelName, File trainingData, String architecture) { + String trainModule = GrobidProperties.getDeLFTTrainModule(); + if (StringUtils.isEmpty(trainModule)) { + trainModule = "grobidTagger.py"; + } + List command = new ArrayList<>(Arrays.asList( + "python3", + trainModule, + modelName, + "train", + "--input", trainingData.getAbsolutePath(), + "--output", GrobidProperties.getModelPath().getAbsolutePath() + )); + if (architecture != null) { + command.add("--architecture"); + command.add(architecture); + } + if (GrobidProperties.useELMo() && modelName.toLowerCase().indexOf("bert") == -1) { + command.add("--use-ELMo"); + } + if (StringUtils.isNotEmpty(GrobidProperties.getDeLFTTrainArgs())) { + command.addAll(Arrays.asList( + GrobidProperties.getDeLFTTrainArgs().split(" ") + )); + } + return command; + } /** * Train with an external process rather than with JNI, this approach appears to be more stable for the @@ -261,19 +291,8 @@ public void run() { public static void train(String modelName, File trainingData, File outputModel, String architecture) { try { LOGGER.info("Train DeLFT model " + modelName + "..."); - List command = Arrays.asList("python3", - "grobidTagger.py", - modelName, - "train", - "--input", trainingData.getAbsolutePath(), - "--output", GrobidProperties.getInstance().getModelPath().getAbsolutePath()); - if (architecture != null) { - command.add("--architecture"); - command.add(architecture); - } - if (GrobidProperties.getInstance().useELMo() && modelName.toLowerCase().indexOf("bert") == -1) { - command.add("--use-ELMo"); - } + List command = getTrainCommand(modelName, trainingData, architecture); + LOGGER.info("Running: {}", command); ProcessBuilder pb = new ProcessBuilder(command); File delftPath = new File(GrobidProperties.getInstance().getDeLFTFilePath()); diff --git a/grobid-core/src/main/java/org/grobid/core/utilities/GrobidProperties.java b/grobid-core/src/main/java/org/grobid/core/utilities/GrobidProperties.java index a66ab5af99..0ccfff37d3 100755 --- a/grobid-core/src/main/java/org/grobid/core/utilities/GrobidProperties.java +++ b/grobid-core/src/main/java/org/grobid/core/utilities/GrobidProperties.java @@ -483,6 +483,16 @@ public static boolean isDeLFTRedirectOutput() { ); } + public static String getDeLFTTrainModule() { + return getPropertyValue( + GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "" + ); + } + + public static String getDeLFTTrainArgs() { + return getPropertyValue(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, ""); + } + public static String getGluttonHost() { return getPropertyValue(GrobidPropertyKeys.PROP_GLUTTON_HOST); } diff --git a/grobid-core/src/main/java/org/grobid/core/utilities/GrobidPropertyKeys.java b/grobid-core/src/main/java/org/grobid/core/utilities/GrobidPropertyKeys.java index e4fdf97fdb..81faed1d8e 100755 --- a/grobid-core/src/main/java/org/grobid/core/utilities/GrobidPropertyKeys.java +++ b/grobid-core/src/main/java/org/grobid/core/utilities/GrobidPropertyKeys.java @@ -21,6 +21,8 @@ public interface GrobidPropertyKeys { String PROP_GROBID_DELFT_REDIRECT_OUTPUT = "grobid.delft.redirect_output"; String PROP_GROBID_DELFT_ELMO = "grobid.delft.useELMo"; String PROP_DELFT_ARCHITECTURE = "grobid.delft.architecture"; + String PROP_GROBID_DELFT_TRAIN_MODULE = "grobid.delft.train.module"; + String PROP_GROBID_DELFT_TRAIN_ARGS = "grobid.delft.train.args"; String PROP_USE_LANG_ID = "grobid.use_language_id"; String PROP_LANG_DETECTOR_FACTORY = "grobid.language_detector_factory"; String PROP_SENTENCE_DETECTOR_FACTORY = "grobid.sentence_detector_factory"; diff --git a/grobid-core/src/test/java/org/grobid/core/jni/DeLFTModelTest.java b/grobid-core/src/test/java/org/grobid/core/jni/DeLFTModelTest.java new file mode 100644 index 0000000000..f92365bf10 --- /dev/null +++ b/grobid-core/src/test/java/org/grobid/core/jni/DeLFTModelTest.java @@ -0,0 +1,97 @@ +package org.grobid.core.jni; + +import java.io.File; + +import org.junit.Before; +import org.junit.Test; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; + +import org.grobid.core.utilities.GrobidProperties; +import org.grobid.core.utilities.GrobidPropertyKeys; + + +public class DeLFTModelTest { + private File trainingData = new File("test/train.data"); + + @Before + public void setUp() { + GrobidProperties.getInstance(); + GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_ELMO, "false"); + GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE); + GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS); + } + + @Test + public void testShouldBuildTrainCommand() { + assertThat( + DeLFTModel.getTrainCommand("model1", trainingData, null), + contains( + "python3", "grobidTagger.py", "model1", "train", + "--input", this.trainingData.getAbsolutePath(), + "--output", GrobidProperties.getModelPath().getAbsolutePath() + ) + ); + } + + @Test + public void testShouldAddUseELMO() { + GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_ELMO, "true"); + assertThat( + DeLFTModel.getTrainCommand("model1", trainingData, null), + contains( + "python3", "grobidTagger.py", "model1", "train", + "--input", this.trainingData.getAbsolutePath(), + "--output", GrobidProperties.getModelPath().getAbsolutePath(), + "--use-ELMo" + ) + ); + } + + @Test + public void testShouldUseCustomTrainModule() { + GrobidProperties.getProps().put( + GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "module1.py" + ); + assertThat( + DeLFTModel.getTrainCommand("model1", trainingData, null), + contains( + "python3", "module1.py", "model1", "train", + "--input", this.trainingData.getAbsolutePath(), + "--output", GrobidProperties.getModelPath().getAbsolutePath() + ) + ); + } + + @Test + public void testShouldAddSingleCustomTrainArg() { + GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "arg1"); + assertThat( + DeLFTModel.getTrainCommand("model1", trainingData, null), + contains( + "python3", "grobidTagger.py", "model1", "train", + "--input", this.trainingData.getAbsolutePath(), + "--output", GrobidProperties.getModelPath().getAbsolutePath(), + "arg1" + ) + ); + } + + @Test + public void testShouldAddMultipleCustomTrainArg() { + GrobidProperties.getProps().put( + GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "arg1 arg2" + ); + assertThat( + DeLFTModel.getTrainCommand("model1", trainingData, null), + contains( + "python3", "grobidTagger.py", "model1", "train", + "--input", this.trainingData.getAbsolutePath(), + "--output", GrobidProperties.getModelPath().getAbsolutePath(), + "arg1", + "arg2" + ) + ); + } +} diff --git a/grobid-core/src/test/java/org/grobid/core/utilities/GrobidPropertiesTest.java b/grobid-core/src/test/java/org/grobid/core/utilities/GrobidPropertiesTest.java index 5a435b9b1c..b0a1677728 100755 --- a/grobid-core/src/test/java/org/grobid/core/utilities/GrobidPropertiesTest.java +++ b/grobid-core/src/test/java/org/grobid/core/utilities/GrobidPropertiesTest.java @@ -88,6 +88,32 @@ public void testIsDeLFTRedirectOutputTrueIfSet() throws IOException { assertTrue(GrobidProperties.isDeLFTRedirectOutput()); } + @Test + public void testShouldReturnEmptyTrainModuleByDefault() { + GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE); + assertEquals(GrobidProperties.getDeLFTTrainModule(), ""); + } + + @Test + public void testShouldReturnConfiguredModule() { + GrobidProperties.getProps().put( + GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "module1" + ); + assertEquals(GrobidProperties.getDeLFTTrainModule(), "module1"); + } + + @Test + public void testShouldReturnEmptyTrainArgsByDefault() { + GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS); + assertEquals(GrobidProperties.getDeLFTTrainArgs(), ""); + } + + @Test + public void testShouldReturnConfiguredTrainArgs() { + GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "args"); + assertEquals(GrobidProperties.getDeLFTTrainArgs(), "args"); + } + /*@Test(expected = GrobidPropertyException.class) public void testCheckPropertiesException_shouldThrowException() { GrobidProperties.getProps().put(