Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Fri 0 4
* [Cplex](https://www.ibm.com/de-de/analytics/cplex-optimizer)
* [MOSEK](https://www.mosek.com/)
* [COPT](https://www.shanshu.ai/copt)
* [cuPDLPx](https://github.com/MIT-Lu-Lab/cuPDLPx)

Note that these do have to be installed by the user separately.

Expand Down
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Upcoming Version
* Harmonize dtypes before concatenation in lp file writing to avoid dtype mismatch errors. This error occurred when creating and storing models in netcdf format using windows machines and loading and solving them on linux machines.
* Add option to use polars series as constant input
* Fix expression merge to explicitly use outer join when combining expressions with disjoint coordinates for consistent behavior across xarray versions
* Add support for GPU-accelerated solver [cuPDLPx](https://github.com/MIT-Lu-Lab/cuPDLPx)

Version 0.5.6
--------------
Expand Down
54 changes: 54 additions & 0 deletions linopy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from linopy.objective import Objective

if TYPE_CHECKING:
from cupdlpx import Model as cupdlpxModel
from highspy.highs import Highs

from linopy.model import Model
Expand Down Expand Up @@ -756,6 +757,59 @@ def to_highspy(m: Model, explicit_coordinate_names: bool = False) -> Highs:
return h


def to_cupdlpx(m: Model, explicit_coordinate_names: bool = False) -> cupdlpxModel:
"""
Export the model to cupdlpx.

This function does not write the model to intermediate files but directly
passes it to cupdlpx.

cuPDLPx does not support named variables and constraints, so names
are not tracked by this function.

Parameters
----------
m : linopy.Model

Returns
-------
model : cupdlpx.Model
"""
import cupdlpx

# build model using canonical form matrices and vectors
# see https://github.com/MIT-Lu-Lab/cuPDLPx/tree/main/python#modeling
M = m.matrices
A = M.A.tocsr() # cuDPLPx only support CSR sparse matrix format
# linopy stores constraints as Ax ?= b and keeps track of inequality
# sense in M.sense. Convert to separate lower and upper bound vectors.
l = np.where(
np.logical_or(np.equal(M.sense, ">"), np.equal(M.sense, "=")),
M.b,
-np.inf,
)
u = np.where(
np.logical_or(np.equal(M.sense, "<"), np.equal(M.sense, "=")),
M.b,
np.inf,
)

cu_model = cupdlpx.Model(
objective_vector=M.c,
constraint_matrix=A,
constraint_lower_bound=l,
constraint_upper_bound=u,
variable_lower_bound=M.lb,
variable_upper_bound=M.ub,
)

# change objective sense
if m.objective.sense == "max":
cu_model.ModelSense = cupdlpx.PDLP.MAXIMIZE

return cu_model


def to_block_files(m: Model, fn: Path) -> None:
"""
Write out the linopy model to a block structured output.
Expand Down
3 changes: 3 additions & 0 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from linopy.io import (
to_block_files,
to_cupdlpx,
to_file,
to_gurobipy,
to_highspy,
Expand Down Expand Up @@ -1510,4 +1511,6 @@ def reset_solution(self) -> None:

to_highspy = to_highspy

to_cupdlpx = to_cupdlpx

to_block_files = to_block_files
255 changes: 255 additions & 0 deletions linopy/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pandas as pd
from packaging.version import parse as parse_version

import linopy.io
from linopy.constants import (
Result,
Solution,
Expand Down Expand Up @@ -59,6 +60,7 @@
"scip",
"copt",
"mindopt",
"cupdlpx",
]

FILE_IO_APIS = ["lp", "lp-polars", "mps"]
Expand Down Expand Up @@ -134,6 +136,16 @@
except coptpy.CoptError:
pass

with contextlib.suppress(ModuleNotFoundError):
import cupdlpx

try:
cupdlpx.Model(np.array([0.0]), np.array([[0.0]]), None, None)
available_solvers.append("cupdlpx")
except ImportError:
pass


quadratic_solvers = [s for s in QUADRATIC_SOLVERS if s in available_solvers]
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,6 +177,7 @@ class SolverName(enum.Enum):
COPT = "copt"
MindOpt = "mindopt"
PIPS = "pips"
cuPDLPx = "cupdlpx"


def path_to_string(path: Path) -> str:
Expand Down Expand Up @@ -2261,3 +2274,245 @@ def __init__(
super().__init__(**solver_options)
msg = "The PIPS solver interface is not yet implemented."
raise NotImplementedError(msg)


class cuPDLPx(Solver[None]):
"""
Solver subclass for the cuPDLPx solver. cuPDLPx must be installed
with working GPU support for usage. Find the installation instructions
at https://github.com/MIT-Lu-Lab/cuPDLPx.

The full list of solver options provided with the python interface
is documented at https://github.com/MIT-Lu-Lab/cuPDLPx/tree/main/python.

Some example options are:
* LogToConsole : False by default.
* TimeLimit : 3600.0 by default.
* IterationLimit : 2147483647 by default.

Attributes
----------
**solver_options
options for the given solver
"""

def __init__(
self,
**solver_options: Any,
) -> None:
super().__init__(**solver_options)

def solve_problem_from_file(
self,
problem_fn: Path,
solution_fn: Path | None = None,
log_fn: Path | None = None,
warmstart_fn: Path | None = None,
basis_fn: Path | None = None,
env: EnvType | None = None,
) -> Result:
"""
Solve a linear problem from a problem file using the solver cuPDLPx.
cuPDLPx does not currently support its own file IO, so this function
reads the problem file using linopy (only support netcf files) and
then passes the model to cuPDLPx for solving.
If the solution is feasible the function returns the
objective, solution and dual constraint variables.

Parameters
----------
problem_fn : Path
Path to the problem file.
solution_fn : Path, optional
Path to the solution file.
log_fn : Path, optional
Path to the log file.
warmstart_fn : Path, optional
Path to the warmstart file.
basis_fn : Path, optional
Path to the basis file.
env : None, optional
Environment for the solver

Returns
-------
Result
"""
logger.warning(
"cuPDLPx doesn't currently support file IO. Building model from file using linopy."
)
problem_fn_ = path_to_string(problem_fn)

if problem_fn_.endswith(".netcdf"):
model: Model = linopy.io.read_netcdf(problem_fn_)
else:
msg = "linopy currently only supports reading models from netcdf files. Try using io_api='direct' instead."
raise NotImplementedError(msg)

return self.solve_problem_from_model(
model,
solution_fn=solution_fn,
log_fn=log_fn,
warmstart_fn=warmstart_fn,
basis_fn=basis_fn,
env=env,
)

def solve_problem_from_model(
self,
model: Model,
solution_fn: Path | None = None,
log_fn: Path | None = None,
warmstart_fn: Path | None = None,
basis_fn: Path | None = None,
env: EnvType | None = None,
explicit_coordinate_names: bool = False,
) -> Result:
"""
Solve a linear problem directly from a linopy model using the solver cuPDLPx.
If the solution is feasible the function returns the
objective, solution and dual constraint variables.

Parameters
----------
model : linopy.model
Linopy model for the problem.
solution_fn : Path, optional
Path to the solution file.
log_fn : Path, optional
Path to the log file.
warmstart_fn : Path, optional
Path to the warmstart file.
basis_fn : Path, optional
Path to the basis file.
env : None, optional
Environment for the solver
explicit_coordinate_names : bool, optional
Transfer variable and constraint names to the solver (default: False)

Returns
-------
Result
"""

if model.type in ["QP", "MILP"]:
msg = "cuPDLPx does not currently support QP or MILP problems."
raise NotImplementedError(msg)

cu_model = model.to_cupdlpx()

return self._solve(
cu_model,
l_model=model,
solution_fn=solution_fn,
log_fn=log_fn,
warmstart_fn=warmstart_fn,
basis_fn=basis_fn,
io_api="direct",
sense=model.sense,
)

def _solve(
self,
cu_model: cupdlpx.Model,
l_model: Model | None = None,
solution_fn: Path | None = None,
log_fn: Path | None = None,
warmstart_fn: Path | None = None,
basis_fn: Path | None = None,
io_api: str | None = None,
sense: str | None = None,
) -> Result:
"""
Solve a linear problem from a cupdlpx.Model object.

Parameters
----------
cu_model: cupdlpx.Model
cupdlpx object.
solution_fn : Path, optional
Path to the solution file.
log_fn : Path, optional
Path to the log file.
warmstart_fn : Path, optional
Path to the warmstart file.
basis_fn : Path, optional
Path to the basis file.
model : linopy.model, optional
Linopy model for the problem.
io_api: str
io_api of the problem. For direct API from linopy model this is "direct".
sense: str
"min" or "max"

Returns
-------
Result
"""

# see https://github.com/MIT-Lu-Lab/cuPDLPx/blob/main/python/cupdlpx/PDLP.py
CONDITION_MAP: dict[int, TerminationCondition] = {
cupdlpx.PDLP.OPTIMAL: TerminationCondition.optimal,
cupdlpx.PDLP.PRIMAL_INFEASIBLE: TerminationCondition.infeasible,
cupdlpx.PDLP.DUAL_INFEASIBLE: TerminationCondition.infeasible_or_unbounded,
cupdlpx.PDLP.TIME_LIMIT: TerminationCondition.time_limit,
cupdlpx.PDLP.ITERATION_LIMIT: TerminationCondition.iteration_limit,
cupdlpx.PDLP.UNSPECIFIED: TerminationCondition.unknown,
}

self._set_solver_params(cu_model)

if warmstart_fn is not None:
# cuPDLPx supports warmstart, but there currently isn't the tooling
# to read it in from a file
raise NotImplementedError("Warmstarting not yet implemented for cuPDLPx.")
else:
cu_model.clearWarmStart()

if basis_fn is not None:
logger.warning("Basis files are not supported by cuPDLPx. Ignoring.")

# solve
cu_model.optimize()

# parse solution and output
if solution_fn is not None:
raise NotImplementedError(
"Solution file output not yet implemented for cuPDLPx."
)

termination_condition = CONDITION_MAP.get(
cu_model.StatusCode, cu_model.StatusCode
)
status = Status.from_termination_condition(termination_condition)
status.legacy_status = cu_model.Status # cuPDLPx status message

def get_solver_solution() -> Solution:
objective = cu_model.ObjVal

vlabels = None if l_model is None else l_model.matrices.vlabels
clabels = None if l_model is None else l_model.matrices.clabels

sol = pd.Series(cu_model.X, vlabels, dtype=float)
dual = pd.Series(cu_model.Pi, clabels, dtype=float)

if cu_model.ModelSense == cupdlpx.PDLP.MAXIMIZE:
dual *= -1 # flip sign of duals for max problems

return Solution(sol, dual, objective)

solution = self.safe_get_solution(status=status, func=get_solver_solution)
solution = maybe_adjust_objective_sign(solution, io_api, sense)

# see https://github.com/MIT-Lu-Lab/cuPDLPx/tree/main/python#solution-attributes
return Result(status, solution, cu_model)

def _set_solver_params(self, cu_model: cupdlpx.Model) -> None:
"""
Set solver options for cuPDLPx model.

For list of available options, see
https://github.com/MIT-Lu-Lab/cuPDLPx/tree/main/python#parameters
"""
for k, v in self.solver_options.items():
cu_model.setParam(k, v)
Loading