Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.grobid.core.jni;

import org.apache.commons.lang3.StringUtils;

import org.grobid.core.GrobidModel;
import org.grobid.core.engines.label.TaggingLabels;
import org.grobid.core.exceptions.GrobidException;
Expand Down Expand Up @@ -252,7 +254,35 @@ public void run() {
LOGGER.error("DeLFT model training via JEP failed", e);
}
}
}
}

protected static List<String> getTrainCommand(String modelName, File trainingData, String architecture) {
String trainModule = GrobidProperties.getDeLFTTrainModule();
if (StringUtils.isEmpty(trainModule)) {
trainModule = "grobidTagger.py";
}
List<String> command = new ArrayList<>(Arrays.asList(
"python3",
trainModule,
modelName,
"train",
"--input", trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath()
));
if (architecture != null) {
command.add("--architecture");
command.add(architecture);
}
if (GrobidProperties.useELMo() && modelName.toLowerCase().indexOf("bert") == -1) {
command.add("--use-ELMo");
}
if (StringUtils.isNotEmpty(GrobidProperties.getDeLFTTrainArgs())) {
command.addAll(Arrays.asList(
GrobidProperties.getDeLFTTrainArgs().split(" ")
));
}
return command;
}

/**
* Train with an external process rather than with JNI, this approach appears to be more stable for the
Expand All @@ -261,19 +291,8 @@ public void run() {
public static void train(String modelName, File trainingData, File outputModel, String architecture) {
try {
LOGGER.info("Train DeLFT model " + modelName + "...");
List<String> command = Arrays.asList("python3",
"grobidTagger.py",
modelName,
"train",
"--input", trainingData.getAbsolutePath(),
"--output", GrobidProperties.getInstance().getModelPath().getAbsolutePath());
if (architecture != null) {
command.add("--architecture");
command.add(architecture);
}
if (GrobidProperties.getInstance().useELMo() && modelName.toLowerCase().indexOf("bert") == -1) {
command.add("--use-ELMo");
}
List<String> command = getTrainCommand(modelName, trainingData, architecture);
LOGGER.info("Running: {}", command);

ProcessBuilder pb = new ProcessBuilder(command);
File delftPath = new File(GrobidProperties.getInstance().getDeLFTFilePath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ public static boolean isDeLFTRedirectOutput() {
);
}

public static String getDeLFTTrainModule() {
return getPropertyValue(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, ""
);
}

public static String getDeLFTTrainArgs() {
return getPropertyValue(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "");
}

public static String getGluttonHost() {
return getPropertyValue(GrobidPropertyKeys.PROP_GLUTTON_HOST);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public interface GrobidPropertyKeys {
String PROP_GROBID_DELFT_REDIRECT_OUTPUT = "grobid.delft.redirect_output";
String PROP_GROBID_DELFT_ELMO = "grobid.delft.useELMo";
String PROP_DELFT_ARCHITECTURE = "grobid.delft.architecture";
String PROP_GROBID_DELFT_TRAIN_MODULE = "grobid.delft.train.module";
String PROP_GROBID_DELFT_TRAIN_ARGS = "grobid.delft.train.args";
String PROP_USE_LANG_ID = "grobid.use_language_id";
String PROP_LANG_DETECTOR_FACTORY = "grobid.language_detector_factory";
String PROP_SENTENCE_DETECTOR_FACTORY = "grobid.sentence_detector_factory";
Expand Down
97 changes: 97 additions & 0 deletions grobid-core/src/test/java/org/grobid/core/jni/DeLFTModelTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package org.grobid.core.jni;

import java.io.File;

import org.junit.Before;
import org.junit.Test;

import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertThat;

import org.grobid.core.utilities.GrobidProperties;
import org.grobid.core.utilities.GrobidPropertyKeys;


public class DeLFTModelTest {
private File trainingData = new File("test/train.data");

@Before
public void setUp() {
GrobidProperties.getInstance();
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_ELMO, "false");
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE);
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS);
}

@Test
public void testShouldBuildTrainCommand() {
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath()
)
);
}

@Test
public void testShouldAddUseELMO() {
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_ELMO, "true");
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath(),
"--use-ELMo"
)
);
}

@Test
public void testShouldUseCustomTrainModule() {
GrobidProperties.getProps().put(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "module1.py"
);
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "module1.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath()
)
);
}

@Test
public void testShouldAddSingleCustomTrainArg() {
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "arg1");
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath(),
"arg1"
)
);
}

@Test
public void testShouldAddMultipleCustomTrainArg() {
GrobidProperties.getProps().put(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "arg1 arg2"
);
assertThat(
DeLFTModel.getTrainCommand("model1", trainingData, null),
contains(
"python3", "grobidTagger.py", "model1", "train",
"--input", this.trainingData.getAbsolutePath(),
"--output", GrobidProperties.getModelPath().getAbsolutePath(),
"arg1",
"arg2"
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,32 @@ public void testIsDeLFTRedirectOutputTrueIfSet() throws IOException {
assertTrue(GrobidProperties.isDeLFTRedirectOutput());
}

@Test
public void testShouldReturnEmptyTrainModuleByDefault() {
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE);
assertEquals(GrobidProperties.getDeLFTTrainModule(), "");
}

@Test
public void testShouldReturnConfiguredModule() {
GrobidProperties.getProps().put(
GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_MODULE, "module1"
);
assertEquals(GrobidProperties.getDeLFTTrainModule(), "module1");
}

@Test
public void testShouldReturnEmptyTrainArgsByDefault() {
GrobidProperties.getProps().remove(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS);
assertEquals(GrobidProperties.getDeLFTTrainArgs(), "");
}

@Test
public void testShouldReturnConfiguredTrainArgs() {
GrobidProperties.getProps().put(GrobidPropertyKeys.PROP_GROBID_DELFT_TRAIN_ARGS, "args");
assertEquals(GrobidProperties.getDeLFTTrainArgs(), "args");
}

/*@Test(expected = GrobidPropertyException.class)
public void testCheckPropertiesException_shouldThrowException() {
GrobidProperties.getProps().put(
Expand Down