|
13 | 13 | *)
|
14 | 14 |
|
15 | 15 | module Pam = struct
|
| 16 | + include Pam |
| 17 | + |
16 | 18 | let unsafe_crypt_r = Pam.unsafe_crypt_r [@@alert "-unsafe"]
|
17 | 19 | (* Suppress the alert the purpose of testing. *)
|
18 | 20 | end
|
@@ -124,8 +126,6 @@ let test_crypt_r_many_threads () =
|
124 | 126 | let start = now () in
|
125 | 127 | while now () -. start < 0.2 do
|
126 | 128 | let actual = unsafe_crypt_r ~key ~setting in
|
127 |
| - Printf.printf "thread %d computed %s\n" i actual ; |
128 |
| - flush stdout ; |
129 | 129 | if actual <> expectation then
|
130 | 130 | failwith (Printf.sprintf "%s <> %s" actual expectation)
|
131 | 131 | done
|
@@ -210,12 +210,65 @@ let test_c_truncation () =
|
210 | 210 | if hash <> hash' then
|
211 | 211 | failwith "Expected truncation using C-style null termination failed"
|
212 | 212 |
|
| 213 | +(* Make following tests fail if the safe API fails to return a valid result. *) |
| 214 | +let crypt ~algo ~key ~salt = |
| 215 | + let open struct exception CryptException of Pam.crypt_err end in |
| 216 | + match Pam.crypt ~algo ~key ~salt with |
| 217 | + | Ok hash -> |
| 218 | + hash |
| 219 | + | Error e -> |
| 220 | + raise (CryptException e) |
| 221 | + |
| 222 | +(* Test trivial correspondence between safe API invocation and unsafe calls. *) |
| 223 | +let test_api_correspondence () = |
| 224 | + let cases = |
| 225 | + [ |
| 226 | + ("$5$salt123$", Pam.SHA256, "salt123") |
| 227 | + ; ("$6$salt456$", Pam.SHA512, "salt456") |
| 228 | + ] |
| 229 | + in |
| 230 | + let go (setting, algo, salt) = |
| 231 | + let key = "password" in |
| 232 | + let h = unsafe_crypt_r ~key ~setting in |
| 233 | + let h' = crypt ~algo ~key ~salt in |
| 234 | + if h <> h' then |
| 235 | + failwith |
| 236 | + "Hashes differ between invocations of safe and unsafe crypt_r APIs" |
| 237 | + in |
| 238 | + List.iter go cases |
| 239 | + |
| 240 | +(** Ensure the safe API fails in the way you expect. *) |
| 241 | +let test_safe_failures () = |
| 242 | + let key = "password" in |
| 243 | + let cases = |
| 244 | + [ |
| 245 | + (* Salt exceeding maximum length. *) |
| 246 | + ( (fun () -> |
| 247 | + Pam.crypt ~algo:SHA256 ~key ~salt:"asaltthatexceedsthemaximumlength" |
| 248 | + ) |
| 249 | + , Pam.SaltTooLong |
| 250 | + ) |
| 251 | + ] |
| 252 | + in |
| 253 | + let test (case, expected_error) = |
| 254 | + match case () with |
| 255 | + | Ok _ -> |
| 256 | + failwith "Expected crypt error" |
| 257 | + | Error e when e <> expected_error -> |
| 258 | + failwith "Actual crypt error does not match expectation" |
| 259 | + | Error _ -> |
| 260 | + () |
| 261 | + in |
| 262 | + List.iter test cases |
| 263 | + |
213 | 264 | let tests () =
|
214 | 265 | [
|
215 | 266 | ("Valid salts", `Quick, test_valid_salts)
|
216 | 267 | ; ("Invalid salts", `Quick, test_invalid_salts)
|
217 | 268 | ; ("Implicit salt truncation", `Quick, test_salt_truncation)
|
218 | 269 | ; ("Increasing string length", `Quick, test_increasing_length)
|
219 | 270 | ; ("C-style termination", `Quick, test_c_truncation)
|
| 271 | + ; ("Safe and unsafe API", `Quick, test_api_correspondence) |
| 272 | + ; ("Safe API error reporting", `Quick, test_safe_failures) |
220 | 273 | ; ("Multiple threads", `Quick, test_crypt_r_many_threads)
|
221 | 274 | ]
|
0 commit comments