dkls23/keygen/
key_refresh.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4//! Protocol for refreshing existing keyshares without changing the corresponding public key
5
6use std::fmt;
7
8use k256::{
9    elliptic_curve::{group::GroupEncoding, PrimeField},
10    NonZeroScalar, ProjectivePoint, Scalar,
11};
12
13use zeroize::Zeroize;
14
15use sl_mpc_mate::coord::*;
16
17use crate::keygen::utils::{get_birkhoff_coefficients, get_lagrange_coeff};
18
19use crate::{
20    keygen::{run_inner, KeyRefreshData, KeygenError, Keyshare},
21    proto::{tags::*, *},
22    setup::KeygenSetupMessage,
23    Seed,
24};
25
26/// Keyshare for refresh of a party.
27#[derive(Clone, Zeroize)]
28pub struct KeyshareForRefresh {
29    /// Rank of each party
30    pub rank_list: Vec<u8>,
31
32    /// Threshold value
33    pub threshold: u8,
34
35    /// Public key of the generated key.
36    pub public_key: ProjectivePoint,
37
38    // Root chain code (used to derive child public keys)
39    pub(crate) root_chain_code: [u8; 32],
40
41    /// set s_i to None if party_i lost their key_share
42    pub s_i: Option<Scalar>,
43
44    /// set s_i to None if party_i lost their key_share
45    pub x_i_list: Option<Vec<NonZeroScalar>>,
46
47    /// list of participants ids who lost their key_shares,
48    /// should be in range [0, n-1]
49    pub lost_keyshare_party_ids: Vec<u8>,
50
51    /// Part ID from key share
52    pub party_id: u8,
53}
54
55impl fmt::Debug for KeyshareForRefresh {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("KeyshareForRefresh")
58            .field("rank_list", &self.rank_list)
59            .field("threshold", &self.threshold)
60            .field("public_key", &self.public_key)
61            .field("lost_keyshare_party_ids", &self.lost_keyshare_party_ids)
62            .finish()
63    }
64}
65
66impl KeyshareForRefresh {
67    #[allow(clippy::too_many_arguments)]
68    /// Create new KeyshareForRefresh object.
69    /// # Warning
70    /// It is recommended to use `KeyshareForRefresh::from_keyshare()` and `KeyshareForRefresh::from_lost_keyshare()` instead.
71    /// This is for advanced usecases only.
72    pub fn new(
73        rank_list: Vec<u8>,
74        threshold: u8,
75        public_key: ProjectivePoint,
76        root_chain_code: [u8; 32],
77        s_i: Option<Scalar>,
78        x_i_list: Option<Vec<NonZeroScalar>>,
79        lost_keyshare_party_ids: Vec<u8>,
80        party_id: u8,
81    ) -> Self {
82        Self {
83            rank_list,
84            threshold,
85            public_key,
86            root_chain_code,
87            s_i,
88            x_i_list,
89            lost_keyshare_party_ids,
90            party_id,
91        }
92    }
93    /// Create KeyshareForRefresh struct from Keyshare
94    pub fn from_keyshare(
95        keyshare: &Keyshare,
96        lost_keyshare_party_ids: Option<Vec<u8>>,
97    ) -> Self {
98        let lost_keyshare_party_ids =
99            lost_keyshare_party_ids.unwrap_or_default();
100        Self {
101            rank_list: keyshare.rank_list(),
102            threshold: keyshare.threshold,
103            public_key: keyshare.public_key(),
104            root_chain_code: keyshare.root_chain_code,
105            s_i: Some(keyshare.s_i()),
106            x_i_list: Some(keyshare.x_i_list()),
107            lost_keyshare_party_ids,
108            party_id: keyshare.party_id,
109        }
110    }
111
112    /// Create KeyshareForRefresh struct for the participant who lost their keyshare
113    pub fn from_lost_keyshare(
114        rank_list: Vec<u8>,
115        threshold: u8,
116        public_key: ProjectivePoint,
117        lost_keyshare_party_ids: Vec<u8>,
118        party_id: u8,
119    ) -> Self {
120        Self {
121            rank_list,
122            threshold,
123            public_key,
124            root_chain_code: [0u8; 32],
125            s_i: None,
126            x_i_list: None,
127            lost_keyshare_party_ids,
128            party_id,
129        }
130    }
131
132    ///  Serialize KeyshareForRefresh to bytes
133    ///  Used to send KeyshareForRefresh to other parties, for key-import
134    pub fn to_bytes(&self) -> Vec<u8> {
135        let mut bytes = Vec::with_capacity(self.size());
136
137        bytes.push(self.party_id);
138
139        bytes.push(self.rank_list.len() as u8);
140        bytes.extend_from_slice(&self.rank_list);
141
142        bytes.push(self.threshold);
143
144        bytes.extend_from_slice(&self.public_key.to_affine().to_bytes());
145
146        bytes.extend_from_slice(&self.root_chain_code);
147
148        if let Some(s_i) = self.s_i {
149            bytes.push(1);
150            bytes.extend_from_slice(&s_i.to_bytes());
151        } else {
152            bytes.push(0);
153        }
154
155        if let Some(x_i_list) = &self.x_i_list {
156            bytes.push(x_i_list.len() as u8);
157            for x_i in x_i_list {
158                bytes.extend_from_slice(&x_i.to_bytes());
159            }
160        } else {
161            bytes.push(0);
162        }
163
164        bytes.push(self.lost_keyshare_party_ids.len() as u8);
165        bytes.extend_from_slice(&self.lost_keyshare_party_ids);
166
167        bytes
168    }
169
170    /// Deserialize KeyshareForRefresh from bytes
171    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
172        let offset = std::cell::Cell::new(0usize);
173
174        let read_data = |num_bytes: u8| {
175            let len = num_bytes as usize;
176            let off = offset.replace(offset.get() + len);
177            bytes.get(off..off + len)
178        };
179
180        let read_byte = || read_data(1).map(|b| b[0]);
181
182        let party_id = read_byte()?;
183
184        let rank_list_len = read_byte()?;
185        let rank_list = read_data(rank_list_len)?;
186
187        let threshold = read_byte()?;
188
189        let public_key = read_data(33).and_then(|b| {
190            ProjectivePoint::from_bytes(b.into()).into_option()
191        })?;
192
193        let root_chain_code: [u8; 32] = read_data(32)?.try_into().ok()?;
194
195        let s_i = if read_byte()? == 1 {
196            let s_i_bytes: [u8; 32] = read_data(32)?.try_into().ok()?;
197            Scalar::from_repr(s_i_bytes.into()).into()
198        } else {
199            None
200        };
201
202        let x_i_list_len = read_byte()?;
203        let x_i_list = if x_i_list_len != 0 {
204            let mut x_i_list = Vec::with_capacity(x_i_list_len as usize);
205            for _ in 0..x_i_list_len {
206                let x_i_bytes: [u8; 32] = read_data(32)?.try_into().ok()?;
207                let x_i: NonZeroScalar =
208                    Option::from(NonZeroScalar::from_repr(x_i_bytes.into()))?;
209                x_i_list.push(x_i);
210            }
211            Some(x_i_list)
212        } else {
213            None
214        };
215
216        let lost_keyshare_party_ids_len = read_byte()?;
217        let lost_keyshare_party_ids = read_data(lost_keyshare_party_ids_len)?;
218
219        Some(Self {
220            rank_list: rank_list.to_vec(),
221            threshold,
222            public_key,
223            root_chain_code,
224            s_i,
225            x_i_list,
226            lost_keyshare_party_ids: lost_keyshare_party_ids.to_vec(),
227            party_id,
228        })
229    }
230
231    fn size(&self) -> usize {
232        let mut size = 1 + self.rank_list.len();
233        size += 1; // party_id
234        size += 1;
235        size += 33;
236        size += 32;
237        size += 1;
238        if self.s_i.is_some() {
239            size += 32;
240        }
241        if let Some(x_i_list) = &self.x_i_list {
242            size += 1;
243            size += x_i_list.len() * 32;
244        } else {
245            size += 1;
246        }
247        size += 1 + self.lost_keyshare_party_ids.len();
248        size
249    }
250}
251
252/// Execute Key Refresh protocol.
253pub async fn run<R, S>(
254    setup: S,
255    seed: Seed,
256    relay: R,
257    old_keyshare: KeyshareForRefresh,
258) -> Result<Keyshare, KeygenError>
259where
260    S: KeygenSetupMessage,
261    R: Relay,
262{
263    let abort_msg = create_abort_message(&setup);
264    let mut relay = FilteredMsgRelay::new(relay);
265
266    let my_party_id = old_keyshare.party_id;
267    let n = setup.total_participants();
268
269    let mut s_i_0 = Scalar::ZERO;
270    if old_keyshare.s_i.is_some() && old_keyshare.x_i_list.is_some() {
271        // calculate additive share s_i_0 of participant_i,
272        // \sum_{i=0}^{n-1} s_i_0 = private_key
273        let s_i = &old_keyshare.s_i.unwrap();
274        let rank_list = &old_keyshare.rank_list;
275        let x_i_list = &old_keyshare.x_i_list.unwrap();
276        let x_i = &x_i_list[my_party_id as usize];
277
278        let party_ids_with_keyshares = (0..n as u8)
279            .filter(|p| {
280                !old_keyshare.lost_keyshare_party_ids.contains(&{ *p })
281            })
282            .collect::<Vec<_>>();
283
284        let all_ranks_zero = rank_list.iter().all(|r| r == &0u8);
285
286        let lambda = if all_ranks_zero {
287            get_lagrange_coeff(x_i, x_i_list, &party_ids_with_keyshares)
288        } else {
289            get_birkhoff_coefficients(
290                rank_list,
291                x_i_list,
292                &party_ids_with_keyshares,
293            )
294            .get(&(my_party_id as usize))
295            .cloned()
296            .unwrap_or(Scalar::ZERO)
297        };
298
299        s_i_0 = lambda * s_i;
300    }
301
302    let key_refresh_data = KeyRefreshData {
303        s_i_0,
304        lost_keyshare_party_ids: old_keyshare.lost_keyshare_party_ids,
305        expected_public_key: old_keyshare.public_key,
306        root_chain_code: old_keyshare.root_chain_code,
307    };
308
309    let result: Result<Keyshare, KeygenError> =
310        run_inner(setup, seed, &mut relay, Some(&key_refresh_data)).await;
311
312    let new_keyshare = match result {
313        Ok(eph_keyshare) => eph_keyshare,
314
315        Err(KeygenError::AbortProtocol(p)) => {
316            return Err(KeygenError::AbortProtocol(p))
317        }
318
319        Err(KeygenError::SendMessage) => {
320            return Err(KeygenError::SendMessage)
321        }
322
323        Err(err_message) => {
324            #[cfg(feature = "tracing")]
325            tracing::debug!("sending abort message");
326
327            relay.send(abort_msg).await?;
328
329            return Err(err_message);
330        }
331    };
332
333    Ok(new_keyshare)
334}
335
336/// Generate ValidatedSetup and seed for Key refresh
337#[cfg(any(test, feature = "test-support"))]
338pub fn setup_key_refresh(
339    t: u8,
340    n: u8,
341    n_i_list: Option<&[u8]>,
342    key_shares_for_refresh: Vec<KeyshareForRefresh>,
343) -> Vec<(
344    crate::setup::keygen::SetupMessage,
345    [u8; 32],
346    KeyshareForRefresh,
347)> {
348    super::utils::setup_keygen(None, t, n, n_i_list)
349        .into_iter()
350        .zip(key_shares_for_refresh)
351        .map(|((setup, seed), share)| (setup, seed, share))
352        .collect()
353}
354
355#[cfg(test)]
356mod tests {
357    use std::sync::Arc;
358
359    use k256::elliptic_curve::group::GroupEncoding;
360
361    use super::*;
362
363    use tokio::task::JoinSet;
364
365    use crate::keygen::utils::gen_keyshares;
366
367    use crate::sign::{run as run_dsg, setup_dsg};
368
369    // (flavor = "multi_thread")
370    #[tokio::test(flavor = "multi_thread")]
371    async fn r1() {
372        let mut old_shares = gen_keyshares(2, 3, Some(&[0, 0, 0])).await;
373
374        old_shares.swap(0, 2);
375
376        let coord = SimpleMessageRelay::new();
377
378        let mut parties = JoinSet::new();
379
380        let key_shares_for_refresh = old_shares
381            .iter()
382            .map(|share| KeyshareForRefresh::from_keyshare(share, None))
383            .collect();
384
385        for (setup, seed, share) in
386            setup_key_refresh(2, 3, Some(&[0, 1, 1]), key_shares_for_refresh)
387        {
388            parties.spawn(run(setup, seed, coord.connect(), share));
389        }
390
391        let mut new_shares = vec![];
392        while let Some(fini) = parties.join_next().await {
393            let fini = fini.unwrap();
394
395            if let Err(ref err) = fini {
396                println!("error {}", err);
397            }
398
399            assert!(fini.is_ok());
400
401            let new_share = fini.unwrap();
402            let pk = hex::encode(new_share.public_key().to_bytes());
403
404            new_shares.push(Arc::new(new_share));
405
406            println!("PK {}", pk);
407        }
408
409        // sign with new key_shares
410        let coord = SimpleMessageRelay::new();
411
412        new_shares.sort_by_key(|share| share.party_id);
413        let subset = &new_shares[0..2_usize];
414
415        let mut parties: JoinSet<Result<_, _>> = JoinSet::new();
416        for (setup, seed) in setup_dsg(None, subset, "m") {
417            parties.spawn(run_dsg(setup, seed, coord.connect()));
418        }
419
420        while let Some(fini) = parties.join_next().await {
421            let fini = fini.unwrap();
422
423            if let Err(ref err) = fini {
424                println!("error {err:?}");
425            }
426            let _fini = fini.unwrap();
427        }
428    }
429
430    #[tokio::test(flavor = "multi_thread")]
431    async fn recover_lost_share() {
432        let coord = SimpleMessageRelay::new();
433        let mut parties = JoinSet::new();
434
435        let t = 2;
436        let n = 4;
437        let rank_list = [0, 0, 0, 0];
438        let old_keyshares = gen_keyshares(t, n, Some(&rank_list)).await;
439        let public_key = old_keyshares[0].public_key();
440
441        // party_0 and party_1 key_shares was lost
442        let lost_keyshare_party_ids = vec![0, 1];
443        let rank_list = vec![0u8, 0u8, 0u8, 0u8];
444        let mut key_shares_for_refresh = Vec::with_capacity(n as usize);
445        key_shares_for_refresh.push(KeyshareForRefresh::from_lost_keyshare(
446            rank_list.clone(),
447            t,
448            public_key,
449            lost_keyshare_party_ids.clone(),
450            0,
451        ));
452        key_shares_for_refresh.push(KeyshareForRefresh::from_lost_keyshare(
453            rank_list,
454            t,
455            public_key,
456            lost_keyshare_party_ids,
457            1,
458        ));
459        key_shares_for_refresh.push(KeyshareForRefresh::from_keyshare(
460            &old_keyshares[2],
461            Some(vec![0, 1]),
462        ));
463        key_shares_for_refresh.push(KeyshareForRefresh::from_keyshare(
464            &old_keyshares[3],
465            Some(vec![0, 1]),
466        ));
467
468        // recover lost key_share
469        for (setup, seed, share) in setup_key_refresh(
470            t,
471            n,
472            Some(&[0, 0, 0, 0]),
473            key_shares_for_refresh,
474        ) {
475            parties.spawn(run(setup, seed, coord.connect(), share));
476        }
477
478        let mut new_shares = vec![];
479        while let Some(fini) = parties.join_next().await {
480            let fini = fini.unwrap();
481
482            if let Err(ref err) = fini {
483                println!("error {}", err);
484            }
485
486            assert!(fini.is_ok());
487
488            let new_share = fini.unwrap();
489            println!("PK {}", hex::encode(new_share.public_key().to_bytes()));
490
491            new_shares.push(Arc::new(new_share));
492        }
493
494        // sign with party_0 and party_1 new key_shares
495        let coord = SimpleMessageRelay::new();
496
497        new_shares.sort_by_key(|share| share.party_id);
498        let subset = &new_shares[0..2_usize];
499
500        let mut parties: JoinSet<Result<_, _>> = JoinSet::new();
501        for (setup, seed) in setup_dsg(None, subset, "m") {
502            parties.spawn(run_dsg(setup, seed, coord.connect()));
503        }
504
505        while let Some(fini) = parties.join_next().await {
506            let fini = fini.unwrap();
507
508            if let Err(ref err) = fini {
509                println!("error {err:?}");
510            }
511            let _fini = fini.unwrap();
512        }
513    }
514
515    #[test]
516    fn refresh_ser_de() {
517        let share = KeyshareForRefresh::new(
518            vec![0, 0, 0, 0],
519            2,
520            ProjectivePoint::GENERATOR * Scalar::ONE,
521            [0u8; 32],
522            Some(Scalar::ONE),
523            Some(vec![NonZeroScalar::new(Scalar::ONE).unwrap()]),
524            vec![0, 1],
525            0,
526        );
527
528        let bytes = share.to_bytes();
529        let share2 = KeyshareForRefresh::from_bytes(&bytes).unwrap();
530
531        assert_eq!(share.rank_list, share2.rank_list);
532        assert_eq!(share.threshold, share2.threshold);
533        assert_eq!(share.public_key, share2.public_key);
534        assert_eq!(share.root_chain_code, share2.root_chain_code);
535        assert_eq!(share.s_i, share2.s_i);
536        assert_eq!(share.x_i_list.is_some(), share2.x_i_list.is_some());
537        let x_i_list = share.x_i_list.unwrap();
538        let x_i_list2 = share2.x_i_list.unwrap();
539
540        assert_eq!(x_i_list.len(), x_i_list2.len());
541
542        for (x_i, x_i2) in x_i_list.iter().zip(x_i_list2.iter()) {
543            assert_eq!(x_i.to_bytes(), x_i2.to_bytes());
544        }
545
546        assert_eq!(
547            share.lost_keyshare_party_ids,
548            share2.lost_keyshare_party_ids
549        );
550    }
551}