33import os
44import struct
55import unittest
6+ from pathlib import Path
7+
8+ import pytest
69from cryptography .hazmat .backends import default_backend
710from cryptography .hazmat .primitives .asymmetric import ec
811from cryptography .hazmat .primitives .serialization import Encoding , PublicFormat
912
10- from pytest import raises
11-
1213import http_ece as ece
1314from http_ece import ECEException
1415
15-
16- TEST_VECTORS = os .path .join (os .sep , ".." , "encrypt_data.json" )[1 :]
16+ TEST_VECTORS = Path (__file__ ).parent .parent / "encrypt_data.json"
1717
1818
1919def logmsg (arg ):
@@ -59,7 +59,7 @@ def setUp(self):
5959 self .m_salt = os .urandom (16 )
6060
6161 def test_derive_key_invalid_mode (self ):
62- with raises (ECEException ) as ex :
62+ with pytest . raises (ECEException , match = "unknown 'mode' specified: invalid" ) :
6363 ece .derive_key (
6464 "invalid" ,
6565 version = "aes128gcm" ,
@@ -70,10 +70,9 @@ def test_derive_key_invalid_mode(self):
7070 auth_secret = None ,
7171 keyid = "valid" ,
7272 )
73- assert ex .value .message == "unknown 'mode' specified: invalid"
7473
7574 def test_derive_key_invalid_salt (self ):
76- with raises (ECEException ) as ex :
75+ with pytest . raises (ECEException , match = "'salt' must be a 16 octet value" ) :
7776 ece .derive_key (
7877 "encrypt" ,
7978 version = "aes128gcm" ,
@@ -84,10 +83,9 @@ def test_derive_key_invalid_salt(self):
8483 auth_secret = None ,
8584 keyid = "valid" ,
8685 )
87- assert ex .value .message == "'salt' must be a 16 octet value"
8886
8987 def test_derive_key_invalid_version (self ):
90- with raises (ECEException ) as ex :
88+ with pytest . raises (ECEException , match = "Invalid version" ) :
9189 ece .derive_key (
9290 "encrypt" ,
9391 version = "invalid" ,
@@ -98,10 +96,9 @@ def test_derive_key_invalid_version(self):
9896 auth_secret = None ,
9997 keyid = "valid" ,
10098 )
101- assert ex .value .message == "Invalid version"
10299
103100 def test_derive_key_no_private_key (self ):
104- with raises (ECEException ) as ex :
101+ with pytest . raises (ECEException , match = "DH requires a private_key" ) :
105102 ece .derive_key (
106103 "encrypt" ,
107104 version = "aes128gcm" ,
@@ -112,10 +109,9 @@ def test_derive_key_no_private_key(self):
112109 auth_secret = None ,
113110 keyid = "valid" ,
114111 )
115- assert ex .value .message == "DH requires a private_key"
116112
117113 def test_derive_key_no_secret (self ):
118- with raises (ECEException ) as ex :
114+ with pytest . raises (ECEException , match = "unable to determine the secret" ) :
119115 ece .derive_key (
120116 "encrypt" ,
121117 version = "aes128gcm" ,
@@ -126,12 +122,10 @@ def test_derive_key_no_secret(self):
126122 auth_secret = None ,
127123 keyid = "valid" ,
128124 )
129- assert ex .value .message == "unable to determine the secret"
130125
131126 def test_iv_bad_counter (self ):
132- with raises (ECEException ) as ex :
127+ with pytest . raises (ECEException , match = "Counter too big" ) :
133128 ece .iv (os .urandom (8 ), pow (2 , 64 ) + 1 )
134- assert ex .value .message == "Counter too big"
135129
136130
137131class TestEceChecking (unittest .TestCase ):
@@ -144,76 +138,69 @@ def setUp(self):
144138 self .m_header += struct .pack ("!L" , 32 ) + b"\0 "
145139
146140 def test_encrypt_small_rs (self ):
147- with raises (ECEException ) as ex :
141+ with pytest . raises (ECEException , match = "Record size too small" ) :
148142 ece .encrypt (
149143 self .m_input ,
150144 version = "aes128gcm" ,
151145 key = self .m_key ,
152146 rs = 1 ,
153147 )
154- assert ex .value .message == "Record size too small"
155148
156149 def test_decrypt_small_rs (self ):
157150 header = os .urandom (16 ) + struct .pack ("!L" , 2 ) + b"\0 "
158- with raises (ECEException ) as ex :
151+ with pytest . raises (ECEException , match = "Record size too small" ) :
159152 ece .decrypt (
160153 header + self .m_input ,
161154 version = "aes128gcm" ,
162155 key = self .m_key ,
163156 rs = 1 ,
164157 )
165- assert ex .value .message == "Record size too small"
166158
167159 def test_encrypt_bad_version (self ):
168- with raises (ECEException ) as ex :
160+ with pytest . raises (ECEException , match = "Invalid version" ) :
169161 ece .encrypt (
170162 self .m_input ,
171163 version = "bogus" ,
172164 key = self .m_key ,
173165 )
174- assert ex .value .message == "Invalid version"
175166
176167 def test_decrypt_bad_version (self ):
177- with raises (ECEException ) as ex :
168+ with pytest . raises (ECEException , match = "Invalid version" ) :
178169 ece .decrypt (
179170 self .m_input ,
180171 version = "bogus" ,
181172 key = self .m_key ,
182173 )
183- assert ex .value .message == "Invalid version"
184174
185175 def test_decrypt_bad_header (self ):
186- with raises (ECEException ) as ex :
176+ with pytest . raises (ECEException , match = "Could not parse the content header" ) :
187177 ece .decrypt (
188178 os .urandom (4 ),
189179 version = "aes128gcm" ,
190180 key = self .m_key ,
191181 )
192- assert ex .value .message == "Could not parse the content header"
193182
194183 def test_encrypt_long_keyid (self ):
195- with raises (ECEException ) as ex :
184+ with pytest . raises (ECEException , match = "keyid is too long" ) :
196185 ece .encrypt (
197186 self .m_input ,
198187 version = "aes128gcm" ,
199188 key = self .m_key ,
200189 keyid = b64e (os .urandom (192 )), # 256 bytes
201190 )
202- assert ex .value .message == "keyid is too long"
203191
204192 def test_overlong_padding (self ):
205- with raises (ECEException ) as ex :
193+ with pytest . raises (ECEException , match = "all zero record plaintext" ) :
206194 ece .decrypt (
207195 self .m_header + b"\xbb \xc7 \xb9 ev\x0b \xf0 f+\x93 \xf4 "
208196 b"\xe5 \xd6 \x94 \xb7 e\xf0 \xcd \x15 \x9b (\x01 \xa5 " ,
209197 version = "aes128gcm" ,
210198 key = b"d\xc7 \x0e d\xa7 %U\x14 Q\xf2 \x08 \xdf \xba \xa0 \xb9 r" ,
211199 keyid = b64e (os .urandom (192 )), # 256 bytes
212200 )
213- assert ex .value .message == "all zero record plaintext"
214201
215202 def test_bad_early_delimiter (self ):
216- with raises (ECEException ) as ex :
203+ with pytest . raises (ECEException , match = "record delimiter != 1" ) :
217204 ece .decrypt (
218205 self .m_header + b"\xb9 \xc7 \xb9 ev\x0b \xf0 \x9e B\xb1 \x08 C8u"
219206 b"\xa3 \x06 \xc9 x\x06 \n \xfc |}\xe9 R\x85 \x91 "
@@ -224,29 +211,26 @@ def test_bad_early_delimiter(self):
224211 key = b"d\xc7 \x0e d\xa7 %U\x14 Q\xf2 \x08 \xdf \xba \xa0 \xb9 r" ,
225212 keyid = b64e (os .urandom (192 )), # 256 bytes
226213 )
227- assert ex .value .message == "record delimiter != 1"
228214
229215 def test_bad_final_delimiter (self ):
230- with raises (ECEException ) as ex :
216+ with pytest . raises (ECEException , match = "last record delimiter != 2" ) :
231217 ece .decrypt (
232218 self .m_header + b"\xba \xc7 \xb9 ev\x0b \xf0 \x9e B\xb1 \x08 Ji"
233219 b"\xe4 P\x1b \x8d I\xdb \xc6 y#MG\xc2 W\x16 " ,
234220 version = "aes128gcm" ,
235221 key = b"d\xc7 \x0e d\xa7 %U\x14 Q\xf2 \x08 \xdf \xba \xa0 \xb9 r" ,
236222 keyid = b64e (os .urandom (192 )), # 256 bytes
237223 )
238- assert ex .value .message == "last record delimiter != 2"
239224
240225 def test_damage (self ):
241- with raises (ECEException ) as ex :
226+ with pytest . raises (ECEException , match = r"Decryption error: InvalidTag()" ) :
242227 ece .decrypt (
243228 self .m_header + b"\xbb \xc6 \xb1 \x1d F:~\x0f \x07 +\xbe \xaa D"
244229 b"\xe0 \xd6 .K\xe5 \xf9 ]%\xe3 \x86 q\xe0 }" ,
245230 version = "aes128gcm" ,
246231 key = b"d\xc7 \x0e d\xa7 %U\x14 Q\xf2 \x08 \xdf \xba \xa0 \xb9 r" ,
247232 keyid = b64e (os .urandom (192 )), # 256 bytes
248233 )
249- assert ex .value .message == "Decryption error: InvalidTag()"
250234
251235
252236class TestEceIntegration (unittest .TestCase ):
@@ -350,9 +334,8 @@ def detect_truncation(self, version):
350334 chunk = encrypted [0 : 21 + rs ]
351335 else :
352336 chunk = encrypted [0 : rs + 16 ]
353- with raises (ECEException ) as ex :
337+ with pytest . raises (ECEException , match = "Message truncated" ) :
354338 ece .decrypt (chunk , salt = salt , key = key , rs = rs , version = version )
355- assert ex .value .message == "Message truncated"
356339
357340 def use_dh (self , version ):
358341 def pubbytes (k ):
@@ -427,9 +410,9 @@ class TestNode(unittest.TestCase):
427410 """Testing using data from the node.js version."""
428411
429412 def setUp (self ):
430- if not os . path . exists (TEST_VECTORS ):
431- self .skipTest ("No %s file found" % TEST_VECTORS )
432- f = open (TEST_VECTORS , "r" )
413+ if not Path ( TEST_VECTORS ). exists ():
414+ self .skipTest (f "No { TEST_VECTORS } file found" )
415+ f = Path (TEST_VECTORS ). open ( )
433416 self .legacy_data = json .loads (f .read ())
434417 f .close ()
435418
@@ -446,7 +429,7 @@ def _run(self, mode):
446429 outp = "input"
447430
448431 for data in self .legacy_data :
449- logmsg ("%s: %s" % (mode , data ["test" ]))
432+ logmsg ("{}: {}" . format (mode , data ["test" ]))
450433 p = data ["params" ][mode ]
451434
452435 if "pad" in p and mode == "encrypt" :
0 commit comments