Skip to content

Commit

Permalink
Merge branch 'main' into ruff_test
Browse files Browse the repository at this point in the history
  • Loading branch information
math411 authored Oct 18, 2024
2 parents 93d286d + cc3698c commit ad56d22
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions src/braket/aws/aws_quantum_task_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import time
from concurrent.futures.thread import ThreadPoolExecutor
from itertools import repeat
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Union

from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.annealing import Problem
Expand All @@ -30,6 +30,14 @@
from braket.registers.qubit_set import QubitSet
from braket.tasks.quantum_task_batch import QuantumTaskBatch

if TYPE_CHECKING:
from braket.tasks.analog_hamiltonian_simulation_quantum_task_result import (
AnalogHamiltonianSimulationQuantumTaskResult,
)
from braket.tasks.annealing_quantum_task_result import AnnealingQuantumTaskResult
from braket.tasks.gate_model_quantum_task_result import GateModelQuantumTaskResult
from braket.tasks.photonic_model_quantum_task_result import PhotonicModelQuantumTaskResult


class AwsQuantumTaskBatch(QuantumTaskBatch):
"""Executes a batch of quantum tasks in parallel.
Expand Down Expand Up @@ -329,7 +337,12 @@ def results(
fail_unsuccessful: bool = False,
max_retries: int = MAX_RETRIES,
use_cached_value: bool = True,
) -> list[AwsQuantumTask]:
) -> list[
GateModelQuantumTaskResult
| AnnealingQuantumTaskResult
| PhotonicModelQuantumTaskResult
| AnalogHamiltonianSimulationQuantumTaskResult
]:
"""Retrieves the result of every quantum task in the batch.
Polling for results happens in parallel; this method returns when all quantum tasks
Expand All @@ -346,7 +359,8 @@ def results(
even when results have already been cached. Default: `True`.
Returns:
list[AwsQuantumTask]: The results of all of the quantum tasks in the batch.
list[GateModelQuantumTaskResult | AnnealingQuantumTaskResult | PhotonicModelQuantumTaskResult | AnalogHamiltonianSimulationQuantumTaskResult]: The # noqa: E501
results of all of the quantum tasks in the batch.
`FAILED`, `CANCELLED`, or timed out quantum tasks will have a result of None
"""
if not self._results or not use_cached_value:
Expand All @@ -367,7 +381,14 @@ def results(
return self._results

@staticmethod
def _retrieve_results(tasks: list[AwsQuantumTask], max_workers: int) -> list[AwsQuantumTask]:
def _retrieve_results(
tasks: list[AwsQuantumTask], max_workers: int
) -> list[
GateModelQuantumTaskResult
| AnnealingQuantumTaskResult
| PhotonicModelQuantumTaskResult
| AnalogHamiltonianSimulationQuantumTaskResult
]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
result_futures = [executor.submit(task.result) for task in tasks]
return [future.result() for future in result_futures]
Expand Down

0 comments on commit ad56d22

Please sign in to comment.