Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions mpqp/execution/connection/aws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,17 @@ def get_aws_braket_account_info() -> str:
return result


def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevice":
def get_braket_device(
device: AWSDevice,
is_noisy: bool = False,
is_gate_model: bool = True,
) -> "BraketDevice":
"""Returns the AwsDevice device associate with the AWSDevice in parameter.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc for arguments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Args:
device: AWSDevice element describing which remote/local AwsDevice we want.
is_noisy: If the expected device is noisy or not.
is_gate_model: If the expected device is gate-model or not.

Raises:
AWSBraketRemoteExecutionError: If the device or the region could not be
Expand All @@ -378,26 +383,35 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic
"""
from braket.devices import LocalSimulator

from mpqp.tools.errors import (
AWSBraketRemoteExecutionError,
DeviceJobIncompatibleError,
)

if not device.is_remote():
if is_noisy:
return LocalSimulator("braket_dm")
else:
return LocalSimulator()

import pkg_resources
from botocore.exceptions import NoRegionError
from braket.aws import AwsDevice, AwsSession

import mpqp

try:
import boto3

braket_client = boto3.client("braket", region_name=device.get_region())
aws_session = AwsSession(braket_client=braket_client)
mpqp_version = pkg_resources.get_distribution("mpqp").version[:3]

mpqp_version = getattr(mpqp, "__version__", "0.0.0+unknown")

aws_session.add_braket_user_agent(
user_agent="APN/1.0 ColibriTD/1.0 MPQP/" + mpqp_version
)
return AwsDevice(device.get_arn(), aws_session=aws_session)
braket_device = AwsDevice(device.get_arn(), aws_session=aws_session)

except ValueError as ve:
raise AWSBraketRemoteExecutionError(
"Failed to retrieve remote AWS device. Please check the arn, or if the "
Expand All @@ -410,6 +424,23 @@ def get_braket_device(device: AWSDevice, is_noisy: bool = False) -> "BraketDevic
"\nTrace: " + str(err)
)

if is_gate_model:
actions = getattr(getattr(braket_device, "properties", None), "action", None)
if actions is not None:
supported = [getattr(k, "value", str(k)) for k in actions.keys()]
supports_gate_model = any(
("openqasm" in action.lower()) or ("jaqcd" in action.lower())
for action in supported
)
if not supports_gate_model:
raise DeviceJobIncompatibleError(
f"{device.name} does not support gate-model workloads. "
f"Supported Braket action types: {supported}. "
"This is an AHS device, which cannot run MPQP QCircuit."
)

return braket_device


def get_all_task_ids() -> list[str]:
"""Retrieves all the task ids of this account/group from AWS.
Expand Down
4 changes: 4 additions & 0 deletions mpqp/execution/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(
while before it is set to the right value (For instance, a job
submission can require handshake protocols to conclude before
attributing an id to the job)."""
self.status_message: Optional[str] = None
"""Optional message associated with the current job status, especially
for execution errors."""

@property
def measure(self) -> Optional[Measure]:
Expand Down Expand Up @@ -188,6 +191,7 @@ def to_dict(self):
"measure": self.measure,
"id": self.id,
"status": self.status,
"status_message": self.status_message,
}

@staticmethod
Expand Down
40 changes: 31 additions & 9 deletions mpqp/execution/providers/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from mpqp.execution.job import Job, JobStatus, JobType
from mpqp.execution.result import Result, Sample, StateVector
from mpqp.noise.noise_model import NoiseModel
from mpqp.tools.errors import AWSBraketRemoteExecutionError, DeviceJobIncompatibleError
from mpqp.tools.errors import (
AWSBraketRemoteExecutionError,
DeviceJobIncompatibleError,
DeviceJobIncompatibleWarning,
)

if TYPE_CHECKING:
from braket.circuits import Circuit
Expand Down Expand Up @@ -109,16 +113,33 @@ def run_braket(job: Job) -> Result:
f"{job.device} instead"
)

import warnings

from braket.tasks import GateModelQuantumTaskResult

if isinstance(job.measure, ExpectationMeasure):
return run_braket_observable(job)
_, task = submit_job_braket(job)
res = task.result()
if TYPE_CHECKING:
assert isinstance(res, GateModelQuantumTaskResult)
try:
if isinstance(job.measure, ExpectationMeasure):
return run_braket_observable(job)

_, task = submit_job_braket(job)
res = task.result()
if TYPE_CHECKING:
assert isinstance(res, GateModelQuantumTaskResult)

return extract_result(res, job, job.device)
return extract_result(res, job, job.device)

except DeviceJobIncompatibleError as e:
warnings.warn(str(e), DeviceJobIncompatibleWarning, stacklevel=1)

job.status = JobStatus.ERROR
job.status_message = "Job execution failed. See warning for details."

return Result(
job,
data=None,
errors=None,
shots=0,
)


def run_braket_observable(job: Job):
Expand Down Expand Up @@ -151,6 +172,7 @@ def run_braket_observable(job: Job):
job.device,
is_noisy=bool(job.circuit.noises),
)

if job.measure is None:
raise NotImplementedError("job.measure is None")
assert isinstance(job.measure, ExpectationMeasure)
Expand Down Expand Up @@ -270,7 +292,7 @@ def run_braket_observable(job: Job):
)

if braket_sum is not None:
from braket.program_sets import ProgramSet, CircuitBinding
from braket.program_sets import CircuitBinding, ProgramSet
from braket.tasks.program_set_quantum_task_result import (
ProgramSetQuantumTaskResult,
)
Expand Down
12 changes: 10 additions & 2 deletions mpqp/execution/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy.typing as npt

from mpqp.core.instruction.measurement.basis_measure import BasisMeasure
from mpqp.execution import Job, JobType
from mpqp.execution import Job, JobStatus, JobType
from mpqp.execution.devices import AvailableDevice
from mpqp.tools.display import clean_1D_array, clean_number_repr
from mpqp.tools.errors import ResultAttributeError
Expand Down Expand Up @@ -288,7 +288,7 @@ class Result:
def __init__(
self,
job: Job,
data: float | dict["str", float] | StateVector | list[Sample],
data: float | dict["str", float] | StateVector | list[Sample] | None,
errors: Optional[float | dict[Any, Any]] = None,
shots: int = 0,
):
Expand All @@ -305,6 +305,11 @@ def __init__(
"""See parameter description."""
self._data = data

if data is None:
if job.status != JobStatus.ERROR:
raise TypeError("Result data cannot be None unless job.status == ERROR")
return

# depending on the type of job, fills the result info from the data in parameter
if job.job_type == JobType.OBSERVABLE:
if not isinstance(data, float) and not isinstance(data, dict):
Expand Down Expand Up @@ -458,6 +463,9 @@ def __str__(self):
label = "" if self.job.circuit.label is None else self.job.circuit.label + ", "
header = f"Result: {label}{type(self.device).__name__}, {self.device.name}"

if self.job.status == JobStatus.ERROR:
return f"{header}\n Status: ERROR\n Message: {self.job.status_message}"

if self.job.job_type == JobType.SAMPLE:
measures = self.job.circuit.measurements
if not len(measures) == 1:
Expand Down
4 changes: 4 additions & 0 deletions mpqp/tools/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class DeviceJobIncompatibleError(ValueError):
for the selected device (for example SAMPLE job on a statevector simulator)."""


class DeviceJobIncompatibleWarning(UserWarning):
"""A warning is issued when a job is not compatible with the selected device."""


class RemoteExecutionError(ConnectionError):
"""Raised when an error occurred during a remote connection, submission or
execution."""
Expand Down
Loading