···6677use crate::{
88 keccakf::KECCAK_BUFFER_SIZE,
99- strobe::{SecurityParameter, StrobeState},
99+ strobe::{Role, SecurityParameter, StrobeState},
1010};
11111212extern crate std;
13131414#[test]
1515fn test_init_128() {
1616- let s = StrobeState::new(b"", SecurityParameter::B128);
1616+ let s = StrobeState::new(b"", SecurityParameter::B128, Role::Sender);
17171818 let expected_st: [u8; KECCAK_BUFFER_SIZE] = [
1919 0x9c, 0x7f, 0x16, 0x8f, 0xf8, 0xfd, 0x55, 0xda, 0x2a, 0xa7, 0x3c, 0x23, 0x55, 0x65, 0x35,
···37373838#[test]
3939fn test_init_256() {
4040- let s = StrobeState::new(b"", SecurityParameter::B256);
4040+ let s = StrobeState::new(b"", SecurityParameter::B256, Role::Sender);
41414242 let expected_st: [u8; KECCAK_BUFFER_SIZE] = [
4343 0x37, 0xc1, 0x15, 0x06, 0xed, 0x61, 0xe7, 0xda, 0x7c, 0x1a, 0x2f, 0x2c, 0x1f, 0x49, 0x74,
···6262#[test]
6363fn test_metadata() {
6464 // We will accumulate output over 3 operations and 3 meta-operations
6565- let mut s = StrobeState::new(b"metadatatest", SecurityParameter::B256);
6565+ let mut s = StrobeState::new(b"metadatatest", SecurityParameter::B256, Role::Sender);
6666 let mut output = std::vec::Vec::new();
67676868 let buf = b"meta1";
···116116117117#[test]
118118fn test_seq() {
119119- let mut s = StrobeState::new(b"seqtest", SecurityParameter::B256);
119119+ let mut s = StrobeState::new(b"seqtest", SecurityParameter::B256, Role::Sender);
120120121121 let mut buf = [0u8; 10];
122122 s.prf(&mut buf[..]);
···172172#[test]
173173fn test_enc_correctness() {
174174 let orig_msg = b"Hello there";
175175- let mut tx = StrobeState::new(b"enccorrectnesstest", SecurityParameter::B256);
176176- let mut rx = StrobeState::new(b"enccorrectnesstest", SecurityParameter::B256);
175175+ let mut tx = StrobeState::new(b"enccorrectnesstest", SecurityParameter::B256, Role::Sender);
176176+ let mut rx = StrobeState::new(
177177+ b"enccorrectnesstest",
178178+ SecurityParameter::B256,
179179+ Role::Receiver,
180180+ );
177181178182 tx.key(b"the-combination-on-my-luggage");
179183 rx.key(b"the-combination-on-my-luggage");
···188192189193#[test]
190194fn test_mac_correctness_and_soundness() {
191191- let mut tx = StrobeState::new(b"mactest", SecurityParameter::B256);
192192- let mut rx = StrobeState::new(b"mactest", SecurityParameter::B256);
195195+ let mut tx = StrobeState::new(b"mactest", SecurityParameter::B256, Role::Sender);
196196+ let mut rx = StrobeState::new(b"mactest", SecurityParameter::B256, Role::Receiver);
193197194198 // Just do some stuff with the state
195199···217221218222#[test]
219223fn test_long_inputs() {
220220- let mut s = StrobeState::new(b"bigtest", SecurityParameter::B256);
224224+ let mut s = StrobeState::new(b"bigtest", SecurityParameter::B256, Role::Sender);
221225 const BIG_N: usize = 9823;
222226 const SMALL_N: usize = 65;
223227 let big_data = [0x34u8; BIG_N];
···274278fn test_streaming_correctness() {
275279 // Compute a few things without breaking up their inputs
276280 let one_shot_st: std::vec::Vec<u8> = {
277277- let mut s = StrobeState::new(b"streamingtest", SecurityParameter::B256);
281281+ let mut s = StrobeState::new(b"streamingtest", SecurityParameter::B256, Role::Receiver);
278282279283 s.ad(b"mynonce");
280284···290294 };
291295 // Now do the same thing but stream the inputs
292296 let streamed_st: std::vec::Vec<u8> = {
293293- let mut s = StrobeState::new(b"streamingtest", SecurityParameter::B256);
297297+ let mut s = StrobeState::new(b"streamingtest", SecurityParameter::B256, Role::Receiver);
294298295299 s.ad(b"my");
296300 s.ad(b"nonce");
+2-2
src/herding_kats/harness.rs
···4455use serde::{Deserialize, Deserializer, de};
6677-use crate::strobe::{SecurityParameter, StrobeState};
77+use crate::strobe::{Role, SecurityParameter, StrobeState};
8899/// The harness we will put on our KATs so we can herd them and make them do tests.
1010/// (This is the top-level structure of the JSON we find in the test vectors)
···113113 operations,
114114 } = serde_json::from_reader(file).unwrap();
115115116116- let mut strobe = StrobeState::new(proto_string.as_bytes(), security);
116116+ let mut strobe = StrobeState::new(proto_string.as_bytes(), security, Role::Sender);
117117118118 operations.into_iter().for_each(
119119 |KatOperation {
+2
src/lib.rs
···66#[cfg(test)]
77mod herding_kats;
88mod keccakf;
99+mod opflags;
1010+mod ops;
911pub mod strobe;
10121113/// Version of Strobe that this crate implements.
···11-use enumflags2::{BitFlag, BitFlags};
22-use subtle::ConstantTimeEq;
11+use subtle::{Choice, ConstantTimeEq};
3243use crate::{
54 GarbledError, STROBE_VERSION,
65 keccakf::{KECCAK_BUFFER_SIZE, KeccakF1600},
66+ opflags::OpFlags,
77+ ops,
78};
8999-#[enumflags2::bitflags]
1010+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1011#[repr(u8)]
1111-#[derive(Copy, Clone, Debug, PartialEq)]
1212-enum OpFlags {
1313- /// Is data being moved inbound
1414- Inbound = 0b000001, // 1<<0
1515- /// Is data being sent to the application
1616- App = 0b000010, // 1<<1
1717- /// Does this operation use cipher output
1818- Cipher = 0b000100, // 1<<2
1919- /// Is data being sent for transport
2020- Transport = 0b001000, // 1<<3
2121- /// Use exclusively for metadata operations
2222- Meta = 0b010000, // 1<<4
2323- /// Reserved and currently unimplemented. Using this will cause a panic.
2424- KeyTree = 0b100000, // 1<<5
1212+pub enum Role {
1313+ Sender,
1414+ Receiver,
2515}
26162727-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2828-enum Role {
2929- Sender,
3030- Receiver,
1717+impl ConstantTimeEq for Role {
1818+ fn ct_eq(&self, other: &Self) -> Choice {
1919+ (*self as u8).ct_eq(&(*other as u8))
2020+ }
3121}
32223323#[derive(Debug, Clone, Copy)]
···4939 position: usize,
5040 /// Index into `state`
5141 start: usize,
5252- /// Represents whether we're a sender or a receiver or uninitialized
5353- role: Option<Role>,
4242+ /// Represents whether we're a sender or a receiver
4343+ role: Role,
5444 /// The last operation performed. This is to verify that the `more` flag is only used across
5545 /// identical operations.
5656- prev_flags: BitFlags<OpFlags>,
4646+ prev_flags: OpFlags,
5747}
58485949macro_rules! define_mut_operations {
···6454 #[$doc]
6555 pub fn $name(&mut self, data: &mut [u8]) {
6656 let flags = $flags;
6767- let prev_flags = self.prev_flags.bits();
6868- let more = prev_flags.ct_eq(&flags.bits());
6969- self.operate(flags, data, bool::from(more));
5757+ let prev_flags = self.prev_flags;
5858+ let more = prev_flags.ct_eq(&flags);
5959+ self.operate(flags, data, more);
7060 }
7161 )*
7262 };
···8070 #[$doc]
8171 pub fn $name(&mut self, data: &[u8]) {
8272 let flags = $flags;
8383- let prev_flags = self.prev_flags.bits();
8484- let more = prev_flags.ct_eq(&flags.bits());
8585- self.operate_no_mutate(flags, data, bool::from(more));
7373+ let prev_flags = self.prev_flags;
7474+ let more = prev_flags.ct_eq(&flags);
7575+ self.operate_no_mutate(flags, data, more);
8676 }
8777 )*
8878 };
···112102113103impl StrobeState {
114104 /// Makes a new `StrobeTransport` object with a given protocol byte string and security parameter.
115115- pub fn new(protocol: &[u8], sec: SecurityParameter) -> Self {
105105+ pub fn new(protocol: &[u8], sec: SecurityParameter, role: Role) -> Self {
116106 let rate = KECCAK_BUFFER_SIZE - (sec as usize) / 4 - 2;
117107 assert!((1..254).contains(&rate));
118108···132122 rate,
133123 position: 0,
134124 start: 0,
135135- role: None,
136136- prev_flags: OpFlags::empty(),
125125+ role,
126126+ prev_flags: OpFlags::EMPTY,
137127 };
138128139129 // Mix the protocol into the state
···154144 pub fn reset_ops(&mut self) {
155145 // This prevents streaming so to always make the prev_flags == flags
156146 // comparison always fail
157157- self.prev_flags = OpFlags::empty();
147147+ self.prev_flags = OpFlags::EMPTY;
158148 }
159149160150 // Runs the permutation function on the internal state
···263253264254 /// Mixes the current state index and flags into the state, accounting for whether we are
265255 /// sending or receiving
266266- fn begin_op(&mut self, mut flags: BitFlags<OpFlags>) {
267267- if flags.contains(OpFlags::Transport) {
268268- let op_role = if flags.contains(OpFlags::Inbound) {
256256+ fn begin_op(&mut self, mut flags: OpFlags) {
257257+ if flags.contains(OpFlags::TRANSPORT).into() {
258258+ let op_role = if flags.contains(OpFlags::INBOUND).into() {
269259 Role::Receiver
270260 } else {
271261 Role::Sender
272262 };
273273-274274- // If uninitialized, take on the direction of the first directional operation we get
275275- if self.role.is_none() {
276276- self.role = Some(op_role);
277277- }
278263279264 // So that the sender and receiver agree, toggle the I flag as necessary
280280- // This is equivalent to flags ^= is_receiver
281281- flags.set(OpFlags::Inbound, self.role.unwrap() != op_role);
265265+ flags.set(OpFlags::INBOUND, self.role.ct_ne(&op_role));
282266 }
283267284268 let old_start = self.start;
···287271 // Mix in the position and flags
288272 self.absorb(&[old_start as u8, flags.bits()]);
289273290290- let force_permutation = flags.contains(OpFlags::Cipher) || flags.contains(OpFlags::KeyTree);
291291- if force_permutation && self.position != 0 {
274274+ let mut force_permutation = flags.intersects(OpFlags::CIPHER | OpFlags::KEYTREE);
275275+ force_permutation &= self.position.ct_ne(&0);
276276+277277+ if force_permutation.into() {
292278 self.permutation_f();
293279 }
294280 }
···296282 /// Performs the state / data transformation that corresponds to the given flags. If `more` is
297283 /// given, this will treat `data` as a continuation of the data given in the previous
298284 /// call to `operate`.
299299- fn operate(&mut self, flags: BitFlags<OpFlags>, data: &mut [u8], more: bool) {
285285+ fn operate(&mut self, flags: OpFlags, data: &mut [u8], more: Choice) {
300286 self.prev_flags = flags;
301287302288 // If `more` isn't set, this is a new operation. Do the begin_op sequence
303303- if !more {
289289+ if !bool::from(more) {
304290 self.begin_op(flags);
305291 }
306292307293 // Meta-ness is only relevant for `begin_op`. Remove it to simplify the below logic.
308308- let flags = flags & !OpFlags::Meta;
294294+ let flags = flags & !OpFlags::META;
295295+296296+ // Flags that don't pass this assertion should normally call `absorb`, but `absorb` does not mutate,
297297+ // so the implementor should have used operate_no_mutate instead
298298+ // RATCHET is special-cased to never call operate directly
299299+ debug_assert!(flags != ops::KEY && bool::from(flags.contains(OpFlags::CIPHER)));
309300310310- // TODO?: Assert that input is empty under some flag conditions
311311- if flags.contains(OpFlags::Cipher | OpFlags::Transport) && !flags.contains(OpFlags::Inbound)
312312- {
313313- // This is equivalent to the `duplex` operation in the Python implementation, with
314314- // `cafter = True`
315315- if flags == OpFlags::Cipher | OpFlags::Transport {
316316- // This is `send_mac`. Pretend the input is all zeros
317317- self.copy_state(data);
318318- } else {
319319- self.absorb_and_set(data);
320320- }
321321- } else if flags == OpFlags::Inbound | OpFlags::App | OpFlags::Cipher {
322322- // Special case of case below. This is PRF. Use `squeeze` instead of `exchange`.
323323- self.squeeze(data);
324324- } else if flags.contains(OpFlags::Cipher) {
325325- // This is equivalent to the `duplex` operation in the Python implementation, with
326326- // `cbefore = True`
327327- self.exchange(data);
328328- } else {
329329- // This should normally call `absorb`, but `absorb` does not mutate, so the implementor
330330- // should have used operate_no_mutate instead
331331- unreachable!("operate should not be called for operations that do not require mutation")
301301+ match flags {
302302+ ops::PRF => self.squeeze(data),
303303+ ops::SEND_MAC => self.copy_state(data),
304304+ ops::SEND_ENC => self.absorb_and_set(data),
305305+ _ => self.exchange(data),
332306 }
333307 }
334308335309 /// Performs the state transformation that corresponds to the given flags. If `more` is given,
336310 /// this will treat `data` as a continuation of the data given in the previous call to
337311 /// `operate`. This uses non-mutating variants of the specializations of the `duplex` function.
338338- fn operate_no_mutate(&mut self, flags: BitFlags<OpFlags>, data: &[u8], more: bool) {
312312+ fn operate_no_mutate(&mut self, flags: OpFlags, data: &[u8], more: Choice) {
339313 self.prev_flags = flags;
340314341315 // If `more` isn't set, this is a new operation. Do the begin_op sequence
342342- if !more {
316316+ if !bool::from(more) {
343317 self.begin_op(flags);
344318 }
345319320320+ // Meta-ness is only relevant for `begin_op`. Remove it to simplify the below logic.
321321+ let flags = flags & !OpFlags::META;
322322+323323+ // Flags that trigger the assertion to fail are mutating operations.
324324+ // RATCHET is special cased to never call operate/operate_no_mutate directly
325325+ debug_assert!(
326326+ flags != ops::PRF && !bool::from(flags.contains(OpFlags::CIPHER | OpFlags::TRANSPORT))
327327+ || bool::from(flags.contains(OpFlags::INBOUND))
328328+ );
329329+346330 // There are no non-mutating variants of things with flags & (C | T | I) == C | T
347347- if flags.contains(OpFlags::Cipher | OpFlags::Transport) && !flags.contains(OpFlags::Inbound)
348348- {
349349- unreachable!("operate_no_mutate called on something that requires mutation")
350350- } else if flags.contains(OpFlags::Cipher) {
331331+ if flags.contains(OpFlags::CIPHER).into() {
351332 // This is equivalent to a non-mutating form of the `duplex` operation in the Python
352333 // implementation, with `cbefore = True`
353334 self.overwrite(data);
···358339 }
359340 }
360341361361- fn recv_mac_inner(
362362- &mut self,
363363- mac_copy: &mut [u8],
364364- flags: BitFlags<OpFlags>,
365365- ) -> Result<(), GarbledError> {
342342+ fn recv_mac_inner(&mut self, mac_copy: &mut [u8], flags: OpFlags) -> Result<(), GarbledError> {
366343 // recv_mac can never be streamed
367367- self.operate(flags, mac_copy, false);
344344+ self.operate(flags, mac_copy, Choice::from(0u8));
368345369346 // Constant-time MAC check. This accumulates the truth values of byte == 0
370347 let all_zero: bool = mac_copy
371348 .iter()
372372- .fold(subtle::Choice::from(1u8), |all_zero, b| {
373373- all_zero & 0u8.ct_eq(b)
374374- })
349349+ .fold(Choice::from(1u8), |all_zero, b| all_zero & 0u8.ct_eq(b))
375350 .into();
376351377352 if all_zero { Ok(()) } else { Err(GarbledError) }
···380355 pub fn recv_mac<const N: usize>(&mut self, mac: &[u8; N]) -> Result<(), GarbledError> {
381356 let mut mac_copy = *mac;
382357383383- self.recv_mac_inner(
384384- &mut mac_copy,
385385- OpFlags::Inbound | OpFlags::Cipher | OpFlags::Transport,
386386- )
358358+ self.recv_mac_inner(&mut mac_copy, ops::RECV_MAC)
387359 }
388360389361 pub fn meta_recv_mac<const N: usize>(&mut self, mac: &[u8; N]) -> Result<(), GarbledError> {
390362 let mut mac_copy = *mac;
391363392392- self.recv_mac_inner(
393393- &mut mac_copy,
394394- OpFlags::Inbound | OpFlags::Cipher | OpFlags::Transport | OpFlags::Meta,
395395- )
364364+ self.recv_mac_inner(&mut mac_copy, ops::META_RECV_MAC)
396365 }
397366398398- fn ratchet_inner(&mut self, num_bytes_to_zero: usize, flags: BitFlags<OpFlags>) {
367367+ fn ratchet_inner(&mut self, num_bytes_to_zero: usize, flags: OpFlags) {
399368 let more = self.prev_flags.bits().ct_eq(&flags.bits());
400369401370 // We don't make an `operate` call, since this is a super special case. That means we have
···410379 }
411380412381 pub fn ratchet(&mut self, num_bytes_to_zero: usize) {
413413- let flags = BitFlags::from(OpFlags::Cipher);
414414-415415- self.ratchet_inner(num_bytes_to_zero, flags);
382382+ self.ratchet_inner(num_bytes_to_zero, ops::RATCHET);
416383 }
417384418385 pub fn meta_ratchet(&mut self, num_bytes_to_zero: usize) {
419419- let flags = OpFlags::Cipher | OpFlags::Meta;
420420-421421- self.ratchet_inner(num_bytes_to_zero, flags);
386386+ self.ratchet_inner(num_bytes_to_zero, ops::META_RATCHET);
422387 }
423388424389 define_mut_operations! {
425390 /// SEND ENC
426426- pub fn send_enc(OpFlags::App | OpFlags::Cipher | OpFlags::Transport);
391391+ pub fn send_enc(ops::SEND_ENC);
427392 /// META SEND ENC
428428- pub fn meta_send_enc(OpFlags::App | OpFlags::Cipher | OpFlags::Transport | OpFlags::Meta);
393393+ pub fn meta_send_enc(ops::META_SEND_ENC);
429394 /// RECV ENV
430430- pub fn recv_enc(OpFlags::Inbound | OpFlags::App | OpFlags::Cipher | OpFlags::Transport);
395395+ pub fn recv_enc(ops::RECV_ENC);
431396 /// META RECV ENC
432432- pub fn meta_recv_enc(OpFlags::Inbound | OpFlags::App | OpFlags::Cipher | OpFlags::Transport | OpFlags::Meta);
397397+ pub fn meta_recv_enc(ops::META_RECV_ENC);
433398 /// SEND MAC
434434- pub fn send_mac(OpFlags::Cipher | OpFlags::Transport);
399399+ pub fn send_mac(ops::SEND_MAC);
435400 /// META SEND MAC
436436- pub fn meta_send_mac(OpFlags::Cipher | OpFlags::Transport | OpFlags::Meta);
401401+ pub fn meta_send_mac(ops::META_SEND_MAC);
437402 /// PRF
438438- pub fn prf(OpFlags::Inbound | OpFlags::App | OpFlags::Cipher);
403403+ pub fn prf(ops::PRF);
439404 /// META PRF
440440- pub fn meta_prf(OpFlags::Inbound | OpFlags::App | OpFlags::Cipher | OpFlags::Meta);
405405+ pub fn meta_prf(ops::META_PRF);
441406 }
442407443408 define_non_mut_operations! {
444409 /// AD
445445- pub fn ad(BitFlags::from(OpFlags::App));
410410+ pub fn ad(ops::AD);
446411 /// META AD
447447- pub fn meta_ad(OpFlags::App | OpFlags::Meta);
412412+ pub fn meta_ad(ops::META_AD);
448413 /// KEY
449449- pub fn key(OpFlags::App | OpFlags::Cipher);
414414+ pub fn key(ops::KEY);
450415 /// META KEY
451451- pub fn meta_key(OpFlags::App | OpFlags::Cipher | OpFlags::Meta);
416416+ pub fn meta_key(ops::META_KEY);
452417 /// SEND CLR
453453- pub fn send_clr(OpFlags::App | OpFlags::Transport);
418418+ pub fn send_clr(ops::SEND_CLR);
454419 /// META SEND CLR
455455- pub fn meta_send_clr(OpFlags::App | OpFlags::Transport | OpFlags::Meta);
420420+ pub fn meta_send_clr(ops::META_SEND_CLR);
456421 /// RECV CLR
457457- pub fn recv_clr(OpFlags::Inbound | OpFlags::App | OpFlags::Transport);
422422+ pub fn recv_clr(ops::RECV_CLR);
458423 /// META RECV CLR
459459- pub fn meta_recv_clr(OpFlags::Inbound | OpFlags::App | OpFlags::Transport | OpFlags::Meta);
424424+ pub fn meta_recv_clr(ops::META_RECV_CLR);
460425 }
461426}
462427···468433469434 #[test]
470435 fn version_formatting() {
471471- let s = StrobeState::new(b"", SecurityParameter::B128);
436436+ let s = StrobeState::new(b"", SecurityParameter::B128, Role::Sender);
472437473438 let display = std::format!("{s}");
474439 let debug = std::format!("{s:?}");