1use std::fmt;
7
8use k256::{
9 elliptic_curve::{group::GroupEncoding, PrimeField},
10 NonZeroScalar, ProjectivePoint, Scalar,
11};
12
13use zeroize::Zeroize;
14
15use sl_mpc_mate::coord::*;
16
17use crate::keygen::utils::{get_birkhoff_coefficients, get_lagrange_coeff};
18
19use crate::{
20 keygen::{run_inner, KeyRefreshData, KeygenError, Keyshare},
21 proto::{tags::*, *},
22 setup::KeygenSetupMessage,
23 Seed,
24};
25
26#[derive(Clone, Zeroize)]
28pub struct KeyshareForRefresh {
29 pub rank_list: Vec<u8>,
31
32 pub threshold: u8,
34
35 pub public_key: ProjectivePoint,
37
38 pub(crate) root_chain_code: [u8; 32],
40
41 pub s_i: Option<Scalar>,
43
44 pub x_i_list: Option<Vec<NonZeroScalar>>,
46
47 pub lost_keyshare_party_ids: Vec<u8>,
50
51 pub party_id: u8,
53}
54
55impl fmt::Debug for KeyshareForRefresh {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("KeyshareForRefresh")
58 .field("rank_list", &self.rank_list)
59 .field("threshold", &self.threshold)
60 .field("public_key", &self.public_key)
61 .field("lost_keyshare_party_ids", &self.lost_keyshare_party_ids)
62 .finish()
63 }
64}
65
66impl KeyshareForRefresh {
67 #[allow(clippy::too_many_arguments)]
68 pub fn new(
73 rank_list: Vec<u8>,
74 threshold: u8,
75 public_key: ProjectivePoint,
76 root_chain_code: [u8; 32],
77 s_i: Option<Scalar>,
78 x_i_list: Option<Vec<NonZeroScalar>>,
79 lost_keyshare_party_ids: Vec<u8>,
80 party_id: u8,
81 ) -> Self {
82 Self {
83 rank_list,
84 threshold,
85 public_key,
86 root_chain_code,
87 s_i,
88 x_i_list,
89 lost_keyshare_party_ids,
90 party_id,
91 }
92 }
93 pub fn from_keyshare(
95 keyshare: &Keyshare,
96 lost_keyshare_party_ids: Option<Vec<u8>>,
97 ) -> Self {
98 let lost_keyshare_party_ids =
99 lost_keyshare_party_ids.unwrap_or_default();
100 Self {
101 rank_list: keyshare.rank_list(),
102 threshold: keyshare.threshold,
103 public_key: keyshare.public_key(),
104 root_chain_code: keyshare.root_chain_code,
105 s_i: Some(keyshare.s_i()),
106 x_i_list: Some(keyshare.x_i_list()),
107 lost_keyshare_party_ids,
108 party_id: keyshare.party_id,
109 }
110 }
111
112 pub fn from_lost_keyshare(
114 rank_list: Vec<u8>,
115 threshold: u8,
116 public_key: ProjectivePoint,
117 lost_keyshare_party_ids: Vec<u8>,
118 party_id: u8,
119 ) -> Self {
120 Self {
121 rank_list,
122 threshold,
123 public_key,
124 root_chain_code: [0u8; 32],
125 s_i: None,
126 x_i_list: None,
127 lost_keyshare_party_ids,
128 party_id,
129 }
130 }
131
132 pub fn to_bytes(&self) -> Vec<u8> {
135 let mut bytes = Vec::with_capacity(self.size());
136
137 bytes.push(self.party_id);
138
139 bytes.push(self.rank_list.len() as u8);
140 bytes.extend_from_slice(&self.rank_list);
141
142 bytes.push(self.threshold);
143
144 bytes.extend_from_slice(&self.public_key.to_affine().to_bytes());
145
146 bytes.extend_from_slice(&self.root_chain_code);
147
148 if let Some(s_i) = self.s_i {
149 bytes.push(1);
150 bytes.extend_from_slice(&s_i.to_bytes());
151 } else {
152 bytes.push(0);
153 }
154
155 if let Some(x_i_list) = &self.x_i_list {
156 bytes.push(x_i_list.len() as u8);
157 for x_i in x_i_list {
158 bytes.extend_from_slice(&x_i.to_bytes());
159 }
160 } else {
161 bytes.push(0);
162 }
163
164 bytes.push(self.lost_keyshare_party_ids.len() as u8);
165 bytes.extend_from_slice(&self.lost_keyshare_party_ids);
166
167 bytes
168 }
169
170 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
172 let offset = std::cell::Cell::new(0usize);
173
174 let read_data = |num_bytes: u8| {
175 let len = num_bytes as usize;
176 let off = offset.replace(offset.get() + len);
177 bytes.get(off..off + len)
178 };
179
180 let read_byte = || read_data(1).map(|b| b[0]);
181
182 let party_id = read_byte()?;
183
184 let rank_list_len = read_byte()?;
185 let rank_list = read_data(rank_list_len)?;
186
187 let threshold = read_byte()?;
188
189 let public_key = read_data(33).and_then(|b| {
190 ProjectivePoint::from_bytes(b.into()).into_option()
191 })?;
192
193 let root_chain_code: [u8; 32] = read_data(32)?.try_into().ok()?;
194
195 let s_i = if read_byte()? == 1 {
196 let s_i_bytes: [u8; 32] = read_data(32)?.try_into().ok()?;
197 Scalar::from_repr(s_i_bytes.into()).into()
198 } else {
199 None
200 };
201
202 let x_i_list_len = read_byte()?;
203 let x_i_list = if x_i_list_len != 0 {
204 let mut x_i_list = Vec::with_capacity(x_i_list_len as usize);
205 for _ in 0..x_i_list_len {
206 let x_i_bytes: [u8; 32] = read_data(32)?.try_into().ok()?;
207 let x_i: NonZeroScalar =
208 Option::from(NonZeroScalar::from_repr(x_i_bytes.into()))?;
209 x_i_list.push(x_i);
210 }
211 Some(x_i_list)
212 } else {
213 None
214 };
215
216 let lost_keyshare_party_ids_len = read_byte()?;
217 let lost_keyshare_party_ids = read_data(lost_keyshare_party_ids_len)?;
218
219 Some(Self {
220 rank_list: rank_list.to_vec(),
221 threshold,
222 public_key,
223 root_chain_code,
224 s_i,
225 x_i_list,
226 lost_keyshare_party_ids: lost_keyshare_party_ids.to_vec(),
227 party_id,
228 })
229 }
230
231 fn size(&self) -> usize {
232 let mut size = 1 + self.rank_list.len();
233 size += 1; size += 1;
235 size += 33;
236 size += 32;
237 size += 1;
238 if self.s_i.is_some() {
239 size += 32;
240 }
241 if let Some(x_i_list) = &self.x_i_list {
242 size += 1;
243 size += x_i_list.len() * 32;
244 } else {
245 size += 1;
246 }
247 size += 1 + self.lost_keyshare_party_ids.len();
248 size
249 }
250}
251
252pub async fn run<R, S>(
254 setup: S,
255 seed: Seed,
256 relay: R,
257 old_keyshare: KeyshareForRefresh,
258) -> Result<Keyshare, KeygenError>
259where
260 S: KeygenSetupMessage,
261 R: Relay,
262{
263 let abort_msg = create_abort_message(&setup);
264 let mut relay = FilteredMsgRelay::new(relay);
265
266 let my_party_id = old_keyshare.party_id;
267 let n = setup.total_participants();
268
269 let mut s_i_0 = Scalar::ZERO;
270 if old_keyshare.s_i.is_some() && old_keyshare.x_i_list.is_some() {
271 let s_i = &old_keyshare.s_i.unwrap();
274 let rank_list = &old_keyshare.rank_list;
275 let x_i_list = &old_keyshare.x_i_list.unwrap();
276 let x_i = &x_i_list[my_party_id as usize];
277
278 let party_ids_with_keyshares = (0..n as u8)
279 .filter(|p| {
280 !old_keyshare.lost_keyshare_party_ids.contains(&{ *p })
281 })
282 .collect::<Vec<_>>();
283
284 let all_ranks_zero = rank_list.iter().all(|r| r == &0u8);
285
286 let lambda = if all_ranks_zero {
287 get_lagrange_coeff(x_i, x_i_list, &party_ids_with_keyshares)
288 } else {
289 get_birkhoff_coefficients(
290 rank_list,
291 x_i_list,
292 &party_ids_with_keyshares,
293 )
294 .get(&(my_party_id as usize))
295 .cloned()
296 .unwrap_or(Scalar::ZERO)
297 };
298
299 s_i_0 = lambda * s_i;
300 }
301
302 let key_refresh_data = KeyRefreshData {
303 s_i_0,
304 lost_keyshare_party_ids: old_keyshare.lost_keyshare_party_ids,
305 expected_public_key: old_keyshare.public_key,
306 root_chain_code: old_keyshare.root_chain_code,
307 };
308
309 let result: Result<Keyshare, KeygenError> =
310 run_inner(setup, seed, &mut relay, Some(&key_refresh_data)).await;
311
312 let new_keyshare = match result {
313 Ok(eph_keyshare) => eph_keyshare,
314
315 Err(KeygenError::AbortProtocol(p)) => {
316 return Err(KeygenError::AbortProtocol(p))
317 }
318
319 Err(KeygenError::SendMessage) => {
320 return Err(KeygenError::SendMessage)
321 }
322
323 Err(err_message) => {
324 #[cfg(feature = "tracing")]
325 tracing::debug!("sending abort message");
326
327 relay.send(abort_msg).await?;
328
329 return Err(err_message);
330 }
331 };
332
333 Ok(new_keyshare)
334}
335
336#[cfg(any(test, feature = "test-support"))]
338pub fn setup_key_refresh(
339 t: u8,
340 n: u8,
341 n_i_list: Option<&[u8]>,
342 key_shares_for_refresh: Vec<KeyshareForRefresh>,
343) -> Vec<(
344 crate::setup::keygen::SetupMessage,
345 [u8; 32],
346 KeyshareForRefresh,
347)> {
348 super::utils::setup_keygen(None, t, n, n_i_list)
349 .into_iter()
350 .zip(key_shares_for_refresh)
351 .map(|((setup, seed), share)| (setup, seed, share))
352 .collect()
353}
354
355#[cfg(test)]
356mod tests {
357 use std::sync::Arc;
358
359 use k256::elliptic_curve::group::GroupEncoding;
360
361 use super::*;
362
363 use tokio::task::JoinSet;
364
365 use crate::keygen::utils::gen_keyshares;
366
367 use crate::sign::{run as run_dsg, setup_dsg};
368
369 #[tokio::test(flavor = "multi_thread")]
371 async fn r1() {
372 let mut old_shares = gen_keyshares(2, 3, Some(&[0, 0, 0])).await;
373
374 old_shares.swap(0, 2);
375
376 let coord = SimpleMessageRelay::new();
377
378 let mut parties = JoinSet::new();
379
380 let key_shares_for_refresh = old_shares
381 .iter()
382 .map(|share| KeyshareForRefresh::from_keyshare(share, None))
383 .collect();
384
385 for (setup, seed, share) in
386 setup_key_refresh(2, 3, Some(&[0, 1, 1]), key_shares_for_refresh)
387 {
388 parties.spawn(run(setup, seed, coord.connect(), share));
389 }
390
391 let mut new_shares = vec![];
392 while let Some(fini) = parties.join_next().await {
393 let fini = fini.unwrap();
394
395 if let Err(ref err) = fini {
396 println!("error {}", err);
397 }
398
399 assert!(fini.is_ok());
400
401 let new_share = fini.unwrap();
402 let pk = hex::encode(new_share.public_key().to_bytes());
403
404 new_shares.push(Arc::new(new_share));
405
406 println!("PK {}", pk);
407 }
408
409 let coord = SimpleMessageRelay::new();
411
412 new_shares.sort_by_key(|share| share.party_id);
413 let subset = &new_shares[0..2_usize];
414
415 let mut parties: JoinSet<Result<_, _>> = JoinSet::new();
416 for (setup, seed) in setup_dsg(None, subset, "m") {
417 parties.spawn(run_dsg(setup, seed, coord.connect()));
418 }
419
420 while let Some(fini) = parties.join_next().await {
421 let fini = fini.unwrap();
422
423 if let Err(ref err) = fini {
424 println!("error {err:?}");
425 }
426 let _fini = fini.unwrap();
427 }
428 }
429
430 #[tokio::test(flavor = "multi_thread")]
431 async fn recover_lost_share() {
432 let coord = SimpleMessageRelay::new();
433 let mut parties = JoinSet::new();
434
435 let t = 2;
436 let n = 4;
437 let rank_list = [0, 0, 0, 0];
438 let old_keyshares = gen_keyshares(t, n, Some(&rank_list)).await;
439 let public_key = old_keyshares[0].public_key();
440
441 let lost_keyshare_party_ids = vec![0, 1];
443 let rank_list = vec![0u8, 0u8, 0u8, 0u8];
444 let mut key_shares_for_refresh = Vec::with_capacity(n as usize);
445 key_shares_for_refresh.push(KeyshareForRefresh::from_lost_keyshare(
446 rank_list.clone(),
447 t,
448 public_key,
449 lost_keyshare_party_ids.clone(),
450 0,
451 ));
452 key_shares_for_refresh.push(KeyshareForRefresh::from_lost_keyshare(
453 rank_list,
454 t,
455 public_key,
456 lost_keyshare_party_ids,
457 1,
458 ));
459 key_shares_for_refresh.push(KeyshareForRefresh::from_keyshare(
460 &old_keyshares[2],
461 Some(vec![0, 1]),
462 ));
463 key_shares_for_refresh.push(KeyshareForRefresh::from_keyshare(
464 &old_keyshares[3],
465 Some(vec![0, 1]),
466 ));
467
468 for (setup, seed, share) in setup_key_refresh(
470 t,
471 n,
472 Some(&[0, 0, 0, 0]),
473 key_shares_for_refresh,
474 ) {
475 parties.spawn(run(setup, seed, coord.connect(), share));
476 }
477
478 let mut new_shares = vec![];
479 while let Some(fini) = parties.join_next().await {
480 let fini = fini.unwrap();
481
482 if let Err(ref err) = fini {
483 println!("error {}", err);
484 }
485
486 assert!(fini.is_ok());
487
488 let new_share = fini.unwrap();
489 println!("PK {}", hex::encode(new_share.public_key().to_bytes()));
490
491 new_shares.push(Arc::new(new_share));
492 }
493
494 let coord = SimpleMessageRelay::new();
496
497 new_shares.sort_by_key(|share| share.party_id);
498 let subset = &new_shares[0..2_usize];
499
500 let mut parties: JoinSet<Result<_, _>> = JoinSet::new();
501 for (setup, seed) in setup_dsg(None, subset, "m") {
502 parties.spawn(run_dsg(setup, seed, coord.connect()));
503 }
504
505 while let Some(fini) = parties.join_next().await {
506 let fini = fini.unwrap();
507
508 if let Err(ref err) = fini {
509 println!("error {err:?}");
510 }
511 let _fini = fini.unwrap();
512 }
513 }
514
515 #[test]
516 fn refresh_ser_de() {
517 let share = KeyshareForRefresh::new(
518 vec![0, 0, 0, 0],
519 2,
520 ProjectivePoint::GENERATOR * Scalar::ONE,
521 [0u8; 32],
522 Some(Scalar::ONE),
523 Some(vec![NonZeroScalar::new(Scalar::ONE).unwrap()]),
524 vec![0, 1],
525 0,
526 );
527
528 let bytes = share.to_bytes();
529 let share2 = KeyshareForRefresh::from_bytes(&bytes).unwrap();
530
531 assert_eq!(share.rank_list, share2.rank_list);
532 assert_eq!(share.threshold, share2.threshold);
533 assert_eq!(share.public_key, share2.public_key);
534 assert_eq!(share.root_chain_code, share2.root_chain_code);
535 assert_eq!(share.s_i, share2.s_i);
536 assert_eq!(share.x_i_list.is_some(), share2.x_i_list.is_some());
537 let x_i_list = share.x_i_list.unwrap();
538 let x_i_list2 = share2.x_i_list.unwrap();
539
540 assert_eq!(x_i_list.len(), x_i_list2.len());
541
542 for (x_i, x_i2) in x_i_list.iter().zip(x_i_list2.iter()) {
543 assert_eq!(x_i.to_bytes(), x_i2.to_bytes());
544 }
545
546 assert_eq!(
547 share.lost_keyshare_party_ids,
548 share2.lost_keyshare_party_ids
549 );
550 }
551}