diff --git a/src/joserfc/_keys.py b/src/joserfc/_keys.py index 1ad7167..cbb5d26 100644 --- a/src/joserfc/_keys.py +++ b/src/joserfc/_keys.py @@ -127,16 +127,13 @@ def __eq__(self, other: t.Any) -> bool: assert isinstance(other, KeySet) return self.keys == other.keys - def as_dict(self, private: bool | None = None, **params: t.Any) -> KeySetSerialization: + def as_dict(self, private: bool = False, **params: t.Any) -> KeySetSerialization: keys: list[DictKey] = [] for key in self.keys: # trigger key to generate kid via thumbprint key.ensure_kid() - if isinstance(key, OctKey): - keys.append(key.as_dict(**params)) - else: - keys.append(key.as_dict(private=private, **params)) + keys.append(key.as_dict(private=private, **params)) return {"keys": keys} def get_by_kid(self, kid: str | None = None, parameters: KeyParameters | None = None) -> Key: diff --git a/src/joserfc/_rfc7517/models.py b/src/joserfc/_rfc7517/models.py index fd08ff7..152dc9d 100644 --- a/src/joserfc/_rfc7517/models.py +++ b/src/joserfc/_rfc7517/models.py @@ -49,7 +49,7 @@ def import_from_bytes(cls, value: bytes, password: t.Any = None) -> t.Any: def as_bytes( key: GenericKey, encoding: t.Literal["PEM", "DER"] | None = None, - private: bool | None = None, + private: bool = False, password: str | None = None, ) -> bytes: raise NotImplementedError() @@ -177,7 +177,7 @@ def thumbprint_uri(self) -> str: value = self.thumbprint() return concat_thumbprint_uri(value, self.thumbprint_digest_method) - def as_dict(self, private: bool | None = None, **params: t.Any) -> DictKey: + def as_dict(self, private: bool = False, **params: t.Any) -> DictKey: """Output this key to a JWK format (in dict). By default, it will return the ``dict_value`` of this key. @@ -190,7 +190,7 @@ def as_dict(self, private: bool | None = None, **params: t.Any) -> DictKey: raise ValueError("This key is not a private key.") data = self.dict_value.copy() - if private is not False: + if private: data.update(params) return data @@ -322,15 +322,15 @@ def raw_value(self) -> t.Union[NativePublicKey, NativePrivateKey]: def as_bytes( self, encoding: t.Literal["PEM", "DER"] | None = None, - private: bool | None = None, + private: bool = False, password: str | None = None, ) -> bytes: return self.binding.as_bytes(self, encoding, private, password) - def as_pem(self, private: bool | None = None, password: str | None = None) -> bytes: + def as_pem(self, private: bool = False, password: str | None = None) -> bytes: return self.as_bytes(private=private, password=password) - def as_der(self, private: bool | None = None, password: str | None = None) -> bytes: + def as_der(self, private: bool = False, password: str | None = None) -> bytes: return self.as_bytes(encoding="DER", private=private, password=password) diff --git a/src/joserfc/_rfc7517/pem.py b/src/joserfc/_rfc7517/pem.py index 5c94af2..6fed5e5 100644 --- a/src/joserfc/_rfc7517/pem.py +++ b/src/joserfc/_rfc7517/pem.py @@ -134,12 +134,9 @@ def import_from_bytes(cls, value: bytes, password: Any | None = None) -> Any: def as_bytes( key: GenericKey, encoding: Literal["PEM", "DER"] | None = None, - private: bool | None = False, + private: bool = False, password: Any | None = None, ) -> bytes: - if private is None: - private = key.is_private - if private: return dump_pem_key(key.private_key, encoding, private, password) else: diff --git a/src/joserfc/_rfc7518/oct_key.py b/src/joserfc/_rfc7518/oct_key.py index 24b6a9a..fc254f4 100644 --- a/src/joserfc/_rfc7518/oct_key.py +++ b/src/joserfc/_rfc7518/oct_key.py @@ -1,3 +1,4 @@ +import typing as t from typing import Any import secrets import warnings @@ -101,3 +102,6 @@ def generate_key( if auto_kid: key.ensure_kid() return key + + def as_dict(self, private: bool = False, **params: t.Any) -> DictKey: + return super().as_dict(private=True, **params) diff --git a/tests/jwk/test_ec_key.py b/tests/jwk/test_ec_key.py index 443f63e..ded584e 100644 --- a/tests/jwk/test_ec_key.py +++ b/tests/jwk/test_ec_key.py @@ -84,14 +84,14 @@ def test_import_invalid_pem_key(self): def test_output_with_password(self): key = ECKey.import_key(read_key("ec-p256-private.pem")) - pem = key.as_pem(password="secret") + pem = key.as_pem(private=True, password="secret") self.assertRaises(TypeError, ECKey.import_key, pem) key2 = ECKey.import_key(pem, password="secret") self.assertEqual(key.as_dict(), key2.as_dict()) def test_key_eq(self): key1 = self.default_key - key2 = ECKey.import_key(key1.as_dict()) + key2 = ECKey.import_key(key1.as_dict(private=True)) self.assertEqual(key1, key2) key3 = ECKey.generate_key() self.assertNotEqual(key1, key3) diff --git a/tests/jwk/test_jwk_set.py b/tests/jwk/test_jwk_set.py index 06342d8..45f9e38 100644 --- a/tests/jwk/test_jwk_set.py +++ b/tests/jwk/test_jwk_set.py @@ -19,7 +19,7 @@ def test_generate_and_import_key_set(self): # we will ensure kid when generating the key set self.assertIsNotNone(key.kid) - jwks1_data = jwks1.as_dict() + jwks1_data = jwks1.as_dict(private=True) self.assertEqual(list(jwks1_data.keys()), ["keys"]) for d1 in jwks1_data["keys"]: self.assertIn("d", d1) @@ -90,6 +90,6 @@ def test_key_eq_with_same_keys(self): def test_key_eq_with_new_keys(self): key_set1 = KeySet.generate_key_set("RSA", 2048) - key_set2 = KeySet([RSAKey.import_key(k.as_dict()) for k in key_set1]) + key_set2 = KeySet([RSAKey.import_key(k.as_dict(private=True)) for k in key_set1]) self.assertIsNot(key_set1, key_set2) self.assertEqual(key_set1, key_set2) diff --git a/tests/jwk/test_okp_key.py b/tests/jwk/test_okp_key.py index 1e15d30..3aa2785 100644 --- a/tests/jwk/test_okp_key.py +++ b/tests/jwk/test_okp_key.py @@ -46,11 +46,11 @@ def test_import_pem_key(self): private_key: OKPKey = OKPKey.import_key(private_pem) public_key: OKPKey = OKPKey.import_key(public_pem) - self.assertEqual(private_key.as_pem(), private_pem) + self.assertEqual(private_key.as_pem(private=True), private_pem) self.assertEqual(private_key.as_pem(private=False), public_pem) self.assertEqual(public_key.as_pem(), public_pem) - self.assertIn("d", private_key.as_dict()) + self.assertIn("d", private_key.as_dict(private=True)) self.assertNotIn("d", public_key.as_dict()) def test_import_invalid_pem_key(self): @@ -88,7 +88,7 @@ def test_all_as_methods(self): key: OKPKey = OKPKey.import_key(private_json) # as_dict - data = key.as_dict() + data = key.as_dict(private=True) self.assertIn("d", data) self.assertEqual(data, private_json) data = key.as_dict(private=False) @@ -96,7 +96,7 @@ def test_all_as_methods(self): self.assertEqual(data, public_json) # as_pem - data = key.as_pem() + data = key.as_pem(private=True) self.assertIn(b"PRIVATE", data) data = key.as_pem(private=False) self.assertIn(b"PUBLIC", data) @@ -107,14 +107,14 @@ def test_all_as_methods(self): def test_output_with_password(self): key = OKPKey.import_key(read_key("okp-ed25519-private.json")) - pem = key.as_pem(password="secret") + pem = key.as_pem(private=True, password="secret") self.assertRaises(TypeError, OKPKey.import_key, pem) key2 = OKPKey.import_key(pem, password="secret") self.assertEqual(key.as_pem(), key2.as_pem()) def test_key_eq(self): key1 = OKPKey.generate_key() - key2 = OKPKey.import_key(key1.as_dict()) + key2 = OKPKey.import_key(key1.as_dict(private=True)) self.assertIsNot(key1, key2) self.assertEqual(key1, key2) key3 = OKPKey.generate_key() diff --git a/tests/jwk/test_rsa_key.py b/tests/jwk/test_rsa_key.py index 8992080..b8a871c 100644 --- a/tests/jwk/test_rsa_key.py +++ b/tests/jwk/test_rsa_key.py @@ -101,16 +101,12 @@ def test_output_as_methods(self): key: RSAKey = RSAKey.import_key(private_pem) # as_dict - data = key.as_dict() - self.assertIn("d", data) data = key.as_dict(private=True) self.assertIn("d", data) data = key.as_dict(private=False) self.assertNotIn("d", data) # as_pem - data = key.as_pem() - self.assertIn(b"PRIVATE", data) data = key.as_pem(private=True) self.assertIn(b"PRIVATE", data) data = key.as_pem(private=False) @@ -172,14 +168,14 @@ def test_import_invalid_pem_key(self): def test_output_with_password(self): private_pem = read_key("rsa-openssl-private.pem") key: RSAKey = RSAKey.import_key(private_pem) - pem = key.as_pem(password="secret") + pem = key.as_pem(private=True, password="secret") self.assertRaises(TypeError, RSAKey.import_key, pem) key2 = RSAKey.import_key(pem, password="secret") self.assertEqual(key.as_dict(), key2.as_dict()) def test_key_eq(self): key1 = self.default_key - key2 = RSAKey.import_key(key1.as_dict()) + key2 = RSAKey.import_key(key1.as_dict(private=True)) self.assertIsNot(key1, key2) self.assertEqual(key1, key2) key3 = RSAKey.generate_key()