dkls23/sign/
dsg.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4use std::collections::HashMap;
5
6use k256::{
7    ecdsa::{
8        signature::hazmat::PrehashVerifier, RecoveryId, Signature,
9        VerifyingKey,
10    },
11    elliptic_curve::{
12        group::GroupEncoding,
13        ops::Reduce,
14        point::AffineCoordinates,
15        scalar::IsHigh,
16        subtle::{Choice, ConstantTimeEq},
17        PrimeField,
18    },
19    sha2::{Digest, Sha256},
20    NonZeroScalar, ProjectivePoint, Scalar, Secp256k1, U256,
21};
22use rand::{Rng, SeedableRng};
23use rand_chacha::ChaCha20Rng;
24use zeroize::Zeroizing;
25
26use sl_mpc_mate::{coord::*, math::birkhoff_coeffs};
27
28use sl_oblivious::rvole::{RVOLEReceiver, RVOLESender};
29
30use crate::{
31    keygen::Keyshare,
32    proto::{
33        create_abort_message, tags::*, EncryptedMessage, SignedMessage, *,
34    },
35    setup::{
36        FinalSignSetupMessage, PreSignSetupMessage, ProtocolParticipant,
37        SignSetupMessage, ABORT_MESSAGE_TAG,
38    },
39    sign::constants::*,
40    sign::messages::*,
41    Seed,
42};
43
44use super::SignError;
45
46use crate::pairs::Pairs;
47
48/// Inner function for the pre-signature phase of the DSG protocol
49///
50/// This function implements the core logic of the pre-signature phase,
51/// where parties generate a pre-signature that can be used to sign any
52/// message later.
53///
54/// # Type Parameters
55///
56/// * `R`: Type implementing the `Relay` trait for message communication
57/// * `S`: Type implementing the `PreSignSetupMessage` trait for setup parameters
58///
59/// # Arguments
60///
61/// * `setup`: Setup parameters for the protocol
62/// * `seed`: Random seed for generating random values
63/// * `relay`: Message relay for communication between parties
64///
65/// # Returns
66///
67/// A `Result` containing either:
68/// * `Ok(PreSign)`: The pre-signature result
69/// * `Err(SignError)`: An error if the protocol fails
70async fn pre_signature_inner<R: Relay, S: PreSignSetupMessage>(
71    setup: &S,
72    seed: Seed,
73    relay: &mut FilteredMsgRelay<R>,
74) -> Result<PreSign, SignError> {
75    let mut rng = ChaCha20Rng::from_seed(seed);
76    let mut scheme = crate::proto::Scheme::new(&mut rng);
77
78    // For DKG part_id == part_idx.
79    //
80    // For DSG: party_idx is an index of the party in the setup messages.
81    //
82    // In the first message a party sends its part_id from Keyshare and
83    // its encryption public key
84    //
85    let my_party_id = setup.keyshare().party_id;
86    let my_party_idx = setup.participant_index();
87
88    let phi_i: Scalar = Scalar::generate_biased(&mut rng);
89    let r_i: Scalar = Scalar::generate_biased(&mut rng);
90    let blind_factor: [u8; 32] = rng.gen();
91
92    let big_r_i = ProjectivePoint::GENERATOR * r_i;
93
94    // TODO: Replace with SmallVec 2?
95    let mut commitments =
96        vec![([0; 32], [0; 32]); setup.total_participants()];
97
98    commitments[my_party_idx].0 = rng.gen(); // generate SessionId
99    commitments[my_party_idx].1 = hash_commitment_r_i(
100        &commitments[my_party_idx].0,
101        &big_r_i,
102        &blind_factor,
103    );
104
105    relay
106        .send(SignedMessage::build(
107            &setup.msg_id(None, DSG_MSG_R1),
108            setup.message_ttl().as_secs() as _,
109            0,
110            setup.signer(),
111            |msg: &mut SignMsg1, _| {
112                msg.session_id = commitments[my_party_idx].0;
113                msg.commitment_r_i = commitments[my_party_idx].1;
114                msg.party_id = my_party_id;
115                msg.enc_pk = scheme.public_key().try_into().unwrap();
116            },
117        ))
118        .await?;
119
120    // vector of pairs (party_idx, party_id)
121    let mut party_idx_to_id_map = vec![(my_party_idx, my_party_id)];
122
123    Round::new(setup.total_participants() - 1, DSG_MSG_R1, relay)
124        .of_signed_messages(
125            setup,
126            SignError::AbortProtocol,
127            |msg: &SignMsg1, party_idx| {
128                party_idx_to_id_map.push((party_idx, msg.party_id));
129                commitments[party_idx] = (msg.session_id, msg.commitment_r_i);
130                scheme
131                    .receiver_public_key(party_idx, &msg.enc_pk)
132                    .map_err(|_| SignError::InvalidMessage)?;
133
134                Ok(())
135            },
136        )
137        .await?;
138
139    party_idx_to_id_map.sort_by_key(|&(_, pid)| pid);
140
141    // there is no party-id duplicates
142    if !party_idx_to_id_map.windows(2).all(|w| w[0] != w[1]) {
143        return Err(SignError::InvalidMessage);
144    }
145
146    // all party-id are in range
147    if !party_idx_to_id_map
148        .iter()
149        .all(|&(_, pid)| pid < setup.keyshare().total_parties)
150    {
151        return Err(SignError::InvalidMessage);
152    }
153
154    // IDX -> ID
155    let find_party_id = |idx: usize| {
156        party_idx_to_id_map
157            .iter()
158            .find_map(|&(i, p)| (i == idx).then_some(p))
159            .unwrap()
160    };
161
162    let final_session_id: [u8; 32] = commitments
163        .iter()
164        .fold(Sha256::new(), |hash, (sid, _)| hash.chain_update(sid))
165        .chain_update(setup.keyshare().final_session_id)
166        .finalize()
167        .into();
168
169    let digest_i: [u8; 32] = commitments
170        .iter()
171        .enumerate()
172        .fold(
173            Sha256::new().chain_update(DSG_LABEL),
174            |hash, (key, (sid, commitment))| {
175                hash.chain_update((key as u32).to_be_bytes())
176                    .chain_update(sid)
177                    .chain_update(commitment)
178            },
179        )
180        .chain_update(DIGEST_I_LABEL)
181        .finalize()
182        .into();
183
184    let mut to_send = vec![];
185
186    let mut mta_receivers = Pairs::from(
187        setup
188            .all_other_parties()
189            .map(|party_idx| {
190                let sender_id = find_party_id(party_idx);
191
192                let sid =
193                    mta_session_id(&final_session_id, sender_id, my_party_id);
194
195                let sender_ot_results =
196                    setup.keyshare().sender_seed(sender_id);
197
198                let mut enc_msg = EncryptedMessage::<SignMsg2>::new(
199                    &setup.msg_id(Some(party_idx), DSG_MSG_R2),
200                    setup.message_ttl().as_secs() as u32,
201                    0,
202                    0,
203                    &scheme,
204                );
205
206                let (msg2, _) = enc_msg.payload(&scheme);
207                msg2.final_session_id = final_session_id;
208
209                let (mta_receiver, chi_i_j) = RVOLEReceiver::new(
210                    sid,
211                    sender_ot_results,
212                    &mut msg2.mta_msg1,
213                    &mut rng,
214                );
215
216                to_send.push(
217                    enc_msg
218                        .encrypt(&mut scheme, party_idx)
219                        .ok_or(SignError::SendMessage)?,
220                );
221
222                Ok((party_idx, (mta_receiver, chi_i_j)))
223            })
224            .collect::<Result<Vec<_>, SignError>>()?,
225    );
226
227    for msg in to_send {
228        relay.feed(msg).await.map_err(|_| SignError::SendMessage)?;
229    }
230
231    relay.flush().await?;
232
233    let zeta_i =
234        get_zeta_i(setup.keyshare(), &party_idx_to_id_map, &digest_i);
235
236    let coeff = if setup.keyshare().zero_ranks() {
237        get_lagrange_coeff(setup.keyshare(), &party_idx_to_id_map)
238    } else {
239        let betta_coeffs =
240            get_birkhoff_coefficients(setup.keyshare(), &party_idx_to_id_map);
241
242        *betta_coeffs
243            .get(&(my_party_id as usize))
244            .expect("betta_i not found") // FIXME
245    };
246
247    let (additive_offset, derived_public_key) = setup
248        .keyshare()
249        .derive_with_offset(setup.chain_path())
250        .unwrap(); // FIXME: report error
251    let threshold_inv = Scalar::from(setup.total_participants() as u32)
252        .invert()
253        .unwrap(); // threshold > 0 so it has an invert
254    let additive_offset = additive_offset * threshold_inv;
255
256    let sk_i = coeff * setup.keyshare().s_i() + additive_offset + zeta_i;
257    let pk_i = ProjectivePoint::GENERATOR * sk_i;
258
259    let mut sender_additive_shares = vec![];
260
261    let mut round =
262        Round::new(setup.total_participants() - 1, DSG_MSG_R2, relay);
263
264    while let Some((msg, party_idx, is_abort)) = round.recv().await? {
265        if is_abort {
266            check_abort(setup, &msg, party_idx, SignError::AbortProtocol)?;
267            round.put_back(&msg, ABORT_MESSAGE_TAG, party_idx);
268            continue;
269        }
270
271        let mut msg = Zeroizing::new(msg);
272        let msg2 = match EncryptedMessage::<SignMsg2>::decrypt(
273            &mut msg, 0, &scheme, party_idx,
274        ) {
275            Some((refs, _)) => refs,
276            _ => {
277                round.put_back(&msg, DSG_MSG_R2, party_idx);
278                continue;
279            }
280        };
281
282        // Check final_session_id
283        if msg2.final_session_id.ct_ne(&final_session_id).into() {
284            return Err(SignError::InvalidFinalSessionID);
285        }
286
287        let receiver_id = find_party_id(party_idx);
288
289        let sid = mta_session_id(&final_session_id, my_party_id, receiver_id);
290
291        let seed_ot_results = setup.keyshare().receiver_seed(receiver_id);
292
293        let mut enc_msg3 = EncryptedMessage::<SignMsg3>::new(
294            &setup.msg_id(Some(party_idx), DSG_MSG_R3),
295            setup.message_ttl().as_secs() as _,
296            0,
297            0,
298            &scheme,
299        );
300
301        let (msg3, _) = enc_msg3.payload(&scheme);
302
303        let [c_u, c_v] = RVOLESender::process(
304            &sid,
305            seed_ot_results,
306            &[r_i, sk_i],
307            &msg2.mta_msg1,
308            &mut msg3.mta_msg2,
309            &mut rng,
310        )
311        .map_err(|_| SignError::AbortProtocolAndBanParty(party_idx as u8))?;
312
313        let gamma_u = ProjectivePoint::GENERATOR * c_u;
314        let gamma_v = ProjectivePoint::GENERATOR * c_v;
315        let (_mta_receiver, chi_i_j) = mta_receivers.find_pair(party_idx);
316
317        let psi = phi_i - chi_i_j;
318
319        msg3.final_session_id = final_session_id;
320        msg3.digest_i = digest_i;
321        msg3.pk_i = encode_point(&pk_i);
322        msg3.big_r_i = encode_point(&big_r_i);
323        msg3.blind_factor = blind_factor;
324        msg3.gamma_v = encode_point(&gamma_v);
325        msg3.gamma_u = encode_point(&gamma_u);
326        msg3.psi = encode_scalar(&psi);
327
328        round
329            .relay
330            .send(
331                enc_msg3
332                    .encrypt(&mut scheme, party_idx)
333                    .ok_or(SignError::SendMessage)?,
334            )
335            .await?;
336
337        sender_additive_shares.push([c_u, c_v]);
338    }
339
340    let mut big_r_star = ProjectivePoint::IDENTITY;
341    let mut sum_pk_j = ProjectivePoint::IDENTITY;
342    let mut sum_psi_j_i = Scalar::ZERO;
343
344    let mut receiver_additive_shares = vec![];
345
346    let mut round =
347        Round::new(setup.total_participants() - 1, DSG_MSG_R3, relay);
348
349    while let Some((msg, party_idx, is_abort)) = round.recv().await? {
350        if is_abort {
351            check_abort(setup, &msg, party_idx, SignError::AbortProtocol)?;
352            round.put_back(&msg, ABORT_MESSAGE_TAG, party_idx);
353            continue;
354        }
355
356        let mut msg = Zeroizing::new(msg);
357        let msg3 = match EncryptedMessage::<SignMsg3>::decrypt(
358            &mut msg, 0, &scheme, party_idx,
359        ) {
360            Some((refs, _)) => refs,
361            _ => {
362                round.put_back(&msg, DSG_MSG_R3, party_idx);
363                continue;
364            }
365        };
366
367        // Check final_session_id
368        if msg3.final_session_id != final_session_id {
369            return Err(SignError::InvalidFinalSessionID);
370        }
371
372        let (mta_receiver, chi_i_j) = mta_receivers.pop_pair(party_idx);
373
374        let [d_u, d_v] =
375            mta_receiver.process(&msg3.mta_msg2).map_err(|_| {
376                SignError::AbortProtocolAndBanParty(party_idx as u8)
377            })?;
378
379        let (sid_i, commitment) = &commitments[party_idx];
380
381        let big_r_j =
382            decode_point(&msg3.big_r_i).ok_or(SignError::InvalidMessage)?;
383
384        if !verify_commitment_r_i(
385            sid_i,
386            &big_r_j,
387            &msg3.blind_factor,
388            commitment,
389        ) {
390            return Err(SignError::InvalidCommitment);
391        }
392
393        if digest_i.ct_ne(&msg3.digest_i).into() {
394            return Err(SignError::InvalidDigest);
395        }
396
397        let pk_j =
398            decode_point(&msg3.pk_i).ok_or(SignError::InvalidMessage)?;
399
400        big_r_star += big_r_j;
401        sum_pk_j += pk_j;
402        sum_psi_j_i +=
403            decode_scalar(&msg3.psi).ok_or(SignError::InvalidMessage)?;
404
405        let cond1 = (big_r_j * chi_i_j)
406            == (ProjectivePoint::GENERATOR * d_u
407                + decode_point(&msg3.gamma_u)
408                    .ok_or(SignError::InvalidMessage)?);
409        if !cond1 {
410            return Err(SignError::AbortProtocolAndBanParty(party_idx as u8));
411        }
412
413        let cond2 = (pk_j * chi_i_j)
414            == (ProjectivePoint::GENERATOR * d_v
415                + decode_point(&msg3.gamma_v)
416                    .ok_or(SignError::InvalidMessage)?);
417        if !cond2 {
418            return Err(SignError::AbortProtocolAndBanParty(party_idx as u8));
419        }
420
421        receiver_additive_shares.push([d_u, d_v]);
422    }
423
424    // new var
425    let big_r = big_r_star + big_r_i;
426    sum_pk_j += pk_i;
427
428    // Checks
429    if sum_pk_j != derived_public_key {
430        return Err(SignError::FailedCheck("Consistency check 3 failed"));
431    }
432
433    let mut sum_v = Scalar::ZERO;
434    let mut sum_u = Scalar::ZERO;
435
436    for i in 0..setup.total_participants() - 1 {
437        let sender_shares = &sender_additive_shares[i];
438        let receiver_shares = &receiver_additive_shares[i];
439        sum_u += sender_shares[0] + receiver_shares[0];
440        sum_v += sender_shares[1] + receiver_shares[1];
441    }
442
443    let r_point = big_r.to_affine();
444    let r_x = <Scalar as Reduce<U256>>::reduce_bytes(&r_point.x());
445    let phi_plus_sum_psi = phi_i + sum_psi_j_i;
446    let s_0 = r_x * (sk_i * phi_plus_sum_psi + sum_v);
447    let s_1 = r_i * phi_plus_sum_psi + sum_u;
448
449    let pre_sign_result = PreSign {
450        final_session_id,
451        public_key: encode_point(&derived_public_key),
452        s_0: encode_scalar(&s_0),
453        s_1: encode_scalar(&s_1),
454        phi_i: encode_scalar(&phi_i),
455        r: encode_point(&big_r),
456        party_id: my_party_id,
457    };
458
459    Ok(pre_sign_result)
460}
461
462/// Creates a partial signature from a pre-signature result
463///
464/// This function takes a pre-signature result and a message hash,
465/// and creates a partial signature that can be combined with other
466/// partial signatures to form the final signature.
467///
468/// # Arguments
469///
470/// * `pre_sign_result`: The pre-signature result from the pre-signature phase
471/// * `message_hash`: The hash of the message to be signed
472///
473/// # Returns
474///
475/// A `Result` containing either:
476/// * `Ok(PartialSignature)`: The partial signature
477/// * `Err(SignError)`: An error if the partial signature cannot be created
478fn create_partial_signature(
479    pre_sign_result: &PreSign,
480    message_hash: [u8; 32],
481) -> Result<PartialSignature, SignError> {
482    let m = Scalar::reduce(U256::from_be_slice(&message_hash));
483
484    let phi_i = decode_scalar(&pre_sign_result.phi_i)
485        .ok_or(SignError::InvalidPreSign)?;
486
487    let s_0 = decode_scalar(&pre_sign_result.s_0)
488        .ok_or(SignError::InvalidPreSign)?;
489
490    let s_0 = m * phi_i + s_0;
491
492    let s_1 = decode_scalar(&pre_sign_result.s_1)
493        .ok_or(SignError::InvalidPreSign)?;
494
495    let r =
496        decode_point(&pre_sign_result.r).ok_or(SignError::InvalidPreSign)?;
497
498    let public_key = decode_point(&pre_sign_result.public_key)
499        .ok_or(SignError::InvalidPreSign)?;
500
501    Ok(PartialSignature {
502        final_session_id: pre_sign_result.final_session_id,
503        public_key,
504        message_hash,
505        s_0,
506        s_1,
507        r,
508    })
509}
510
511/// Combines partial signatures into a final signature
512///
513/// This function takes a collection of partial signatures and combines
514/// them to produce the final ECDSA signature and recovery ID.
515///
516/// # Arguments
517///
518/// * `partial_signatures`: A slice of partial signatures to combine
519///
520/// # Returns
521///
522/// A `Result` containing either:
523/// * `Ok((Signature, RecoveryId))`: The final signature and recovery ID
524/// * `Err(SignError)`: An error if the signatures cannot be combined
525fn combine_partial_signature(
526    partial_signatures: &[PartialSignature],
527) -> Result<(Signature, RecoveryId), SignError> {
528    let p0 = &partial_signatures[0];
529
530    let mut check = Choice::from(0);
531
532    let mut sum_s_0 = p0.s_0;
533    let mut sum_s_1 = p0.s_1;
534
535    for pn in &partial_signatures[1..] {
536        check |= pn.final_session_id.ct_ne(&p0.final_session_id);
537        check |= pn.public_key.ct_ne(&p0.public_key);
538        check |= pn.r.ct_ne(&p0.r);
539        check |= pn.message_hash.ct_ne(&p0.message_hash);
540
541        sum_s_0 += pn.s_0;
542        sum_s_1 += pn.s_1;
543    }
544
545    if check.into() {
546        return Err(SignError::FailedCheck(
547            "Invalid list of partial signatures",
548        ));
549    }
550
551    let r = p0.r.to_affine();
552
553    let is_y_odd: bool = r.y_is_odd().into();
554
555    let r_x = <Scalar as Reduce<U256>>::reduce_bytes(&r.x());
556    let is_x_reduced = r_x.to_repr() != r.x();
557    let recid = RecoveryId::new(is_y_odd, is_x_reduced);
558
559    let sum_s_1_inv = sum_s_1.invert().unwrap();
560    let s = sum_s_0 * sum_s_1_inv;
561
562    let is_y_odd = recid.is_y_odd() ^ bool::from(s.is_high());
563    let recid = RecoveryId::new(is_y_odd, recid.is_x_reduced());
564
565    let sign = Signature::from_scalars(r_x, s)?;
566    let sign = sign.normalize_s().unwrap_or(sign);
567
568    VerifyingKey::from_affine(p0.public_key.to_affine())?
569        .verify_prehash(&p0.message_hash, &sign)?;
570
571    Ok((sign, recid))
572}
573
574/// Main entry point for the DSG protocol
575///
576/// This function executes the complete DSG protocol, including both
577/// the pre-signature and finish phases.
578///
579/// # Type Parameters
580///
581/// * `R`: Type implementing the `Relay` trait for message communication
582/// * `S`: Type implementing the `SignSetupMessage` trait for setup parameters
583///
584/// # Arguments
585///
586/// * `setup`: Setup parameters for the protocol
587/// * `seed`: Random seed for generating random values
588/// * `relay`: Message relay for communication between parties
589///
590/// # Returns
591///
592/// A `Result` containing either:
593/// * `Ok((Signature, RecoveryId))`: The final signature and recovery ID
594/// * `Err(SignError)`: An error if the protocol fails
595pub async fn run<R: Relay, S: SignSetupMessage>(
596    setup: S,
597    seed: Seed,
598    relay: R,
599) -> Result<(Signature, RecoveryId), SignError> {
600    let abort_msg = create_abort_message(&setup);
601    let mut relay = FilteredMsgRelay::new(relay);
602
603    relay.ask_messages(&setup, ABORT_MESSAGE_TAG, false).await?;
604    relay.ask_messages(&setup, DSG_MSG_R1, false).await?;
605    relay.ask_messages(&setup, DSG_MSG_R2, true).await?;
606    relay.ask_messages(&setup, DSG_MSG_R3, true).await?;
607    relay.ask_messages(&setup, DSG_MSG_R4, false).await?;
608
609    let result = match run_inner(setup, seed, &mut relay).await {
610        Ok(sign) => Ok(sign),
611        Err(SignError::AbortProtocol(p)) => Err(SignError::AbortProtocol(p)),
612        Err(SignError::SendMessage) => Err(SignError::SendMessage),
613        Err(err) => {
614            // ignore error of sending abort message
615            let _ = relay.send(abort_msg).await;
616            Err(err)
617        }
618    };
619
620    let _ = relay.close().await;
621
622    result
623}
624
625/// Inner function for the main DSG protocol execution
626///
627/// This function implements the core logic of the DSG protocol,
628/// handling both the pre-signature and finish phases.
629///
630/// # Type Parameters
631///
632/// * `R`: Type implementing the `Relay` trait for message communication
633/// * `S`: Type implementing the `SignSetupMessage` trait for setup parameters
634///
635/// # Arguments
636///
637/// * `setup`: Setup parameters for the protocol
638/// * `seed`: Random seed for generating random values
639/// * `relay`: Message relay for communication between parties
640///
641/// # Returns
642///
643/// A `Result` containing either:
644/// * `Ok((Signature, RecoveryId))`: The final signature and recovery ID
645/// * `Err(SignError)`: An error if the protocol fails
646async fn run_inner<R: Relay, S: SignSetupMessage>(
647    setup: S,
648    seed: Seed,
649    relay: &mut FilteredMsgRelay<R>,
650) -> Result<(Signature, RecoveryId), SignError> {
651    let t = setup.total_participants();
652
653    let pre_signature_result =
654        pre_signature_inner(&setup, seed, relay).await?;
655
656    let msg_hash = setup.message_hash();
657
658    run_final(&setup, relay, t, msg_hash, &pre_signature_result).await
659}
660
661/// Executes the pre-signature phase of the DSG protocol
662///
663/// This function runs only the pre-signature phase of the protocol,
664/// producing a pre-signature that can be used later to sign messages.
665///
666/// # Type Parameters
667///
668/// * `R`: Type implementing the `Relay` trait for message communication
669/// * `S`: Type implementing the `PreSignSetupMessage` trait for setup parameters
670///
671/// # Arguments
672///
673/// * `setup`: Setup parameters for the protocol
674/// * `seed`: Random seed for generating random values
675/// * `relay`: Message relay for communication between parties
676///
677/// # Returns
678///
679/// A `Result` containing either:
680/// * `Ok(PreSign)`: The pre-signature result
681/// * `Err(SignError)`: An error if the protocol fails
682pub async fn pre_signature<R: Relay, S: PreSignSetupMessage>(
683    setup: S,
684    seed: Seed,
685    relay: R,
686) -> Result<PreSign, SignError> {
687    let abort_msg = create_abort_message(&setup);
688    let mut relay = FilteredMsgRelay::new(relay);
689
690    relay.ask_messages(&setup, ABORT_MESSAGE_TAG, false).await?;
691    relay.ask_messages(&setup, DSG_MSG_R1, false).await?;
692    relay.ask_messages(&setup, DSG_MSG_R2, true).await?;
693    relay.ask_messages(&setup, DSG_MSG_R3, true).await?;
694
695    let result = match pre_signature_inner(&setup, seed, &mut relay).await {
696        Ok(result) => Ok(result),
697        Err(SignError::AbortProtocol(p)) => Err(SignError::AbortProtocol(p)),
698        Err(SignError::SendMessage) => Err(SignError::SendMessage),
699        Err(err) => {
700            relay.send(abort_msg).await?;
701            Err(err)
702        }
703    };
704
705    let _ = relay.close().await;
706
707    result
708}
709
710/// Executes the finish phase of the DSG protocol
711///
712/// This function runs the finish phase of the protocol, using a
713/// pre-signature to generate the final signature for a message.
714///
715/// # Type Parameters
716///
717/// * `R`: Type implementing the `Relay` trait for message communication
718/// * `S`: Type implementing the `FinalSignSetupMessage` trait for setup parameters
719///
720/// # Arguments
721///
722/// * `setup`: Setup parameters for the protocol
723/// * `relay`: Message relay for communication between parties
724///
725/// # Returns
726///
727/// A `Result` containing either:
728/// * `Ok((Signature, RecoveryId))`: The final signature and recovery ID
729/// * `Err(SignError)`: An error if the protocol fails
730pub async fn finish<R: Relay, S: FinalSignSetupMessage>(
731    setup: S,
732    relay: R,
733) -> Result<(Signature, RecoveryId), SignError> {
734    let pre_signature_result = setup.pre_signature();
735    let msg_hash = setup.message_hash();
736    let mut relay = FilteredMsgRelay::new(relay);
737
738    relay.ask_messages(&setup, ABORT_MESSAGE_TAG, false).await?;
739    relay.ask_messages(&setup, DSG_MSG_R4, false).await?;
740
741    let result = run_final(
742        &setup,
743        &mut relay,
744        setup.total_participants(),
745        msg_hash,
746        pre_signature_result,
747    )
748    .await;
749
750    let _ = relay.close().await;
751
752    result
753}
754
755/// Inner function for the finish phase of the DSG protocol
756///
757/// This function implements the core logic of the finish phase,
758/// where parties use a pre-signature to generate the final signature.
759///
760/// # Type Parameters
761///
762/// * `R`: Type implementing the `Relay` trait for message communication
763/// * `S`: Type implementing the `ProtocolParticipant` trait for participant information
764///
765/// # Arguments
766///
767/// * `setup`: Setup parameters for the protocol
768/// * `relay`: Message relay for communication between parties
769/// * `t`: Threshold value for the signature
770/// * `msg_hash`: Hash of the message to be signed
771/// * `pre_signature_result`: The pre-signature result from the pre-signature phase
772///
773/// # Returns
774///
775/// A `Result` containing either:
776/// * `Ok((Signature, RecoveryId))`: The final signature and recovery ID
777/// * `Err(SignError)`: An error if the protocol fails
778async fn run_final<R: Relay, S: ProtocolParticipant>(
779    setup: &S,
780    relay: &mut FilteredMsgRelay<R>,
781    t: usize,
782    msg_hash: [u8; 32],
783    pre_signature_result: &PreSign,
784) -> Result<(Signature, RecoveryId), SignError> {
785    let public_key = decode_point(&pre_signature_result.public_key).unwrap();
786    let r = decode_point(&pre_signature_result.r).unwrap();
787
788    let partial_signature =
789        create_partial_signature(pre_signature_result, msg_hash)?;
790
791    relay
792        .send(SignedMessage::build(
793            &setup.msg_id(None, DSG_MSG_R4),
794            setup.message_ttl().as_secs() as _,
795            0,
796            setup.signer(),
797            |msg4: &mut SignMsg4, _| {
798                msg4.session_id = partial_signature.final_session_id;
799                msg4.s_0 = encode_scalar(&partial_signature.s_0);
800                msg4.s_1 = encode_scalar(&partial_signature.s_1);
801            },
802        ))
803        .await?;
804
805    let mut partial_signatures: Vec<PartialSignature> = Vec::with_capacity(t);
806
807    partial_signatures.push(partial_signature);
808
809    Round::new(setup.total_participants() - 1, DSG_MSG_R4, relay)
810        .of_signed_messages(
811            setup,
812            SignError::AbortProtocol,
813            |msg: &SignMsg4, _party_idx| {
814                partial_signatures.push(PartialSignature {
815                    final_session_id: msg.session_id,
816                    public_key,
817                    message_hash: msg_hash,
818                    s_0: decode_scalar(&msg.s_0)
819                        .ok_or(SignError::InvalidMessage)?,
820                    s_1: decode_scalar(&msg.s_1)
821                        .ok_or(SignError::InvalidMessage)?,
822                    r,
823                });
824
825                Ok(())
826            },
827        )
828        .await?;
829
830    combine_partial_signature(&partial_signatures)
831}
832
833/// Computes the hash of a commitment value
834///
835/// This function computes the hash of a commitment value using the
836/// session ID, R point, and blind factor.
837///
838/// # Arguments
839///
840/// * `session_id`: The session identifier
841/// * `big_r_i`: The R point value
842/// * `blind_factor`: The blind factor value
843///
844/// # Returns
845///
846/// A 32-byte array containing the hash of the commitment
847fn hash_commitment_r_i(
848    session_id: &[u8],
849    big_r_i: &ProjectivePoint,
850    blind_factor: &[u8; 32],
851) -> [u8; 32] {
852    let mut hasher = Sha256::new();
853    hasher.update(DSG_LABEL);
854    hasher.update(session_id.as_ref());
855    hasher.update(big_r_i.to_bytes());
856    hasher.update(blind_factor);
857    hasher.update(COMMITMENT_LABEL);
858
859    hasher.finalize().into()
860}
861
862/// Computes the zeta_i value for a party
863///
864/// This function computes the zeta_i value used in the signature
865/// generation process.
866///
867/// # Arguments
868///
869/// * `keyshare`: The key share of the party
870/// * `party_id_list`: List of party IDs participating in the protocol
871/// * `sig_id`: The signature identifier
872///
873/// # Returns
874///
875/// The computed zeta_i scalar value
876fn get_zeta_i(
877    keyshare: &Keyshare,
878    party_id_list: &[(usize, u8)],
879    sig_id: &[u8],
880) -> Scalar {
881    let mut sum_p_0 = Scalar::ZERO;
882    for &(_, p_0_party) in party_id_list {
883        if p_0_party >= keyshare.party_id {
884            continue;
885        }
886
887        let seed_j_i = keyshare.each(p_0_party).zeta_seed;
888
889        let mut hasher = Sha256::new();
890        hasher.update(DSG_LABEL);
891        hasher.update(seed_j_i);
892        hasher.update(sig_id);
893        hasher.update(PAIRWISE_RANDOMIZATION_LABEL);
894        sum_p_0 += Scalar::reduce(U256::from_be_slice(&hasher.finalize()));
895    }
896
897    let mut sum_p_1 = Scalar::ZERO;
898    for &(_, p_1_party) in party_id_list {
899        if p_1_party <= keyshare.party_id {
900            continue;
901        }
902
903        let seed_i_j = keyshare.each(p_1_party - 1).zeta_seed;
904
905        let mut hasher = Sha256::new();
906        hasher.update(DSG_LABEL);
907        hasher.update(seed_i_j);
908        hasher.update(sig_id);
909        hasher.update(PAIRWISE_RANDOMIZATION_LABEL);
910        sum_p_1 += Scalar::reduce(U256::from_be_slice(&hasher.finalize()));
911    }
912
913    sum_p_0 - sum_p_1
914}
915
916/// Computes the Birkhoff coefficients for the protocol
917///
918/// This function computes the Birkhoff coefficients used in the
919/// signature generation process.
920///
921/// # Arguments
922///
923/// * `keyshare`: The key share of the party
924/// * `sign_party_ids`: List of party IDs participating in the protocol
925///
926/// # Returns
927///
928/// A map of party indices to their corresponding Birkhoff coefficients
929fn get_birkhoff_coefficients(
930    keyshare: &Keyshare,
931    sign_party_ids: &[(usize, u8)],
932) -> HashMap<usize, Scalar> {
933    let params = sign_party_ids
934        .iter()
935        .map(|&(_, pid)| {
936            (keyshare.get_x_i(pid), keyshare.get_rank(pid) as usize)
937        })
938        .collect::<Vec<_>>();
939
940    let betta_vec = birkhoff_coeffs::<Secp256k1>(&params);
941
942    sign_party_ids
943        .iter()
944        .zip(betta_vec.iter())
945        .map(|((_, pid), w_i)| (*pid as usize, *w_i))
946        .collect::<HashMap<_, _>>()
947}
948
949/// Computes the Lagrange coefficient for a party
950///
951/// This function computes the Lagrange coefficient used in the
952/// signature generation process.
953///
954/// # Arguments
955///
956/// * `keyshare`: The key share of the party
957/// * `sign_party_ids`: List of party IDs participating in the protocol
958///
959/// # Returns
960///
961/// The computed Lagrange coefficient
962fn get_lagrange_coeff(
963    keyshare: &Keyshare,
964    sign_party_ids: &[(usize, u8)],
965) -> Scalar {
966    let mut coeff = Scalar::from(1u64);
967    let pid = keyshare.party_id;
968    let x_i = &keyshare.get_x_i(pid) as &Scalar;
969
970    for &(_, party_id) in sign_party_ids {
971        let x_j = &keyshare.get_x_i(party_id) as &Scalar;
972        if x_i.ct_ne(x_j).into() {
973            let sub = x_j - x_i;
974            coeff *= x_j * &sub.invert().unwrap();
975        }
976    }
977
978    coeff
979}
980
981/// Computes a list of Lagrange coefficients
982///
983/// This function computes a list of Lagrange coefficients for a set
984/// of party points.
985///
986/// # Type Parameters
987///
988/// * `K`: Type of the key function
989/// * `T`: Type of the party points
990///
991/// # Arguments
992///
993/// * `party_points`: List of party points
994/// * `k`: Function to extract the key from a party point
995///
996/// # Returns
997///
998/// An iterator over the computed Lagrange coefficients
999pub(crate) fn get_lagrange_coeff_list<'a, K, T>(
1000    party_points: &'a [T],
1001    k: K,
1002) -> impl Iterator<Item = Scalar> + 'a
1003where
1004    K: Fn(&T) -> &NonZeroScalar + 'a,
1005{
1006    party_points.iter().map(move |x_i| {
1007        let x_i = k(x_i);
1008        let mut coeff = Scalar::ONE;
1009        for x_j in party_points {
1010            let x_j = k(x_j);
1011            if x_i.ct_ne(x_j).into() {
1012                let sub = x_j.sub(x_i);
1013                // SAFETY: Invert is safe because we check x_j != x_i, so sub is not zero.
1014                coeff *= x_j.as_ref() * &sub.invert().unwrap();
1015            }
1016        }
1017        coeff
1018    })
1019}
1020
1021/// Verifies a commitment value
1022///
1023/// This function verifies that a commitment value matches the expected
1024/// hash of the session ID, R point, and blind factor.
1025///
1026/// # Arguments
1027///
1028/// * `sid`: The session identifier
1029/// * `big_r_i`: The R point value
1030/// * `blind_factor`: The blind factor value
1031/// * `commitment`: The commitment value to verify
1032///
1033/// # Returns
1034///
1035/// `true` if the commitment is valid, `false` otherwise
1036fn verify_commitment_r_i(
1037    sid: &[u8],
1038    big_r_i: &ProjectivePoint,
1039    blind_factor: &[u8; 32],
1040    commitment: &[u8],
1041) -> bool {
1042    let compare_commitment = hash_commitment_r_i(sid, big_r_i, blind_factor);
1043
1044    commitment.ct_eq(&compare_commitment).into()
1045}
1046
1047/// Generates a session ID for the MtA protocol
1048///
1049/// This function generates a unique session ID for the MtA protocol
1050/// based on the final session ID and the sender/receiver IDs.
1051///
1052/// # Arguments
1053///
1054/// * `final_session_id`: The final session identifier
1055/// * `sender_id`: The ID of the sender party
1056/// * `receiver_id`: The ID of the receiver party
1057///
1058/// # Returns
1059///
1060/// A 32-byte array containing the generated session ID
1061fn mta_session_id(
1062    final_session_id: &[u8],
1063    sender_id: u8,
1064    receiver_id: u8,
1065) -> [u8; 32] {
1066    let mut h = Sha256::new();
1067    h.update(DSG_LABEL);
1068    h.update(final_session_id);
1069    h.update(b"sender");
1070    h.update([sender_id]);
1071    h.update(b"receiver");
1072    h.update([receiver_id]);
1073    h.update(PAIRWISE_MTA_LABEL);
1074    h.finalize().into()
1075}
1076
1077/// Test module for the DSG protocol
1078///
1079/// This module contains various test cases for the DSG protocol,
1080/// including tests for different party configurations and scenarios.
1081#[cfg(test)]
1082mod tests {
1083    use super::*;
1084
1085    use tokio::task::JoinSet;
1086
1087    use sl_mpc_mate::coord::SimpleMessageRelay;
1088
1089    use crate::{
1090        keygen::utils::gen_keyshares,
1091        sign::{setup_dsg, setup_finish_sign},
1092    };
1093
1094    #[tokio::test(flavor = "multi_thread")]
1095    async fn s2x2() {
1096        let coord = SimpleMessageRelay::new();
1097
1098        let shares = gen_keyshares(2, 2, Some(&[0, 0])).await;
1099
1100        let chain_path = "m";
1101
1102        let mut parties = JoinSet::new();
1103        for (setup, seed) in setup_dsg(None, &shares, chain_path) {
1104            parties.spawn(run(setup, seed, coord.connect()));
1105        }
1106
1107        while let Some(fini) = parties.join_next().await {
1108            let fini = fini.unwrap();
1109
1110            if let Err(ref err) = fini {
1111                println!("error {err:?}");
1112            }
1113
1114            let _fini = fini.unwrap();
1115        }
1116    }
1117
1118    #[tokio::test(flavor = "multi_thread")]
1119    async fn s2x3() {
1120        let coord = SimpleMessageRelay::new();
1121
1122        let shares = gen_keyshares(2, 3, Some(&[0, 1, 1])).await;
1123
1124        let vk =
1125            VerifyingKey::from_affine(shares[0].public_key().to_affine())
1126                .unwrap();
1127
1128        let chain_path = "m";
1129
1130        let mut parties = JoinSet::new();
1131        for (setup, seed) in setup_dsg(None, &shares[0..2], chain_path) {
1132            parties.spawn(run(setup, seed, coord.connect()));
1133        }
1134
1135        while let Some(fini) = parties.join_next().await {
1136            let fini = fini.unwrap();
1137
1138            if let Err(ref err) = fini {
1139                println!("error {err:?}");
1140            }
1141
1142            let (sign, recid) = fini.unwrap();
1143
1144            let hash = [1u8; 32];
1145
1146            let recid2 =
1147                RecoveryId::trial_recovery_from_prehash(&vk, &hash, &sign)
1148                    .unwrap();
1149
1150            assert_eq!(recid, recid2);
1151        }
1152    }
1153
1154    #[tokio::test(flavor = "multi_thread")]
1155    async fn s2x3_all_shares() {
1156        let coord = SimpleMessageRelay::new();
1157
1158        let shares = gen_keyshares(2, 3, Some(&[0, 1, 1])).await;
1159
1160        let vk =
1161            VerifyingKey::from_affine(shares[0].public_key().to_affine())
1162                .unwrap();
1163
1164        let chain_path = "m";
1165
1166        let mut parties = JoinSet::new();
1167        for (setup, seed) in setup_dsg(None, &shares, chain_path) {
1168            parties.spawn(run(setup, seed, coord.connect()));
1169        }
1170
1171        while let Some(fini) = parties.join_next().await {
1172            let fini = fini.unwrap();
1173
1174            if let Err(ref err) = fini {
1175                println!("error {err:?}");
1176            }
1177
1178            let (sign, recid) = fini.unwrap();
1179
1180            let hash = [1u8; 32];
1181
1182            let recid2 =
1183                RecoveryId::trial_recovery_from_prehash(&vk, &hash, &sign)
1184                    .unwrap();
1185
1186            assert_eq!(recid, recid2);
1187        }
1188    }
1189
1190    #[tokio::test(flavor = "multi_thread")]
1191    async fn s3x5() {
1192        let coord = SimpleMessageRelay::new();
1193
1194        let shares = gen_keyshares(3, 5, Some(&[0, 1, 1, 1, 1])).await;
1195
1196        let vk =
1197            VerifyingKey::from_affine(shares[0].public_key().to_affine())
1198                .unwrap();
1199
1200        let chain_path = "m";
1201
1202        let mut parties = JoinSet::new();
1203        for (setup, seed) in setup_dsg(None, &shares[0..3], chain_path) {
1204            parties.spawn(run(setup, seed, coord.connect()));
1205        }
1206
1207        while let Some(fini) = parties.join_next().await {
1208            let fini = fini.unwrap();
1209
1210            if let Err(ref err) = fini {
1211                println!("error {err:?}");
1212            }
1213
1214            let (sign, recid) = fini.unwrap();
1215
1216            let hash = [1u8; 32];
1217
1218            let recid2 =
1219                RecoveryId::trial_recovery_from_prehash(&vk, &hash, &sign)
1220                    .unwrap();
1221
1222            assert_eq!(recid, recid2);
1223        }
1224    }
1225
1226    #[tokio::test(flavor = "multi_thread")]
1227    async fn pre2x3() {
1228        let shares = gen_keyshares(2, 3, Some(&[0, 1, 1])).await;
1229        let chain_path = "m";
1230
1231        let coord = SimpleMessageRelay::new();
1232        let mut parties = JoinSet::new();
1233
1234        for (setup, seed) in setup_dsg(None, &shares[0..2], chain_path) {
1235            parties.spawn(pre_signature(setup, seed, coord.connect()));
1236        }
1237
1238        let mut pre_sign = vec![];
1239
1240        while let Some(fini) = parties.join_next().await {
1241            let fini = fini.unwrap();
1242
1243            if let Err(ref err) = fini {
1244                println!("error {err:?}");
1245            }
1246
1247            pre_sign.push(fini.unwrap())
1248        }
1249
1250        let coord = SimpleMessageRelay::new();
1251        let mut parties = JoinSet::new();
1252
1253        for setup in setup_finish_sign(pre_sign) {
1254            parties.spawn(finish(setup, coord.connect()));
1255        }
1256
1257        while let Some(fini) = parties.join_next().await {
1258            let fini = fini.unwrap();
1259
1260            if let Err(ref err) = fini {
1261                println!("error {err:?}");
1262            }
1263
1264            let _fini = fini.unwrap();
1265        }
1266    }
1267}