diff --git a/src/internal/keys.rs b/src/internal/keys.rs index 989907d..d4f3e5c 100644 --- a/src/internal/keys.rs +++ b/src/internal/keys.rs @@ -23,7 +23,7 @@ use internal::util::{fmt_hex, opt, Bytes32, Bytes64}; use sodiumoxide::crypto::scalarmult as ecdh; use sodiumoxide::crypto::sign; use sodiumoxide::randombytes; -use std::fmt::{self, Debug, Error, Formatter}; +use std::fmt::{self, Debug}; use std::io::{Cursor, Read, Write}; use std::u16; use std::vec::Vec; @@ -32,14 +32,18 @@ use std::vec::Vec; #[derive(Clone, PartialEq, Eq, Debug)] pub struct IdentityKey { - pub public_key: PublicKey, + pub public_key: IdentityPublicKey, } impl IdentityKey { - pub fn new(k: PublicKey) -> IdentityKey { + pub fn new(k: IdentityPublicKey) -> IdentityKey { IdentityKey { public_key: k } } + pub fn verify(&self, s: &Signature, m: &[u8]) -> bool { + self.public_key.verify(s, m) + } + pub fn fingerprint(&self) -> String { self.public_key.fingerprint() } @@ -55,12 +59,16 @@ impl IdentityKey { let mut public_key = None; for _ in 0..n { match d.u8()? { - 0 => uniq!("IdentityKey::public_key", public_key, PublicKey::decode(d)?), + 0 => uniq!( + "IdentityPublicKey::public_key", + public_key, + IdentityPublicKey::decode(d)? + ), _ => d.skip()?, } } Ok(IdentityKey { - public_key: to_field!(public_key, "IdentityKey::public_key"), + public_key: to_field!(public_key, "IdentityPublicKey::public_key"), }) } } @@ -70,7 +78,7 @@ impl IdentityKey { #[derive(Clone)] pub struct IdentityKeyPair { pub version: u8, - pub secret_key: SecretKey, + pub secret_key: IdentitySecretKey, pub public_key: IdentityKey, } @@ -82,13 +90,24 @@ impl Default for IdentityKeyPair { impl IdentityKeyPair { pub fn new() -> IdentityKeyPair { - let k = KeyPair::new(); + let (public_edward, secret_edward) = sign::gen_keypair(); + + let es = from_ed25519_sk(&secret_edward).expect("invalid ed25519 secret key"); + let ep = from_ed25519_pk(&public_edward).expect("invalid ed25519 public key"); + + let secret_key = IdentitySecretKey { + sec_edward: secret_edward, + sec_curve: ecdh::Scalar(es), + }; + let public_key = IdentityKey::new(IdentityPublicKey { + pub_edward: public_edward, + pub_curve: ecdh::GroupElement(ep), + }); + IdentityKeyPair { - version: 1, - secret_key: k.secret_key, - public_key: IdentityKey { - public_key: k.public_key, - }, + version: 2, + secret_key, + public_key, } } @@ -103,9 +122,10 @@ impl IdentityKeyPair { } pub fn encode(&self, e: &mut Encoder) -> EncodeResult<()> { + let version = 2; e.object(3)?; e.u8(0)?; - e.u8(self.version)?; + e.u8(version)?; e.u8(1)?; self.secret_key.encode(e)?; e.u8(2)?; @@ -114,40 +134,61 @@ impl IdentityKeyPair { pub fn decode(d: &mut Decoder) -> DecodeResult { let n = d.object()?; - let mut version = None; - let mut secret_key = None; - let mut public_key = None; + let mut version_option = None; + let mut secret_key_v1_option = None; + let mut public_key_v1_option = None; + let mut secret_key_v2_option = None; + let mut public_key_v2_option = None; for _ in 0..n { match d.u8()? { - 0 => uniq!("IdentityKeyPair::version", version, d.u8()?), + 0 => uniq!("IdentityKeyPair::version", version_option, d.u8()?), 1 => uniq!( - "IdentityKeyPair::secret_key", - secret_key, - SecretKey::decode(d)? + "IdentityKeyPair::secret_key_v1", + secret_key_v1_option, + IdentitySecretKey::decode(d)? ), 2 => uniq!( - "IdentityKeyPair::public_key", - public_key, + "IdentityKeyPair::public_key_v1", + public_key_v1_option, + IdentityKey::decode(d)? + ), + 3 => uniq!( + "IdentityKeyPair::secret_key_v2", + secret_key_v2_option, + IdentitySecretKey::decode(d)? + ), + 4 => uniq!( + "IdentityKeyPair::public_key_v2", + public_key_v2_option, IdentityKey::decode(d)? ), _ => d.skip()?, } } - Ok(IdentityKeyPair { - version: to_field!(version, "IdentityKeyPair::version"), - secret_key: to_field!(secret_key, "IdentityKeyPair::secret_key"), - public_key: to_field!(public_key, "IdentityKeyPair::public_key"), - }) + let version = to_field!(version_option, "IdentityKeyPair::version"); + match version { + 1 => Ok(IdentityKeyPair { + version, + secret_key: to_field!(secret_key_v1_option, "IdentityKeyPair::secret_key_v2"), + public_key: to_field!(public_key_v1_option, "IdentityKeyPair::public_key_v2"), + }), + 2 => Ok(IdentityKeyPair { + version, + secret_key: to_field!(secret_key_v2_option, "IdentityKeyPair::secret_key_v2"), + public_key: to_field!(public_key_v2_option, "IdentityKeyPair::public_key_v2"), + }), + _ => Err(DecodeError::InvalidField("IdentitySecretKey::sec_curve")), + } } } // Prekey /////////////////////////////////////////////////////////////////// -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct PreKey { pub version: u8, pub key_id: PreKeyId, - pub key_pair: KeyPair, + pub key_pair: DHKeyPair, } impl PreKey { @@ -155,7 +196,7 @@ impl PreKey { PreKey { version: 1, key_id: i, - key_pair: KeyPair::new(), + key_pair: DHKeyPair::new(), } } @@ -192,7 +233,7 @@ impl PreKey { match d.u8()? { 0 => uniq!("PreKey::version", version, d.u8()?), 1 => uniq!("PreKey::key_id", key_id, PreKeyId::decode(d)?), - 2 => uniq!("PreKey::key_pair", key_pair, KeyPair::decode(d)?), + 2 => uniq!("PreKey::key_pair", key_pair, DHKeyPair::decode(d)?), _ => d.skip()?, } } @@ -225,7 +266,7 @@ pub enum PreKeyAuth { pub struct PreKeyBundle { pub version: u8, pub prekey_id: PreKeyId, - pub public_key: PublicKey, + pub public_key: DHPublicKey, pub identity_key: IdentityKey, pub signature: Option, } @@ -243,7 +284,7 @@ impl PreKeyBundle { pub fn signed(ident: &IdentityKeyPair, key: &PreKey) -> PreKeyBundle { let ratchet_key = key.key_pair.public_key.clone(); - let signature = ident.secret_key.sign(&ratchet_key.pub_edward.0); + let signature = ident.secret_key.sign(&ratchet_key.pub_curve.0); PreKeyBundle { version: 1, prekey_id: key.key_id, @@ -255,15 +296,17 @@ impl PreKeyBundle { pub fn verify(&self) -> PreKeyAuth { match self.signature { - Some(ref sig) => if self - .identity_key - .public_key - .verify(sig, &self.public_key.pub_edward.0) - { - PreKeyAuth::Valid - } else { - PreKeyAuth::Invalid - }, + Some(ref sig) => { + if self + .identity_key + .public_key + .verify(sig, &self.public_key.pub_curve.0) + { + PreKeyAuth::Valid + } else { + PreKeyAuth::Invalid + } + } None => PreKeyAuth::Unknown, } } @@ -309,7 +352,7 @@ impl PreKeyBundle { 2 => uniq!( "PreKeyBundle::public_key", public_key, - PublicKey::decode(d)? + DHPublicKey::decode(d)? ), 3 => uniq!( "PreKeyBundle::identity_key", @@ -365,35 +408,33 @@ impl fmt::Display for PreKeyId { } } -// Keypair ////////////////////////////////////////////////////////////////// +// DHKeypair ////////////////////////////////////////////////////////////////// -#[derive(Clone)] -pub struct KeyPair { - pub secret_key: SecretKey, - pub public_key: PublicKey, +#[derive(Debug, Clone)] +pub struct DHKeyPair { + pub secret_key: DHSecretKey, + pub public_key: DHPublicKey, } -impl Default for KeyPair { +impl Default for DHKeyPair { fn default() -> Self { Self::new() } } -impl KeyPair { - pub fn new() -> KeyPair { - let (p, s) = sign::gen_keypair(); - - let es = from_ed25519_sk(&s).expect("invalid ed25519 secret key"); - let ep = from_ed25519_pk(&p).expect("invalid ed25519 public key"); +impl DHKeyPair { + pub fn new() -> DHKeyPair { + let random_bytes = rand_bytes(ecdh::SCALARBYTES); + let mut private_key = ecdh::Scalar([0u8; ecdh::SCALARBYTES]); + private_key.0.clone_from_slice(&random_bytes); + let public_key = ecdh::scalarmult_base(&private_key); - KeyPair { - secret_key: SecretKey { - sec_edward: s, - sec_curve: ecdh::Scalar(es), + DHKeyPair { + secret_key: DHSecretKey { + sec_curve: private_key, }, - public_key: PublicKey { - pub_edward: p, - pub_curve: ecdh::GroupElement(ep), + public_key: DHPublicKey { + pub_curve: public_key, }, } } @@ -406,101 +447,149 @@ impl KeyPair { self.public_key.encode(e) } - pub fn decode(d: &mut Decoder) -> DecodeResult { + pub fn decode(d: &mut Decoder) -> DecodeResult { let n = d.object()?; let mut secret_key = None; let mut public_key = None; for _ in 0..n { match d.u8()? { - 0 => uniq!("KeyPair::secret_key", secret_key, SecretKey::decode(d)?), - 1 => uniq!("KeyPair::public_key", public_key, PublicKey::decode(d)?), + 0 => uniq!("KeyPair::secret_key", secret_key, DHSecretKey::decode(d)?), + 1 => uniq!("KeyPair::public_key", public_key, DHPublicKey::decode(d)?), _ => d.skip()?, } } - Ok(KeyPair { + Ok(DHKeyPair { secret_key: to_field!(secret_key, "KeyPair::secret_key"), public_key: to_field!(public_key, "KeyPair::public_key"), }) } } -// SecretKey //////////////////////////////////////////////////////////////// +// IdentitySecretKey //////////////////////////////////////////////////////////////// #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct Zero {} #[derive(Clone)] -pub struct SecretKey { - sec_edward: sign::SecretKey, - sec_curve: ecdh::Scalar, +pub struct IdentitySecretKey { + pub sec_edward: sign::SecretKey, + pub sec_curve: ecdh::Scalar, } -impl SecretKey { +impl IdentitySecretKey { pub fn sign(&self, m: &[u8]) -> Signature { Signature { sig: sign::sign_detached(m, &self.sec_edward), } } - pub fn shared_secret(&self, p: &PublicKey) -> Result<[u8; 32], Zero> { + pub fn shared_secret(&self, p: &DHPublicKey) -> Result<[u8; 32], Zero> { ecdh::scalarmult(&self.sec_curve, &p.pub_curve) .map(|ge| ge.0) .map_err(|()| Zero {}) } pub fn encode(&self, e: &mut Encoder) -> EncodeResult<()> { - e.object(1)?; + e.object(2)?; e.u8(0).and(e.bytes(&self.sec_edward.0))?; + e.u8(1).and(e.bytes(&self.sec_curve.0))?; Ok(()) } - pub fn decode(d: &mut Decoder) -> DecodeResult { + pub fn decode(d: &mut Decoder) -> DecodeResult { let n = d.object()?; - let mut sec_edward = None; + let mut sec_edward_option = None; + let mut sec_curve_option = None; for _ in 0..n { match d.u8()? { 0 => uniq!( - "SecretKey::sec_edward", - sec_edward, + "IdentitySecretKey::sec_edward", + sec_edward_option, Bytes64::decode(d).map(|v| sign::SecretKey(v.array))? ), + 1 => uniq!( + "IdentitySecretKey::sec_curve", + sec_curve_option, + Bytes32::decode(d).map(|v| ecdh::Scalar(v.array))? + ), _ => d.skip()?, } } - let sec_edward = sec_edward.ok_or(DecodeError::MissingField("SecretKey::sec_edward"))?; - let sec_curve = from_ed25519_sk(&sec_edward) - .map(ecdh::Scalar) - .map_err(|()| DecodeError::InvalidField("SecretKey::sec_edward"))?; - Ok(SecretKey { + let sec_edward = + sec_edward_option.ok_or(DecodeError::MissingField("IdentitySecretKey::sec_edward"))?; + let sec_curve = sec_curve_option.unwrap_or( + from_ed25519_sk(&sec_edward) + .map(ecdh::Scalar) + .map_err(|()| DecodeError::InvalidField("IdentitySecretKey::sec_curve"))?, + ); + Ok(IdentitySecretKey { sec_edward, sec_curve, }) } } -// PublicKey //////////////////////////////////////////////////////////////// +// DHSecretKey //////////////////////////////////////////////////////////////// -#[derive(Clone)] -pub struct PublicKey { - pub_edward: sign::PublicKey, - pub_curve: ecdh::GroupElement, +#[derive(Debug, Clone)] +pub struct DHSecretKey { + pub sec_curve: ecdh::Scalar, } -impl PartialEq for PublicKey { - fn eq(&self, other: &PublicKey) -> bool { - self.pub_edward.0 == other.pub_edward.0 && self.pub_curve.0 == other.pub_curve.0 +impl DHSecretKey { + pub fn shared_secret(&self, p: &DHPublicKey) -> Result<[u8; 32], Zero> { + ecdh::scalarmult(&self.sec_curve, &p.pub_curve) + .map(|ge| ge.0) + .map_err(|()| Zero {}) } -} -impl Eq for PublicKey {} + pub fn encode(&self, e: &mut Encoder) -> EncodeResult<()> { + e.object(1)?; + e.u8(1).and(e.bytes(&self.sec_curve.0))?; + Ok(()) + } -impl Debug for PublicKey { - fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { - write!(f, "{:?}", &self.pub_edward.0) + pub fn decode(d: &mut Decoder) -> DecodeResult { + let n = d.object()?; + let mut sec_edward_option = None; + let mut sec_curve_option = None; + for _ in 0..n { + match d.u8()? { + 0 => uniq!( + "DHSecretKey::sec_edward", + sec_edward_option, + Bytes64::decode(d).map(|v| sign::SecretKey(v.array))? + ), + 1 => uniq!( + "DHSecretKey::sec_curve", + sec_curve_option, + Bytes32::decode(d).map(|v| ecdh::Scalar(v.array))? + ), + _ => d.skip()?, + } + } + if let Some(sec_curve) = sec_curve_option { + Ok(DHSecretKey { sec_curve }) + } else if let Some(pub_edward) = sec_edward_option { + let sec_curve = from_ed25519_sk(&pub_edward) + .map(ecdh::Scalar) + .map_err(|()| DecodeError::InvalidField("DHSecretKey::sec_curve"))?; + Ok(DHSecretKey { sec_curve }) + } else { + Err(DecodeError::MissingField("DHSecretKey::sec_edward")) + } } } -impl PublicKey { +// IdentityPublicKey //////////////////////////////////////////////////////////////// + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct IdentityPublicKey { + pub pub_edward: sign::PublicKey, + pub pub_curve: ecdh::GroupElement, +} + +impl IdentityPublicKey { pub fn verify(&self, s: &Signature, m: &[u8]) -> bool { sign::verify_detached(&s.sig, m, &self.pub_edward) } @@ -509,36 +598,98 @@ impl PublicKey { fmt_hex(&self.pub_edward.0) } + pub fn to_dh_public_key(&self) -> DHPublicKey { + DHPublicKey { + pub_curve: self.pub_curve.clone(), + } + } + pub fn encode(&self, e: &mut Encoder) -> EncodeResult<()> { - e.object(1)?; + e.object(2)?; e.u8(0).and(e.bytes(&self.pub_edward.0))?; + e.u8(1).and(e.bytes(&self.pub_curve.0))?; Ok(()) } - pub fn decode(d: &mut Decoder) -> DecodeResult { + pub fn decode(d: &mut Decoder) -> DecodeResult { let n = d.object()?; - let mut pub_edward = None; + let mut pub_edward_option = None; + let mut pub_curve_option = None; for _ in 0..n { match d.u8()? { 0 => uniq!( - "PublicKey::pub_edward", - pub_edward, + "IdentityPublicKey::pub_edward", + pub_edward_option, Bytes32::decode(d).map(|v| sign::PublicKey(v.array))? ), + 1 => uniq!( + "IdentityPublicKey::pub_curve", + pub_curve_option, + Bytes32::decode(d).map(|v| ecdh::GroupElement(v.array))? + ), _ => d.skip()?, } } - let pub_edward = pub_edward.ok_or(DecodeError::MissingField("PublicKey::pub_edward"))?; - let pub_curve = from_ed25519_pk(&pub_edward) - .map(ecdh::GroupElement) - .map_err(|()| DecodeError::InvalidField("PublicKey::pub_edward"))?; - Ok(PublicKey { + let pub_edward = + pub_edward_option.ok_or(DecodeError::MissingField("IdentityPublicKey::pub_edward"))?; + let pub_curve = pub_curve_option.unwrap_or( + from_ed25519_pk(&pub_edward) + .map(ecdh::GroupElement) + .map_err(|()| DecodeError::InvalidField("IdentityPublicKey::pub_curve"))?, + ); + Ok(IdentityPublicKey { pub_edward, pub_curve, }) } } +// DHPublicKey //////////////////////////////////////////////////////////////// + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct DHPublicKey { + pub pub_curve: ecdh::GroupElement, +} + +impl DHPublicKey { + pub fn encode(&self, e: &mut Encoder) -> EncodeResult<()> { + e.object(1)?; + e.u8(1).and(e.bytes(&self.pub_curve.0))?; + Ok(()) + } + + pub fn decode(d: &mut Decoder) -> DecodeResult { + let n = d.object()?; + let mut pub_edward_option = None; + let mut pub_curve_option = None; + for _ in 0..n { + match d.u8()? { + 0 => uniq!( + "DHPublicKey::pub_edward", + pub_edward_option, + Bytes32::decode(d).map(|v| sign::PublicKey(v.array))? + ), + 1 => uniq!( + "DHPublicKey::pub_curve", + pub_curve_option, + Bytes32::decode(d).map(|v| ecdh::GroupElement(v.array))? + ), + _ => d.skip()?, + } + } + if let Some(pub_curve) = pub_curve_option { + Ok(DHPublicKey { pub_curve }) + } else if let Some(pub_edward) = pub_edward_option { + let pub_curve = from_ed25519_pk(&pub_edward) + .map(ecdh::GroupElement) + .map_err(|()| DecodeError::InvalidField("DHPublicKey::pub_curve"))?; + Ok(DHPublicKey { pub_curve }) + } else { + Err(DecodeError::MissingField("DHPublicKey::pub_edward")) + } + } +} + // Random /////////////////////////////////////////////////////////////////// pub fn rand_bytes(size: usize) -> Vec { @@ -620,8 +771,8 @@ mod tests { #[test] fn dh_agreement() { - let a = KeyPair::new(); - let b = KeyPair::new(); + let a = DHKeyPair::new(); + let b = DHKeyPair::new(); let sa = a.secret_key.shared_secret(&b.public_key); let sb = b.secret_key.shared_secret(&a.public_key); assert_eq!(&sa, &sb) @@ -629,7 +780,7 @@ mod tests { #[test] fn sign_and_verify() { - let a = KeyPair::new(); + let a = IdentityKeyPair::new(); let s = a.secret_key.sign(b"foobarbaz"); assert!(a.public_key.verify(&s, b"foobarbaz")); assert!(!a.public_key.verify(&s, b"foobar")); @@ -637,25 +788,33 @@ mod tests { #[test] fn enc_dec_pubkey() { - let k = KeyPair::new(); + let k = DHKeyPair::new(); let r = roundtrip( |mut e| k.public_key.encode(&mut e), - |mut d| PublicKey::decode(&mut d), + |mut d| DHPublicKey::decode(&mut d), ); assert_eq!(k.public_key, r) } #[test] - fn enc_dec_seckey() { - let k = KeyPair::new(); + fn identity() { + let k = IdentityKeyPair::new(); let r = roundtrip( |mut e| k.secret_key.encode(&mut e), - |mut d| SecretKey::decode(&mut d), + |mut d| IdentitySecretKey::decode(&mut d), ); assert_eq!(&k.secret_key.sec_edward.0[..], &r.sec_edward.0[..]); assert_eq!(&k.secret_key.sec_curve.0[..], &r.sec_curve.0[..]) } - + #[test] + fn enc_dec_dhkey() { + let k = DHKeyPair::new(); + let r = roundtrip( + |mut e| k.secret_key.encode(&mut e), + |mut d| DHSecretKey::decode(&mut d), + ); + assert_eq!(&k.secret_key.sec_curve.0[..], &r.sec_curve.0[..]) + } #[test] fn enc_dec_prekey_bundle() { let i = IdentityKeyPair::new(); @@ -685,7 +844,7 @@ mod tests { #[test] fn degenerated_key() { - let mut k = KeyPair::new(); + let mut k = DHKeyPair::new(); for i in 0..k.public_key.pub_curve.0.len() { k.public_key.pub_curve.0[i] = 0 } diff --git a/src/internal/message.rs b/src/internal/message.rs index 01d1a2f..2a5a080 100644 --- a/src/internal/message.rs +++ b/src/internal/message.rs @@ -18,7 +18,7 @@ use cbor::skip::Skip; use cbor::{Config, Decoder, Encoder}; use internal::derived::{Mac, MacKey, Nonce}; -use internal::keys::{IdentityKey, PreKeyId, PublicKey}; +use internal::keys::{DHPublicKey, IdentityKey, PreKeyId}; use internal::types::{DecodeError, DecodeResult, EncodeResult}; use internal::util::fmt_hex; use sodiumoxide::randombytes; @@ -139,7 +139,7 @@ impl<'r> Message<'r> { pub struct PreKeyMessage<'r> { pub prekey_id: PreKeyId, - pub base_key: Cow<'r, PublicKey>, + pub base_key: Cow<'r, DHPublicKey>, pub identity_key: Cow<'r, IdentityKey>, pub message: CipherMessage<'r>, } @@ -175,7 +175,7 @@ impl<'r> PreKeyMessage<'r> { for _ in 0..n { match d.u8()? { 0 => uniq!("PreKeyMessage::prekey_id", prekey_id, PreKeyId::decode(d)?), - 1 => uniq!("PreKeyMessage::base_key", base_key, PublicKey::decode(d)?), + 1 => uniq!("PreKeyMessage::base_key", base_key, DHPublicKey::decode(d)?), 2 => uniq!( "PreKeyMessage::identity_key", identity_key, @@ -200,7 +200,7 @@ pub struct CipherMessage<'r> { pub session_tag: SessionTag, pub counter: Counter, pub prev_counter: Counter, - pub ratchet_key: Cow<'r, PublicKey>, + pub ratchet_key: Cow<'r, DHPublicKey>, pub cipher_text: Vec, } @@ -253,7 +253,7 @@ impl<'r> CipherMessage<'r> { 3 => uniq!( "CipherMessage::ratchet_key", ratchet_key, - PublicKey::decode(d)? + DHPublicKey::decode(d)? ), 4 => uniq!("CipherMessage::cipher_text", cipher_text, d.bytes()?), _ => d.skip()?, @@ -375,15 +375,15 @@ impl<'r> Envelope<'r> { mod tests { use super::*; use internal::derived::MacKey; - use internal::keys::{IdentityKey, KeyPair, PreKeyId}; + use internal::keys::{DHKeyPair, IdentityKeyPair, PreKeyId}; use std::borrow::Cow; #[test] fn enc_dec_envelope() { let mk = MacKey::new([1; 32]); - let bk = KeyPair::new().public_key; - let ik = IdentityKey::new(KeyPair::new().public_key); - let rk = KeyPair::new().public_key; + let bk = DHKeyPair::new().public_key; + let ik = IdentityKeyPair::new().public_key; + let rk = DHKeyPair::new().public_key; let tg = SessionTag::new(); let m1 = Message::Keyed(PreKeyMessage { diff --git a/src/internal/session.rs b/src/internal/session.rs index 0d6174c..dcdab5b 100644 --- a/src/internal/session.rs +++ b/src/internal/session.rs @@ -27,7 +27,7 @@ use cbor::skip::Skip; use cbor::{self, Config, Decoder, Encoder}; use hkdf::{Info, Input, Salt}; use internal::derived::{CipherKey, DerivedSecrets, MacKey}; -use internal::keys::{self, KeyPair, PublicKey}; +use internal::keys::{self, DHKeyPair, DHPublicKey}; use internal::keys::{IdentityKey, IdentityKeyPair, PreKey, PreKeyBundle, PreKeyId}; use internal::message::{CipherMessage, Counter, Envelope, Message, PreKeyMessage, SessionTag}; use internal::types::{DecodeError, DecodeResult, EncodeResult, InternalError}; @@ -46,8 +46,8 @@ impl RootKey { pub fn dh_ratchet( &self, - ours: &KeyPair, - theirs: &PublicKey, + ours: &DHKeyPair, + theirs: &DHPublicKey, ) -> Result<(RootKey, ChainKey), Error> { let secret = ours.secret_key.shared_secret(theirs)?; let dsecs = DerivedSecrets::kdf(Input(&secret), Salt(&self.key), Info(b"dh_ratchet")); @@ -139,11 +139,11 @@ impl ChainKey { #[derive(Clone)] pub struct SendChain { chain_key: ChainKey, - ratchet_key: KeyPair, + ratchet_key: DHKeyPair, } impl SendChain { - pub fn new(ck: ChainKey, rk: KeyPair) -> SendChain { + pub fn new(ck: ChainKey, rk: DHKeyPair) -> SendChain { SendChain { chain_key: ck, ratchet_key: rk, @@ -165,7 +165,7 @@ impl SendChain { for _ in 0..n { match d.u8()? { 0 => uniq!("SendChain::chain_key", chain_key, ChainKey::decode(d)?), - 1 => uniq!("SendChain::ratchet_key", ratchet_key, KeyPair::decode(d)?), + 1 => uniq!("SendChain::ratchet_key", ratchet_key, DHKeyPair::decode(d)?), _ => d.skip()?, } } @@ -183,12 +183,12 @@ const MAX_COUNTER_GAP: usize = 1000; #[derive(Clone)] pub struct RecvChain { chain_key: ChainKey, - ratchet_key: PublicKey, + ratchet_key: DHPublicKey, message_keys: VecDeque, } impl RecvChain { - pub fn new(ck: ChainKey, rk: PublicKey) -> RecvChain { + pub fn new(ck: ChainKey, rk: DHPublicKey) -> RecvChain { RecvChain { chain_key: ck, ratchet_key: rk, @@ -291,7 +291,11 @@ impl RecvChain { for _ in 0..n { match d.u8()? { 0 => uniq!("RecvChain::chain_key", chain_key, ChainKey::decode(d)?), - 1 => uniq!("RecvChain::ratchet_key", ratchet_key, PublicKey::decode(d)?), + 1 => uniq!( + "RecvChain::ratchet_key", + ratchet_key, + DHPublicKey::decode(d)? + ), 2 => uniq!("RecvChain::message_keys", message_keys, { let lv = d.array()?; let mut vm = VecDeque::with_capacity(lv); @@ -406,26 +410,26 @@ pub struct Session { counter: usize, local_identity: I, remote_identity: IdentityKey, - pending_prekey: Option<(PreKeyId, PublicKey)>, + pending_prekey: Option<(PreKeyId, DHPublicKey)>, session_states: BTreeMap>, } struct AliceParams<'r> { alice_ident: &'r IdentityKeyPair, - alice_base: &'r KeyPair, + alice_base: &'r DHKeyPair, bob: &'r PreKeyBundle, } struct BobParams<'r> { bob_ident: &'r IdentityKeyPair, - bob_prekey: KeyPair, + bob_prekey: DHKeyPair, alice_ident: &'r IdentityKey, - alice_base: &'r PublicKey, + alice_base: &'r DHPublicKey, } impl> Session { pub fn init_from_prekey(alice: I, pk: PreKeyBundle) -> Result, Error> { - let alice_base = KeyPair::new(); + let alice_base = DHKeyPair::new(); let state = SessionState::init_as_alice(&AliceParams { alice_ident: alice.borrow(), alice_base: &alice_base, @@ -447,7 +451,6 @@ impl> Session { Ok(session) } - #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))] pub fn init_from_message( ours: I, store: &mut S, @@ -580,7 +583,6 @@ impl> Session { // state left is the one to be inserted, but if Alice and Bob do not // manage to agree on a session state within `usize::MAX` it is probably // of least concern. - #[cfg_attr(feature = "cargo-clippy", allow(map_entry))] fn insert_session_state(&mut self, t: SessionTag, s: SessionState) { if self.session_states.contains_key(&t) { if let Some(x) = self.session_states.get_mut(&t) { @@ -694,7 +696,7 @@ impl> Session { for _ in 0..n { match d.u8()? { 0 => uniq!("PendingPreKey::id", id, PreKeyId::decode(d)?), - 1 => uniq!("PendingPreKey::pk", pk, PublicKey::decode(d)?), + 1 => uniq!("PendingPreKey::pk", pk, DHPublicKey::decode(d)?), _ => d.skip()?, } } @@ -750,7 +752,7 @@ impl SessionState { buf.extend( &p.alice_base .secret_key - .shared_secret(&p.bob.identity_key.public_key)?, + .shared_secret(&p.bob.identity_key.public_key.to_dh_public_key())?, ); buf.extend(&p.alice_base.secret_key.shared_secret(&p.bob.public_key)?); buf @@ -766,7 +768,7 @@ impl SessionState { recv_chains.push_front(RecvChain::new(chainkey, p.bob.public_key.clone())); // sending chain - let send_ratchet = KeyPair::new(); + let send_ratchet = DHKeyPair::new(); let (rok, chk) = rootkey.dh_ratchet(&send_ratchet, &p.bob.public_key)?; let send_chain = SendChain::new(chk, send_ratchet); @@ -784,7 +786,7 @@ impl SessionState { buf.extend( &p.bob_prekey .secret_key - .shared_secret(&p.alice_ident.public_key)?, + .shared_secret(&p.alice_ident.public_key.to_dh_public_key())?, ); buf.extend(&p.bob_ident.secret_key.shared_secret(p.alice_base)?); buf.extend(&p.bob_prekey.secret_key.shared_secret(p.alice_base)?); @@ -806,14 +808,15 @@ impl SessionState { }) } - fn ratchet(&mut self, ratchet_key: PublicKey) -> Result<(), Error> { - let new_ratchet = KeyPair::new(); + fn ratchet(&mut self, ratchet_key: DHPublicKey) -> Result<(), Error> { + let new_ratchet = DHKeyPair::new(); let (recv_root_key, recv_chain_key) = self .root_key .dh_ratchet(&self.send_chain.ratchet_key, &ratchet_key)?; - let (send_root_key, send_chain_key) = recv_root_key.dh_ratchet(&new_ratchet, &ratchet_key)?; + let (send_root_key, send_chain_key) = + recv_root_key.dh_ratchet(&new_ratchet, &ratchet_key)?; let recv_chain = RecvChain::new(recv_chain_key, ratchet_key); let send_chain = SendChain::new(send_chain_key, new_ratchet); @@ -833,7 +836,7 @@ impl SessionState { fn encrypt<'r>( self: &'r mut SessionState, ident: &'r IdentityKey, - pending: &'r Option<(PreKeyId, PublicKey)>, + pending: &'r Option<(PreKeyId, DHPublicKey)>, tag: SessionTag, plain: &[u8], ) -> EncodeResult { @@ -1013,7 +1016,7 @@ impl std::error::Error for Error { self.as_str() } - fn cause(&self) -> Option<&std::error::Error> { + fn cause(&self) -> Option<&dyn std::error::Error> { match *self { Error::PreKeyStoreError(ref e) => Some(e), _ => None, @@ -1099,8 +1102,9 @@ mod tests { &bob_ident, &mut bob_store, &alices[0].encrypt(b"hello").unwrap().into_owned(), - ).unwrap() - .0; + ) + .unwrap() + .0; for a in &mut alices { for _ in 0..900 { @@ -1114,10 +1118,9 @@ mod tests { assert_eq!(total_size, bob.session_states.len()); for a in &mut alices { - assert!( - bob.decrypt(&mut bob_store, &a.encrypt(b"Hello Bob!").unwrap()) - .is_ok() - ); + assert!(bob + .decrypt(&mut bob_store, &a.encrypt(b"Hello Bob!").unwrap()) + .is_ok()); } } diff --git a/src/internal/types.rs b/src/internal/types.rs index 805fe54..aec25d3 100644 --- a/src/internal/types.rs +++ b/src/internal/types.rs @@ -63,7 +63,7 @@ impl Error for EncodeError { "EncodeError" } - fn cause(&self) -> Option<&Error> { + fn cause(&self) -> Option<&dyn Error> { match *self { EncodeError::Internal(ref e) => Some(e), EncodeError::Encoder(ref e) => Some(e), @@ -117,7 +117,7 @@ impl Error for DecodeError { "DecodeError" } - fn cause(&self) -> Option<&Error> { + fn cause(&self) -> Option<&dyn Error> { match *self { DecodeError::Decoder(ref e) => Some(e), _ => None, diff --git a/src/keys.rs b/src/keys.rs index abdd2b5..6db07e4 100644 --- a/src/keys.rs +++ b/src/keys.rs @@ -15,14 +15,14 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +pub use internal::keys::DHPublicKey; +pub use internal::keys::DHSecretKey; pub use internal::keys::IdentityKey; pub use internal::keys::IdentityKeyPair; pub use internal::keys::PreKey; pub use internal::keys::PreKeyAuth; pub use internal::keys::PreKeyBundle; pub use internal::keys::PreKeyId; -pub use internal::keys::PublicKey; -pub use internal::keys::SecretKey; pub use internal::keys::Signature; pub use internal::keys::Zero; pub use internal::keys::MAX_PREKEY_ID;