dkls23/proto/tags.rs
1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4//! Module for handling message tags and message relay functionality.
5//!
6//! This module provides functionality for filtering and managing message relays,
7//! including support for message tags, message filtering, and round-based message handling.
8//! It includes structures for managing expected messages and handling message rounds.
9
10use std::{
11 collections::HashMap,
12 ops::{Deref, DerefMut},
13};
14
15use bytemuck::{AnyBitPattern, NoUninit};
16use zeroize::Zeroizing;
17
18use sl_mpc_mate::coord::*;
19
20use crate::{
21 pairs::Pairs,
22 proto::{
23 check_abort, EncryptedMessage, EncryptionScheme, MessageTag, MsgId,
24 Relay, SignedMessage, Wrap,
25 },
26 setup::{ProtocolParticipant, ABORT_MESSAGE_TAG},
27};
28
29/// Errors that can occur during message relay operations.
30#[derive(Debug)]
31pub enum Error {
32 /// Protocol was aborted by a participant
33 Abort(usize),
34 /// Error receiving a message
35 Recv,
36 /// Error sending a message
37 Send,
38 /// Received message was invalid
39 InvalidMessage,
40}
41
42/// A message relay that filters messages based on expected tags and party IDs.
43///
44/// This struct wraps an underlying relay and provides additional functionality
45/// for filtering messages based on expected tags and party IDs. It maintains
46/// a buffer of received messages and tracks expected messages.
47///
48/// # Type Parameters
49/// * `R` - The type of the underlying relay implementation
50pub struct FilteredMsgRelay<R> {
51 relay: R,
52 in_buf: Vec<(Vec<u8>, usize, MessageTag)>,
53 expected: HashMap<MsgId, (usize, MessageTag)>,
54}
55
56impl<R: Relay> FilteredMsgRelay<R> {
57 /// Creates a new `FilteredMsgRelay` by wrapping an existing relay.
58 ///
59 /// # Arguments
60 /// * `relay` - The underlying relay to wrap
61 ///
62 /// # Returns
63 /// A new `FilteredMsgRelay` instance
64 pub fn new(relay: R) -> Self {
65 Self {
66 relay,
67 expected: HashMap::new(),
68 in_buf: vec![],
69 }
70 }
71
72 /// Returns the underlying relay object.
73 ///
74 /// # Returns
75 /// The wrapped relay object
76 pub fn into_inner(self) -> R {
77 self.relay
78 }
79
80 /// Marks a message with the given ID as expected and associates it with a party ID and tag.
81 ///
82 /// # Arguments
83 /// * `id` - The message ID to expect
84 /// * `tag` - The expected message tag
85 /// * `party_id` - The ID of the party sending the message
86 /// * `ttl` - Time-to-live for the message request
87 ///
88 /// # Returns
89 /// `Ok(())` if successful, or an error if the message request fails
90 pub async fn expect_message(
91 &mut self,
92 id: MsgId,
93 tag: MessageTag,
94 party_id: usize,
95 ttl: u32,
96 ) -> Result<(), MessageSendError> {
97 self.relay.ask(&id, ttl).await?;
98 self.expected.insert(id, (party_id, tag));
99
100 Ok(())
101 }
102
103 /// Returns a message back to the expected messages queue.
104 ///
105 /// # Arguments
106 /// * `msg` - The message to put back
107 /// * `tag` - The message tag
108 /// * `party_id` - The ID of the party that sent the message
109 fn put_back(&mut self, msg: &[u8], tag: MessageTag, party_id: usize) {
110 self.expected
111 .insert(msg.try_into().unwrap(), (party_id, tag));
112 }
113
114 /// Receives an expected message with the given tag and returns the associated party ID.
115 ///
116 /// # Arguments
117 /// * `tag` - The expected message tag
118 ///
119 /// # Returns
120 /// A tuple containing:
121 /// - The received message
122 /// - The party ID of the sender
123 /// - A boolean indicating if this is an abort message
124 pub async fn recv(
125 &mut self,
126 tag: MessageTag,
127 ) -> Result<(Vec<u8>, usize, bool), Error> {
128 // flush output message messages.
129 self.relay.flush().await.map_err(|_| Error::Recv)?;
130
131 if let Some(idx) = self.in_buf.iter().position(|ent| ent.2 == tag) {
132 let (msg, p, _) = self.in_buf.swap_remove(idx);
133 return Ok((msg, p, false));
134 }
135
136 loop {
137 let msg = self.relay.next().await.ok_or(Error::Recv)?;
138
139 if let Ok(id) = <&MsgId>::try_from(msg.as_slice()) {
140 if let Some(&(p, t)) = self.expected.get(id) {
141 self.expected.remove(id);
142 match t {
143 ABORT_MESSAGE_TAG => {
144 return Ok((msg, p, true));
145 }
146
147 _ if t == tag => {
148 return Ok((msg, p, false));
149 }
150
151 _ => {
152 // some expected but not required right
153 // now message.
154 self.in_buf.push((msg, p, t));
155 }
156 }
157 }
158 }
159 }
160 }
161
162 /// Adds expected messages and asks the underlying relay to receive them.
163 ///
164 /// # Arguments
165 /// * `setup` - The protocol participant setup
166 /// * `tag` - The expected message tag
167 /// * `p2p` - Whether this is a peer-to-peer message
168 ///
169 /// # Returns
170 /// The number of messages with the same tag
171 pub async fn ask_messages<P: ProtocolParticipant>(
172 &mut self,
173 setup: &P,
174 tag: MessageTag,
175 p2p: bool,
176 ) -> Result<usize, MessageSendError> {
177 self.ask_messages_from_iter(
178 setup,
179 tag,
180 setup.all_other_parties(),
181 p2p,
182 )
183 .await
184 }
185
186 /// Asks for messages with a given tag from a set of parties.
187 ///
188 /// Filters out the current party's index from the list of parties.
189 ///
190 /// # Arguments
191 /// * `setup` - The protocol participant setup
192 /// * `tag` - The expected message tag
193 /// * `from_parties` - Iterator over party indices to receive from
194 /// * `p2p` - Whether this is a peer-to-peer message
195 ///
196 /// # Returns
197 /// The number of messages with the same tag
198 pub async fn ask_messages_from_iter<P, I>(
199 &mut self,
200 setup: &P,
201 tag: MessageTag,
202 from_parties: I,
203 p2p: bool,
204 ) -> Result<usize, MessageSendError>
205 where
206 P: ProtocolParticipant,
207 I: IntoIterator<Item = usize>,
208 {
209 let my_party_index = setup.participant_index();
210 let receiver = p2p.then_some(my_party_index);
211 let mut count = 0;
212 for sender_index in from_parties.into_iter() {
213 if sender_index == my_party_index {
214 continue;
215 }
216
217 count += 1;
218 self.expect_message(
219 setup.msg_id_from(sender_index, receiver, tag),
220 tag,
221 sender_index,
222 setup.message_ttl().as_secs() as _,
223 )
224 .await?;
225 }
226
227 Ok(count)
228 }
229
230 /// Similar to `ask_messages_from_iter` but accepts a slice of indices.
231 ///
232 /// # Arguments
233 /// * `setup` - The protocol participant setup
234 /// * `tag` - The expected message tag
235 /// * `from_parties` - Slice of party indices to receive from
236 /// * `p2p` - Whether this is a peer-to-peer message
237 ///
238 /// # Returns
239 /// The number of messages with the same tag
240 pub async fn ask_messages_from_slice<'a, P, I>(
241 &mut self,
242 setup: &P,
243 tag: MessageTag,
244 from_parties: I,
245 p2p: bool,
246 ) -> Result<usize, MessageSendError>
247 where
248 P: ProtocolParticipant,
249 I: IntoIterator<Item = &'a usize>,
250 {
251 self.ask_messages_from_iter(
252 setup,
253 tag,
254 from_parties.into_iter().copied(),
255 p2p,
256 )
257 .await
258 }
259
260 /// Creates a new round for receiving messages.
261 ///
262 /// # Arguments
263 /// * `count` - Number of messages to receive in this round
264 /// * `tag` - The expected message tag
265 ///
266 /// # Returns
267 /// A new `Round` instance
268 pub fn round(&mut self, count: usize, tag: MessageTag) -> Round<'_, R> {
269 Round::new(count, tag, self)
270 }
271}
272
273impl<R> Deref for FilteredMsgRelay<R> {
274 type Target = R;
275
276 fn deref(&self) -> &Self::Target {
277 &self.relay
278 }
279}
280
281impl<R> DerefMut for FilteredMsgRelay<R> {
282 fn deref_mut(&mut self) -> &mut Self::Target {
283 &mut self.relay
284 }
285}
286
287/// A structure for receiving a round of messages.
288///
289/// This struct manages the reception of a fixed number of messages with a specific tag
290/// in a single round of communication.
291///
292/// # Type Parameters
293/// * `'a` - The lifetime of the parent `FilteredMsgRelay`
294/// * `R` - The type of the underlying relay
295pub struct Round<'a, R> {
296 tag: MessageTag,
297 count: usize,
298 pub(crate) relay: &'a mut FilteredMsgRelay<R>,
299}
300
301impl<'a, R: Relay> Round<'a, R> {
302 /// Creates a new round with a given number of messages to receive.
303 ///
304 /// # Arguments
305 /// * `count` - Number of messages to receive in this round
306 /// * `tag` - The expected message tag
307 /// * `relay` - The parent message relay
308 ///
309 /// # Returns
310 /// A new `Round` instance
311 pub fn new(
312 count: usize,
313 tag: MessageTag,
314 relay: &'a mut FilteredMsgRelay<R>,
315 ) -> Self {
316 Self { count, tag, relay }
317 }
318
319 /// Receives the next message in the round.
320 ///
321 /// # Returns
322 /// - `Ok(Some(message, party_index, is_abort_flag))` on successful reception
323 /// - `Ok(None)` when the round is complete
324 /// - `Err(Error)` if an error occurs
325 pub async fn recv(
326 &mut self,
327 ) -> Result<Option<(Vec<u8>, usize, bool)>, Error> {
328 Ok(if self.count > 0 {
329 let msg = self.relay.recv(self.tag).await;
330 #[cfg(feature = "tracing")]
331 if msg.is_err() {
332 for (id, (p, t)) in &self.relay.expected {
333 if t == &self.tag {
334 tracing::debug!("waiting for {:X} {} {:?}", id, p, t);
335 }
336 }
337 }
338 let msg = msg?;
339 self.count -= 1;
340 Some(msg)
341 } else {
342 None
343 })
344 }
345
346 /// Returns a message back to the expected messages queue.
347 ///
348 /// This is used when a message is received but found to be invalid.
349 ///
350 /// # Arguments
351 /// * `msg` - The message to put back
352 /// * `tag` - The message tag
353 /// * `party_id` - The ID of the party that sent the message
354 pub fn put_back(&mut self, msg: &[u8], tag: MessageTag, party_id: usize) {
355 self.relay.put_back(msg, tag, party_id);
356 self.count += 1;
357
358 // TODO Should we ASK it again?
359 }
360
361 /// Receives all messages in the round, verifies them, decodes them, and passes them to a handler.
362 ///
363 /// # Type Parameters
364 /// * `T` - The type of the message payload
365 /// * `F` - The handler function type
366 /// * `S` - The protocol participant type
367 /// * `E` - The error type
368 ///
369 /// # Arguments
370 /// * `setup` - The protocol participant setup
371 /// * `abort_err` - Function to create an error from an abort message
372 /// * `handler` - Function to handle each received message
373 ///
374 /// # Returns
375 /// `Ok(())` if all messages are successfully processed, or an error if any message fails
376 pub async fn of_signed_messages<T, F, S, E>(
377 mut self,
378 setup: &S,
379 abort_err: impl Fn(usize) -> E,
380 mut handler: F,
381 ) -> Result<(), E>
382 where
383 T: AnyBitPattern + NoUninit,
384 S: ProtocolParticipant,
385 F: FnMut(&T, usize) -> Result<(), E>,
386 E: From<Error>,
387 {
388 while let Some((msg, party_idx, is_abort)) = self.recv().await? {
389 if is_abort {
390 check_abort(setup, &msg, party_idx, &abort_err)?;
391 self.put_back(&msg, ABORT_MESSAGE_TAG, party_idx);
392 continue;
393 }
394
395 let vk = setup.verifier(party_idx);
396 let msg: &T = match SignedMessage::verify(&msg, vk) {
397 Some(refs) => refs,
398 _ => {
399 self.put_back(&msg, self.tag, party_idx);
400 continue;
401 }
402 };
403
404 handler(msg, party_idx)?;
405 }
406
407 Ok(())
408 }
409
410 /// Receives all encrypted messages in the round, decrypts them, and passes them to a handler.
411 ///
412 /// # Type Parameters
413 /// * `T` - The type of the message payload
414 /// * `F` - The handler function type
415 /// * `P` - The protocol participant type
416 /// * `E` - The error type
417 ///
418 /// # Arguments
419 /// * `setup` - The protocol participant setup
420 /// * `scheme` - The encryption scheme to use
421 /// * `trailer` - Size of the trailer data
422 /// * `err` - Function to create an error from an abort message
423 /// * `handler` - Function to handle each received message
424 ///
425 /// # Returns
426 /// `Ok(())` if all messages are successfully processed, or an error if any message fails
427 pub async fn of_encrypted_messages<T, F, P, E>(
428 mut self,
429 setup: &P,
430 scheme: &mut dyn EncryptionScheme,
431 trailer: usize,
432 err: impl Fn(usize) -> E,
433 mut handler: F,
434 ) -> Result<(), E>
435 where
436 T: AnyBitPattern + NoUninit,
437 P: ProtocolParticipant,
438 F: FnMut(
439 &T,
440 usize,
441 &[u8],
442 &mut dyn EncryptionScheme,
443 ) -> Result<Option<Vec<u8>>, E>,
444 E: From<Error>,
445 {
446 while let Some((msg, party_index, is_abort)) = self.recv().await? {
447 if is_abort {
448 check_abort(setup, &msg, party_index, &err)?;
449 self.put_back(&msg, ABORT_MESSAGE_TAG, party_index);
450 continue;
451 }
452
453 let mut msg = Zeroizing::new(msg);
454
455 let (msg, trailer) = match EncryptedMessage::<T>::decrypt(
456 &mut msg,
457 trailer,
458 scheme,
459 party_index,
460 ) {
461 Some(refs) => refs,
462 _ => {
463 self.put_back(&msg, self.tag, party_index);
464 continue;
465 }
466 };
467
468 if let Some(replay) = handler(msg, party_index, trailer, scheme)?
469 {
470 self.relay.send(replay).await.map_err(|_| Error::Send)?;
471 }
472 }
473
474 Ok(())
475 }
476
477 /// Broadcasts four different types of messages to all participants.
478 ///
479 /// # Type Parameters
480 /// * `P` - The protocol participant type
481 /// * `T1` - The type of the first message
482 /// * `T2` - The type of the second message
483 /// * `T3` - The type of the third message
484 /// * `T4` - The type of the fourth message
485 ///
486 /// # Arguments
487 /// * `setup` - The protocol participant setup
488 /// * `msg` - Tuple of four messages to broadcast
489 ///
490 /// # Returns
491 /// A tuple of four `Pairs` containing the broadcast messages and their senders
492 pub async fn broadcast_4<P, T1, T2, T3, T4>(
493 self,
494 setup: &P,
495 msg: (T1, T2, T3, T4),
496 ) -> Result<
497 (
498 Pairs<T1, usize>,
499 Pairs<T2, usize>,
500 Pairs<T3, usize>,
501 Pairs<T4, usize>,
502 ),
503 Error,
504 >
505 where
506 P: ProtocolParticipant,
507 T1: Wrap,
508 T2: Wrap,
509 T3: Wrap,
510 T4: Wrap,
511 {
512 #[cfg(feature = "tracing")]
513 tracing::debug!("enter broadcast {:?}", self.tag);
514
515 let my_party_id = setup.participant_index();
516
517 let sizes = [
518 msg.0.external_size(),
519 msg.1.external_size(),
520 msg.2.external_size(),
521 msg.3.external_size(),
522 ];
523 let trailer: usize = sizes.iter().sum();
524
525 let buffer = {
526 // Do not hold SignedMessage across an await point to avoid
527 // forcing ProtocolParticipant::MessageSignature to be Send
528 // in case if the future returned by run() have to be Send.
529 let mut buffer = SignedMessage::<(), _>::new(
530 &setup.msg_id(None, self.tag),
531 setup.message_ttl().as_secs() as _,
532 0,
533 trailer,
534 );
535
536 let (_, mut out) = buffer.payload();
537
538 out = msg.0.encode(out);
539 out = msg.1.encode(out);
540 out = msg.2.encode(out);
541 msg.3.encode(out);
542
543 buffer.sign(setup.signer())
544 };
545
546 self.relay.send(buffer).await.map_err(|_| Error::Send)?;
547
548 let (mut p0, mut p1, mut p2, mut p3) =
549 self.recv_broadcast_4(setup, &sizes).await?;
550
551 p0.push(my_party_id, msg.0);
552 p1.push(my_party_id, msg.1);
553 p2.push(my_party_id, msg.2);
554 p3.push(my_party_id, msg.3);
555
556 Ok((p0, p1, p2, p3))
557 }
558
559 /// Receives four different types of broadcast messages from all participants.
560 ///
561 /// # Type Parameters
562 /// * `P` - The protocol participant type
563 /// * `T1` - The type of the first message
564 /// * `T2` - The type of the second message
565 /// * `T3` - The type of the third message
566 /// * `T4` - The type of the fourth message
567 ///
568 /// # Arguments
569 /// * `setup` - The protocol participant setup
570 /// * `sizes` - Array of sizes for each message type
571 ///
572 /// # Returns
573 /// A tuple of four `Pairs` containing the received messages and their senders
574 pub async fn recv_broadcast_4<P, T1, T2, T3, T4>(
575 mut self,
576 setup: &P,
577 sizes: &[usize; 4],
578 ) -> Result<
579 (
580 Pairs<T1, usize>,
581 Pairs<T2, usize>,
582 Pairs<T3, usize>,
583 Pairs<T4, usize>,
584 ),
585 Error,
586 >
587 where
588 P: ProtocolParticipant,
589 T1: Wrap,
590 T2: Wrap,
591 T3: Wrap,
592 T4: Wrap,
593 {
594 let trailer: usize = sizes.iter().sum();
595
596 let mut p0 = Pairs::new();
597 let mut p1 = Pairs::new();
598 let mut p2 = Pairs::new();
599 let mut p3 = Pairs::new();
600
601 while let Some((msg, party_id, is_abort)) = self.recv().await? {
602 if is_abort {
603 check_abort(setup, &msg, party_id, Error::Abort)?;
604 self.put_back(&msg, ABORT_MESSAGE_TAG, party_id);
605 continue;
606 }
607
608 let buf = match SignedMessage::<(), _>::verify_with_trailer(
609 &msg,
610 trailer,
611 setup.verifier(party_id),
612 ) {
613 Some((_, msg)) => msg,
614 None => {
615 // We got message with a right ID but with broken signature.
616 self.put_back(&msg, self.tag, party_id);
617 continue;
618 }
619 };
620
621 let (buf, v1) =
622 T1::decode(buf, sizes[0]).ok_or(Error::InvalidMessage)?;
623 let (buf, v2) =
624 T2::decode(buf, sizes[1]).ok_or(Error::InvalidMessage)?;
625 let (buf, v3) =
626 T3::decode(buf, sizes[2]).ok_or(Error::InvalidMessage)?;
627 let (_bu, v4) =
628 T4::decode(buf, sizes[3]).ok_or(Error::InvalidMessage)?;
629
630 p0.push(party_id, v1);
631 p1.push(party_id, v2);
632 p2.push(party_id, v3);
633 p3.push(party_id, v4);
634 }
635
636 Ok((p0, p1, p2, p3))
637 }
638}