dkls23/proto/
tags.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4//! Module for handling message tags and message relay functionality.
5//!
6//! This module provides functionality for filtering and managing message relays,
7//! including support for message tags, message filtering, and round-based message handling.
8//! It includes structures for managing expected messages and handling message rounds.
9
10use std::{
11    collections::HashMap,
12    ops::{Deref, DerefMut},
13};
14
15use bytemuck::{AnyBitPattern, NoUninit};
16use zeroize::Zeroizing;
17
18use sl_mpc_mate::coord::*;
19
20use crate::{
21    pairs::Pairs,
22    proto::{
23        check_abort, EncryptedMessage, EncryptionScheme, MessageTag, MsgId,
24        Relay, SignedMessage, Wrap,
25    },
26    setup::{ProtocolParticipant, ABORT_MESSAGE_TAG},
27};
28
29/// Errors that can occur during message relay operations.
30#[derive(Debug)]
31pub enum Error {
32    /// Protocol was aborted by a participant
33    Abort(usize),
34    /// Error receiving a message
35    Recv,
36    /// Error sending a message
37    Send,
38    /// Received message was invalid
39    InvalidMessage,
40}
41
42/// A message relay that filters messages based on expected tags and party IDs.
43///
44/// This struct wraps an underlying relay and provides additional functionality
45/// for filtering messages based on expected tags and party IDs. It maintains
46/// a buffer of received messages and tracks expected messages.
47///
48/// # Type Parameters
49/// * `R` - The type of the underlying relay implementation
50pub struct FilteredMsgRelay<R> {
51    relay: R,
52    in_buf: Vec<(Vec<u8>, usize, MessageTag)>,
53    expected: HashMap<MsgId, (usize, MessageTag)>,
54}
55
56impl<R: Relay> FilteredMsgRelay<R> {
57    /// Creates a new `FilteredMsgRelay` by wrapping an existing relay.
58    ///
59    /// # Arguments
60    /// * `relay` - The underlying relay to wrap
61    ///
62    /// # Returns
63    /// A new `FilteredMsgRelay` instance
64    pub fn new(relay: R) -> Self {
65        Self {
66            relay,
67            expected: HashMap::new(),
68            in_buf: vec![],
69        }
70    }
71
72    /// Returns the underlying relay object.
73    ///
74    /// # Returns
75    /// The wrapped relay object
76    pub fn into_inner(self) -> R {
77        self.relay
78    }
79
80    /// Marks a message with the given ID as expected and associates it with a party ID and tag.
81    ///
82    /// # Arguments
83    /// * `id` - The message ID to expect
84    /// * `tag` - The expected message tag
85    /// * `party_id` - The ID of the party sending the message
86    /// * `ttl` - Time-to-live for the message request
87    ///
88    /// # Returns
89    /// `Ok(())` if successful, or an error if the message request fails
90    pub async fn expect_message(
91        &mut self,
92        id: MsgId,
93        tag: MessageTag,
94        party_id: usize,
95        ttl: u32,
96    ) -> Result<(), MessageSendError> {
97        self.relay.ask(&id, ttl).await?;
98        self.expected.insert(id, (party_id, tag));
99
100        Ok(())
101    }
102
103    /// Returns a message back to the expected messages queue.
104    ///
105    /// # Arguments
106    /// * `msg` - The message to put back
107    /// * `tag` - The message tag
108    /// * `party_id` - The ID of the party that sent the message
109    fn put_back(&mut self, msg: &[u8], tag: MessageTag, party_id: usize) {
110        self.expected
111            .insert(msg.try_into().unwrap(), (party_id, tag));
112    }
113
114    /// Receives an expected message with the given tag and returns the associated party ID.
115    ///
116    /// # Arguments
117    /// * `tag` - The expected message tag
118    ///
119    /// # Returns
120    /// A tuple containing:
121    /// - The received message
122    /// - The party ID of the sender
123    /// - A boolean indicating if this is an abort message
124    pub async fn recv(
125        &mut self,
126        tag: MessageTag,
127    ) -> Result<(Vec<u8>, usize, bool), Error> {
128        // flush output message messages.
129        self.relay.flush().await.map_err(|_| Error::Recv)?;
130
131        if let Some(idx) = self.in_buf.iter().position(|ent| ent.2 == tag) {
132            let (msg, p, _) = self.in_buf.swap_remove(idx);
133            return Ok((msg, p, false));
134        }
135
136        loop {
137            let msg = self.relay.next().await.ok_or(Error::Recv)?;
138
139            if let Ok(id) = <&MsgId>::try_from(msg.as_slice()) {
140                if let Some(&(p, t)) = self.expected.get(id) {
141                    self.expected.remove(id);
142                    match t {
143                        ABORT_MESSAGE_TAG => {
144                            return Ok((msg, p, true));
145                        }
146
147                        _ if t == tag => {
148                            return Ok((msg, p, false));
149                        }
150
151                        _ => {
152                            // some expected but not required right
153                            // now message.
154                            self.in_buf.push((msg, p, t));
155                        }
156                    }
157                }
158            }
159        }
160    }
161
162    /// Adds expected messages and asks the underlying relay to receive them.
163    ///
164    /// # Arguments
165    /// * `setup` - The protocol participant setup
166    /// * `tag` - The expected message tag
167    /// * `p2p` - Whether this is a peer-to-peer message
168    ///
169    /// # Returns
170    /// The number of messages with the same tag
171    pub async fn ask_messages<P: ProtocolParticipant>(
172        &mut self,
173        setup: &P,
174        tag: MessageTag,
175        p2p: bool,
176    ) -> Result<usize, MessageSendError> {
177        self.ask_messages_from_iter(
178            setup,
179            tag,
180            setup.all_other_parties(),
181            p2p,
182        )
183        .await
184    }
185
186    /// Asks for messages with a given tag from a set of parties.
187    ///
188    /// Filters out the current party's index from the list of parties.
189    ///
190    /// # Arguments
191    /// * `setup` - The protocol participant setup
192    /// * `tag` - The expected message tag
193    /// * `from_parties` - Iterator over party indices to receive from
194    /// * `p2p` - Whether this is a peer-to-peer message
195    ///
196    /// # Returns
197    /// The number of messages with the same tag
198    pub async fn ask_messages_from_iter<P, I>(
199        &mut self,
200        setup: &P,
201        tag: MessageTag,
202        from_parties: I,
203        p2p: bool,
204    ) -> Result<usize, MessageSendError>
205    where
206        P: ProtocolParticipant,
207        I: IntoIterator<Item = usize>,
208    {
209        let my_party_index = setup.participant_index();
210        let receiver = p2p.then_some(my_party_index);
211        let mut count = 0;
212        for sender_index in from_parties.into_iter() {
213            if sender_index == my_party_index {
214                continue;
215            }
216
217            count += 1;
218            self.expect_message(
219                setup.msg_id_from(sender_index, receiver, tag),
220                tag,
221                sender_index,
222                setup.message_ttl().as_secs() as _,
223            )
224            .await?;
225        }
226
227        Ok(count)
228    }
229
230    /// Similar to `ask_messages_from_iter` but accepts a slice of indices.
231    ///
232    /// # Arguments
233    /// * `setup` - The protocol participant setup
234    /// * `tag` - The expected message tag
235    /// * `from_parties` - Slice of party indices to receive from
236    /// * `p2p` - Whether this is a peer-to-peer message
237    ///
238    /// # Returns
239    /// The number of messages with the same tag
240    pub async fn ask_messages_from_slice<'a, P, I>(
241        &mut self,
242        setup: &P,
243        tag: MessageTag,
244        from_parties: I,
245        p2p: bool,
246    ) -> Result<usize, MessageSendError>
247    where
248        P: ProtocolParticipant,
249        I: IntoIterator<Item = &'a usize>,
250    {
251        self.ask_messages_from_iter(
252            setup,
253            tag,
254            from_parties.into_iter().copied(),
255            p2p,
256        )
257        .await
258    }
259
260    /// Creates a new round for receiving messages.
261    ///
262    /// # Arguments
263    /// * `count` - Number of messages to receive in this round
264    /// * `tag` - The expected message tag
265    ///
266    /// # Returns
267    /// A new `Round` instance
268    pub fn round(&mut self, count: usize, tag: MessageTag) -> Round<'_, R> {
269        Round::new(count, tag, self)
270    }
271}
272
273impl<R> Deref for FilteredMsgRelay<R> {
274    type Target = R;
275
276    fn deref(&self) -> &Self::Target {
277        &self.relay
278    }
279}
280
281impl<R> DerefMut for FilteredMsgRelay<R> {
282    fn deref_mut(&mut self) -> &mut Self::Target {
283        &mut self.relay
284    }
285}
286
287/// A structure for receiving a round of messages.
288///
289/// This struct manages the reception of a fixed number of messages with a specific tag
290/// in a single round of communication.
291///
292/// # Type Parameters
293/// * `'a` - The lifetime of the parent `FilteredMsgRelay`
294/// * `R` - The type of the underlying relay
295pub struct Round<'a, R> {
296    tag: MessageTag,
297    count: usize,
298    pub(crate) relay: &'a mut FilteredMsgRelay<R>,
299}
300
301impl<'a, R: Relay> Round<'a, R> {
302    /// Creates a new round with a given number of messages to receive.
303    ///
304    /// # Arguments
305    /// * `count` - Number of messages to receive in this round
306    /// * `tag` - The expected message tag
307    /// * `relay` - The parent message relay
308    ///
309    /// # Returns
310    /// A new `Round` instance
311    pub fn new(
312        count: usize,
313        tag: MessageTag,
314        relay: &'a mut FilteredMsgRelay<R>,
315    ) -> Self {
316        Self { count, tag, relay }
317    }
318
319    /// Receives the next message in the round.
320    ///
321    /// # Returns
322    /// - `Ok(Some(message, party_index, is_abort_flag))` on successful reception
323    /// - `Ok(None)` when the round is complete
324    /// - `Err(Error)` if an error occurs
325    pub async fn recv(
326        &mut self,
327    ) -> Result<Option<(Vec<u8>, usize, bool)>, Error> {
328        Ok(if self.count > 0 {
329            let msg = self.relay.recv(self.tag).await;
330            #[cfg(feature = "tracing")]
331            if msg.is_err() {
332                for (id, (p, t)) in &self.relay.expected {
333                    if t == &self.tag {
334                        tracing::debug!("waiting for {:X} {} {:?}", id, p, t);
335                    }
336                }
337            }
338            let msg = msg?;
339            self.count -= 1;
340            Some(msg)
341        } else {
342            None
343        })
344    }
345
346    /// Returns a message back to the expected messages queue.
347    ///
348    /// This is used when a message is received but found to be invalid.
349    ///
350    /// # Arguments
351    /// * `msg` - The message to put back
352    /// * `tag` - The message tag
353    /// * `party_id` - The ID of the party that sent the message
354    pub fn put_back(&mut self, msg: &[u8], tag: MessageTag, party_id: usize) {
355        self.relay.put_back(msg, tag, party_id);
356        self.count += 1;
357
358        // TODO Should we ASK it again?
359    }
360
361    /// Receives all messages in the round, verifies them, decodes them, and passes them to a handler.
362    ///
363    /// # Type Parameters
364    /// * `T` - The type of the message payload
365    /// * `F` - The handler function type
366    /// * `S` - The protocol participant type
367    /// * `E` - The error type
368    ///
369    /// # Arguments
370    /// * `setup` - The protocol participant setup
371    /// * `abort_err` - Function to create an error from an abort message
372    /// * `handler` - Function to handle each received message
373    ///
374    /// # Returns
375    /// `Ok(())` if all messages are successfully processed, or an error if any message fails
376    pub async fn of_signed_messages<T, F, S, E>(
377        mut self,
378        setup: &S,
379        abort_err: impl Fn(usize) -> E,
380        mut handler: F,
381    ) -> Result<(), E>
382    where
383        T: AnyBitPattern + NoUninit,
384        S: ProtocolParticipant,
385        F: FnMut(&T, usize) -> Result<(), E>,
386        E: From<Error>,
387    {
388        while let Some((msg, party_idx, is_abort)) = self.recv().await? {
389            if is_abort {
390                check_abort(setup, &msg, party_idx, &abort_err)?;
391                self.put_back(&msg, ABORT_MESSAGE_TAG, party_idx);
392                continue;
393            }
394
395            let vk = setup.verifier(party_idx);
396            let msg: &T = match SignedMessage::verify(&msg, vk) {
397                Some(refs) => refs,
398                _ => {
399                    self.put_back(&msg, self.tag, party_idx);
400                    continue;
401                }
402            };
403
404            handler(msg, party_idx)?;
405        }
406
407        Ok(())
408    }
409
410    /// Receives all encrypted messages in the round, decrypts them, and passes them to a handler.
411    ///
412    /// # Type Parameters
413    /// * `T` - The type of the message payload
414    /// * `F` - The handler function type
415    /// * `P` - The protocol participant type
416    /// * `E` - The error type
417    ///
418    /// # Arguments
419    /// * `setup` - The protocol participant setup
420    /// * `scheme` - The encryption scheme to use
421    /// * `trailer` - Size of the trailer data
422    /// * `err` - Function to create an error from an abort message
423    /// * `handler` - Function to handle each received message
424    ///
425    /// # Returns
426    /// `Ok(())` if all messages are successfully processed, or an error if any message fails
427    pub async fn of_encrypted_messages<T, F, P, E>(
428        mut self,
429        setup: &P,
430        scheme: &mut dyn EncryptionScheme,
431        trailer: usize,
432        err: impl Fn(usize) -> E,
433        mut handler: F,
434    ) -> Result<(), E>
435    where
436        T: AnyBitPattern + NoUninit,
437        P: ProtocolParticipant,
438        F: FnMut(
439            &T,
440            usize,
441            &[u8],
442            &mut dyn EncryptionScheme,
443        ) -> Result<Option<Vec<u8>>, E>,
444        E: From<Error>,
445    {
446        while let Some((msg, party_index, is_abort)) = self.recv().await? {
447            if is_abort {
448                check_abort(setup, &msg, party_index, &err)?;
449                self.put_back(&msg, ABORT_MESSAGE_TAG, party_index);
450                continue;
451            }
452
453            let mut msg = Zeroizing::new(msg);
454
455            let (msg, trailer) = match EncryptedMessage::<T>::decrypt(
456                &mut msg,
457                trailer,
458                scheme,
459                party_index,
460            ) {
461                Some(refs) => refs,
462                _ => {
463                    self.put_back(&msg, self.tag, party_index);
464                    continue;
465                }
466            };
467
468            if let Some(replay) = handler(msg, party_index, trailer, scheme)?
469            {
470                self.relay.send(replay).await.map_err(|_| Error::Send)?;
471            }
472        }
473
474        Ok(())
475    }
476
477    /// Broadcasts four different types of messages to all participants.
478    ///
479    /// # Type Parameters
480    /// * `P` - The protocol participant type
481    /// * `T1` - The type of the first message
482    /// * `T2` - The type of the second message
483    /// * `T3` - The type of the third message
484    /// * `T4` - The type of the fourth message
485    ///
486    /// # Arguments
487    /// * `setup` - The protocol participant setup
488    /// * `msg` - Tuple of four messages to broadcast
489    ///
490    /// # Returns
491    /// A tuple of four `Pairs` containing the broadcast messages and their senders
492    pub async fn broadcast_4<P, T1, T2, T3, T4>(
493        self,
494        setup: &P,
495        msg: (T1, T2, T3, T4),
496    ) -> Result<
497        (
498            Pairs<T1, usize>,
499            Pairs<T2, usize>,
500            Pairs<T3, usize>,
501            Pairs<T4, usize>,
502        ),
503        Error,
504    >
505    where
506        P: ProtocolParticipant,
507        T1: Wrap,
508        T2: Wrap,
509        T3: Wrap,
510        T4: Wrap,
511    {
512        #[cfg(feature = "tracing")]
513        tracing::debug!("enter broadcast {:?}", self.tag);
514
515        let my_party_id = setup.participant_index();
516
517        let sizes = [
518            msg.0.external_size(),
519            msg.1.external_size(),
520            msg.2.external_size(),
521            msg.3.external_size(),
522        ];
523        let trailer: usize = sizes.iter().sum();
524
525        let buffer = {
526            // Do not hold SignedMessage across an await point to avoid
527            // forcing ProtocolParticipant::MessageSignature to be Send
528            // in case if the future returned by run() have to be Send.
529            let mut buffer = SignedMessage::<(), _>::new(
530                &setup.msg_id(None, self.tag),
531                setup.message_ttl().as_secs() as _,
532                0,
533                trailer,
534            );
535
536            let (_, mut out) = buffer.payload();
537
538            out = msg.0.encode(out);
539            out = msg.1.encode(out);
540            out = msg.2.encode(out);
541            msg.3.encode(out);
542
543            buffer.sign(setup.signer())
544        };
545
546        self.relay.send(buffer).await.map_err(|_| Error::Send)?;
547
548        let (mut p0, mut p1, mut p2, mut p3) =
549            self.recv_broadcast_4(setup, &sizes).await?;
550
551        p0.push(my_party_id, msg.0);
552        p1.push(my_party_id, msg.1);
553        p2.push(my_party_id, msg.2);
554        p3.push(my_party_id, msg.3);
555
556        Ok((p0, p1, p2, p3))
557    }
558
559    /// Receives four different types of broadcast messages from all participants.
560    ///
561    /// # Type Parameters
562    /// * `P` - The protocol participant type
563    /// * `T1` - The type of the first message
564    /// * `T2` - The type of the second message
565    /// * `T3` - The type of the third message
566    /// * `T4` - The type of the fourth message
567    ///
568    /// # Arguments
569    /// * `setup` - The protocol participant setup
570    /// * `sizes` - Array of sizes for each message type
571    ///
572    /// # Returns
573    /// A tuple of four `Pairs` containing the received messages and their senders
574    pub async fn recv_broadcast_4<P, T1, T2, T3, T4>(
575        mut self,
576        setup: &P,
577        sizes: &[usize; 4],
578    ) -> Result<
579        (
580            Pairs<T1, usize>,
581            Pairs<T2, usize>,
582            Pairs<T3, usize>,
583            Pairs<T4, usize>,
584        ),
585        Error,
586    >
587    where
588        P: ProtocolParticipant,
589        T1: Wrap,
590        T2: Wrap,
591        T3: Wrap,
592        T4: Wrap,
593    {
594        let trailer: usize = sizes.iter().sum();
595
596        let mut p0 = Pairs::new();
597        let mut p1 = Pairs::new();
598        let mut p2 = Pairs::new();
599        let mut p3 = Pairs::new();
600
601        while let Some((msg, party_id, is_abort)) = self.recv().await? {
602            if is_abort {
603                check_abort(setup, &msg, party_id, Error::Abort)?;
604                self.put_back(&msg, ABORT_MESSAGE_TAG, party_id);
605                continue;
606            }
607
608            let buf = match SignedMessage::<(), _>::verify_with_trailer(
609                &msg,
610                trailer,
611                setup.verifier(party_id),
612            ) {
613                Some((_, msg)) => msg,
614                None => {
615                    // We got message with a right ID but with broken signature.
616                    self.put_back(&msg, self.tag, party_id);
617                    continue;
618                }
619            };
620
621            let (buf, v1) =
622                T1::decode(buf, sizes[0]).ok_or(Error::InvalidMessage)?;
623            let (buf, v2) =
624                T2::decode(buf, sizes[1]).ok_or(Error::InvalidMessage)?;
625            let (buf, v3) =
626                T3::decode(buf, sizes[2]).ok_or(Error::InvalidMessage)?;
627            let (_bu, v4) =
628                T4::decode(buf, sizes[3]).ok_or(Error::InvalidMessage)?;
629
630            p0.push(party_id, v1);
631            p1.push(party_id, v2);
632            p2.push(party_id, v3);
633            p3.push(party_id, v4);
634        }
635
636        Ok((p0, p1, p2, p3))
637    }
638}