Skip to content

Commit

Permalink
ACMEClient: Remove locks and hide internal variables
Browse files Browse the repository at this point in the history
  • Loading branch information
dvolodin7 committed Nov 15, 2023
1 parent c20a125 commit 9c259f0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 66 deletions.
99 changes: 48 additions & 51 deletions src/gufo/acme/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,14 @@ def __init__(
timeout: Optional[float] = None,
user_agent: Optional[str] = None,
) -> None:
self.directory_url = directory_url
self.directory_lock = asyncio.Lock()
self.directory: Optional[ACMEDirectory] = None
self.key = key
self.alg = alg
self.account_url = account_url
self.nonces: Set[bytes] = set()
self.nonce_lock = asyncio.Lock()
self.timeout = timeout or self.DEFAULT_TIMEOUT
self.user_agent = user_agent or f"Gufo ACME/{__version__}"
self._directory_url = directory_url
self._directory: Optional[ACMEDirectory] = None
self._key = key
self._alg = alg
self._account_url = account_url
self._nonces: Set[bytes] = set()
self._timeout = timeout or self.DEFAULT_TIMEOUT
self._user_agent = user_agent or f"Gufo ACME/{__version__}"

async def __aenter__(self: "ACMEClient") -> "ACMEClient":
"""
Expand Down Expand Up @@ -167,7 +165,7 @@ def is_bound(self: "ACMEClient") -> bool:
True - if the client is bound to account,
False - otherwise.
"""
return self.account_url is not None
return self._account_url is not None

def _check_bound(self: "ACMEClient") -> None:
"""
Expand Down Expand Up @@ -200,7 +198,7 @@ def _get_client(self: "ACMEClient") -> httpx.AsyncClient:
Async HTTP client instance.
"""
return httpx.AsyncClient(
http2=True, headers={"User-Agent": self.user_agent}
http2=True, headers={"User-Agent": self._user_agent}
)

@staticmethod
Expand All @@ -223,27 +221,26 @@ async def _get_directory(self: "ACMEClient") -> ACMEDirectory:
Raises:
ACMEError: In case of the errors.
"""
async with self.directory_lock:
if self.directory is not None:
return self.directory
async with self._get_client() as client:
logger.warning(
"Fetching ACME directory from %s", self.directory_url
)
try:
r = await self._wait_for(
client.get(self.directory_url), self.timeout
)
except httpx.HTTPError as e:
raise ACMEConnectError from e
self._check_response(r)
data = r.json()
self.directory = ACMEDirectory(
new_account=data["newAccount"],
new_nonce=data.get("newNonce"),
new_order=data["newOrder"],
if self._directory is not None:
return self._directory
async with self._get_client() as client:
logger.warning(
"Fetching ACME directory from %s", self._directory_url
)
return self.directory
try:
r = await self._wait_for(
client.get(self._directory_url), self._timeout
)
except httpx.HTTPError as e:
raise ACMEConnectError from e
self._check_response(r)
data = r.json()
self._directory = ACMEDirectory(
new_account=data["newAccount"],
new_nonce=data.get("newNonce"),
new_order=data["newOrder"],
)
return self._directory

@staticmethod
def _email_to_contacts(email: Union[str, Iterable[str]]) -> List[str]:
Expand Down Expand Up @@ -313,8 +310,8 @@ async def new_account(
"contact": contacts,
},
)
self.account_url = resp.headers["Location"]
return self.account_url
self._account_url = resp.headers["Location"]
return self._account_url

async def deactivate_account(self: "ACMEClient") -> None:
"""
Expand Down Expand Up @@ -345,16 +342,16 @@ async def deactivate_account(self: "ACMEClient") -> None:
ACMEError: In case of the errors.
ACMENotRegistred: If the client is not bound to account.
"""
logger.warning("Deactivating account: %s", self.account_url)
logger.warning("Deactivating account: %s", self._account_url)
# Check account is really bound
self._check_bound()
# Send deactivation request
await self._post(
self.account_url, # type: ignore
self._account_url, # type: ignore
{"status": "deactivated"},
)
# Unbind client
self.account_url = None
self._account_url = None

@staticmethod
def _domain_to_identifiers(
Expand Down Expand Up @@ -686,7 +683,7 @@ async def _head(self: "ACMEClient", url: str) -> httpx.Response:
try:
r = await self._wait_for(
client.head(url),
self.timeout,
self._timeout,
)
except httpx.HTTPError as e:
raise ACMEConnectError from e
Expand Down Expand Up @@ -758,7 +755,7 @@ async def _post_once(
"Content-Type": self.JOSE_CONTENT_TYPE,
},
),
self.timeout,
self._timeout,
)
except httpx.HTTPError as e:
raise ACMEConnectError from e
Expand All @@ -780,13 +777,13 @@ async def _get_nonce(self: "ACMEClient", url: str) -> bytes:
Returns:
nonce value as bytes.
"""
if not self.nonces:
if not self._nonces:
d = await self._get_directory()
nonce_url = url if d.new_nonce is None else d.new_nonce
logger.warning("Fetching nonce from %s", nonce_url)
resp = await self._head(nonce_url)
self._check_response(resp)
return self.nonces.pop()
return self._nonces.pop()

def _nonce_from_response(self: "ACMEClient", resp: httpx.Response) -> None:
"""
Expand All @@ -806,10 +803,10 @@ def _nonce_from_response(self: "ACMEClient", resp: httpx.Response) -> None:
try:
logger.warning("Registering new nonce %s", nonce)
b_nonce = decode_b64jose(nonce)
if b_nonce in self.nonces:
if b_nonce in self._nonces:
msg = "Duplicated nonce"
raise ACMEError(msg)
self.nonces.add(b_nonce)
self._nonces.add(b_nonce)
except DeserializationError as e:
logger.error("Bad nonce: %s", e)
raise ACMEBadNonceError from e
Expand All @@ -834,11 +831,11 @@ def _to_jws(
"""
return AcmeJWS.sign(
json.dumps(data, indent=2).encode() if data is not None else b"",
alg=self.alg,
alg=self._alg,
nonce=nonce,
url=url,
key=self.key,
kid=self.account_url,
key=self._key,
kid=self._account_url,
).json_dumps(indent=2)

@staticmethod
Expand Down Expand Up @@ -1070,7 +1067,7 @@ def get_key_authorization(
[
challenge.token,
".",
encode_b64jose(self.key.thumbprint(hash_function=SHA256)),
encode_b64jose(self._key.thumbprint(hash_function=SHA256)),
]
).encode()

Expand Down Expand Up @@ -1142,11 +1139,11 @@ def get_state(self: "ACMEClient") -> bytes:
State of the client as a stream of bytes
"""
state = {
"directory": self.directory_url,
"key": self.key.fields_to_partial_json(),
"directory": self._directory_url,
"key": self._key.fields_to_partial_json(),
}
if self.account_url is not None:
state["account_url"] = self.account_url
if self._account_url is not None:
state["account_url"] = self._account_url
return json.dumps(state, indent=2).encode()

@classmethod
Expand Down
30 changes: 15 additions & 15 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_post_timeout():
async def inner():
async with BlackholeACMEClient(DIRECTORY, key=KEY) as client:
# Avoid HTTP call in get_nonce
client.nonces.add(
client._nonces.add(
b"\xa0[\xe7\x94S\xf5\xc0\x88Q\x95\x84\xb6\x8d6\x97l"
)
with pytest.raises(ACMETimeoutError):
Expand All @@ -263,7 +263,7 @@ def test_post_error():
async def inner():
async with BuggyACMEClient(DIRECTORY, key=KEY) as client:
# Avoid HTTP call in get_nonce
client.nonces.add(
client._nonces.add(
b"\xa0[\xe7\x94S\xf5\xc0\x88Q\x95\x84\xb6\x8d6\x97l"
)
with pytest.raises(ACMEConnectError):
Expand All @@ -276,7 +276,7 @@ def test_post_retry():
async def inner():
async with BlackholeACMEClientBadNonce(DIRECTORY, key=KEY) as client:
# Avoid HTTP call in get_nonce
client.nonces.add(
client._nonces.add(
b"\xa0[\xe7\x94S\xf5\xc0\x88Q\x95\x84\xb6\x8d6\x97l"
)
with pytest.raises(ACMEBadNonceError):
Expand Down Expand Up @@ -356,33 +356,33 @@ def test_check_response_err(j, etype):

def test_nonce_from_response():
client = ACMEClient(DIRECTORY, key=KEY)
assert not client.nonces
assert not client._nonces
resp = Response(200, headers={"Replay-Nonce": "oFvnlFP1wIhRlYS2jTaXbA"})
client._nonce_from_response(resp)
assert client.nonces == {
assert client._nonces == {
b"\xa0[\xe7\x94S\xf5\xc0\x88Q\x95\x84\xb6\x8d6\x97l"
}


def test_nonce_from_response_none():
client = ACMEClient(DIRECTORY, key=KEY)
assert not client.nonces
assert not client._nonces
resp = Response(200)
client._nonce_from_response(resp)
assert not client.nonces
assert not client._nonces


def test_nonce_from_response_decode_error():
client = ACMEClient(DIRECTORY, key=KEY)
assert not client.nonces
assert not client._nonces
resp = Response(200, headers={"Replay-Nonce": "x"})
with pytest.raises(ACMEBadNonceError):
client._nonce_from_response(resp)


def test_nonce_from_response_duplicated():
client = ACMEClient(DIRECTORY, key=KEY)
assert not client.nonces
assert not client._nonces
resp = Response(200, headers={"Replay-Nonce": "oFvnlFP1wIhRlYS2jTaXbA"})
client._nonce_from_response(resp)
with pytest.raises(ACMEError):
Expand Down Expand Up @@ -586,9 +586,9 @@ def test_state1() -> None:
state = client.get_state()
client2 = ACMEClient.from_state(state)
assert client is not client2
assert client.directory == client2.directory
assert client.key == client2.key
assert client2.account_url is None
assert client._directory == client2._directory
assert client._key == client2._key
assert client2._account_url is None


def test_state2() -> None:
Expand All @@ -598,6 +598,6 @@ def test_state2() -> None:
state = client.get_state()
client2 = ACMEClient.from_state(state)
assert client is not client2
assert client.directory == client2.directory
assert client.key == client2.key
assert client.account_url == client2.account_url
assert client._directory == client2._directory
assert client._key == client2._key
assert client._account_url == client2._account_url

0 comments on commit 9c259f0

Please sign in to comment.