dkls23/keygen/
migration.rs1use crate::keygen::{KeyRefreshData, KeygenError, Keyshare};
9use crate::proto::{create_abort_message, FilteredMsgRelay};
10use crate::setup::KeygenSetupMessage;
11use crate::{keygen, Seed};
12use futures_util::SinkExt;
13use k256::{ProjectivePoint, Scalar};
14use sl_mpc_mate::coord::Relay;
15
16pub async fn run<R, S>(
49 setup: S,
50 seed: Seed,
51 relay: R,
52 s_i_0: Scalar,
53 public_key: ProjectivePoint,
54 root_chain_code: [u8; 32],
55) -> Result<Keyshare, KeygenError>
56where
57 S: KeygenSetupMessage,
58 R: Relay,
59{
60 let abort_msg = create_abort_message(&setup);
61
62 let mut relay = FilteredMsgRelay::new(relay);
63
64 let key_refresh_data = KeyRefreshData {
65 s_i_0,
66 lost_keyshare_party_ids: vec![],
67 expected_public_key: public_key,
68 root_chain_code, };
70
71 let result: Result<Keyshare, KeygenError> =
72 keygen::run_inner(setup, seed, &mut relay, Some(&key_refresh_data))
73 .await;
74
75 let new_keyshare = match result {
76 Ok(eph_keyshare) => eph_keyshare,
77
78 Err(KeygenError::AbortProtocol(p)) => {
79 return Err(KeygenError::AbortProtocol(p))
80 }
81
82 Err(KeygenError::SendMessage) => {
83 return Err(KeygenError::SendMessage)
84 }
85
86 Err(err_message) => {
87 #[cfg(feature = "tracing")]
88 tracing::debug!("sending abort message");
89
90 relay.send(abort_msg).await?;
91
92 return Err(err_message);
93 }
94 };
95
96 Ok(new_keyshare)
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::keygen::utils::setup_keygen;
103 use crate::sign::{run as run_dsg, setup_dsg};
104 use k256::elliptic_curve::group::GroupEncoding;
105 use k256::elliptic_curve::ops::MulByGenerator;
106 use k256::elliptic_curve::sec1::ToEncodedPoint;
107 use k256::{CompressedPoint, NonZeroScalar, U256};
108 use sl_mpc_mate::coord::SimpleMessageRelay;
109 use std::collections::VecDeque;
110 use std::sync::Arc;
111 use tokio::task::JoinSet;
112
113 #[tokio::test(flavor = "multi_thread")]
114 async fn migration_test() {
115 let coord = SimpleMessageRelay::new();
116
117 let mut parties = JoinSet::new();
118
119 let binding = hex::decode(
120 "02eba32793892022121314aed023df242292d313cb657f6f69016d90b6cfc92d33".as_bytes(),
121 )
122 .unwrap();
123 let public_key = ProjectivePoint::from_bytes(
124 CompressedPoint::from_slice(&binding),
125 );
126
127 let mut s_i_0 = VecDeque::new();
128
129 s_i_0.push_back(
130 NonZeroScalar::from_uint(U256::from_be_hex(
131 "3B6661CC3A28C174AF9D0FDD966E9F9D9D2A96682A504E1E9165D700BDC47809",
132 ))
133 .unwrap(),
134 );
135 s_i_0.push_back(
136 NonZeroScalar::from_uint(U256::from_be_hex(
137 "3361D26EBB452DDA716E38F20405B42E3ABDC890CAEE1150AB0D019D45091DC4",
138 ))
139 .unwrap(),
140 );
141 s_i_0.push_back(
142 NonZeroScalar::from_uint(U256::from_be_hex(
143 "71FDD4E9358DB270FA0EF15F4D72A6267B012781D154D2A380ECFCA86E85BEA2",
144 ))
145 .unwrap(),
146 );
147
148 let sk = s_i_0.iter().fold(Scalar::ZERO, |sum, val| sum.add(val));
149 let pub_key = ProjectivePoint::mul_by_generator(&sk);
150 println!(
151 "{:?}",
152 pub_key
153 .to_encoded_point(true)
154 .x()
155 .iter()
156 .map(|v| format!("{:02X}", v))
157 .collect::<Vec<_>>()
158 .join(".")
159 );
160 let root_chain_code = "253453627f65463253453627f6546321".as_bytes()
161 [0..32]
162 .try_into()
163 .unwrap();
164
165 for (setupmsg, seed) in setup_keygen(None, 2, 3, None) {
166 parties.spawn(run(
167 setupmsg,
168 seed,
169 coord.connect(),
170 *s_i_0.pop_front().unwrap(),
171 public_key.unwrap(),
172 root_chain_code,
173 ));
174 }
175
176 let mut new_shares = vec![];
177 while let Some(fini) = parties.join_next().await {
178 let fini = fini.unwrap();
179
180 if let Err(ref err) = fini {
181 println!("error {}", err);
182 }
183 assert!(fini.is_ok());
185
186 let new_share = fini.unwrap();
187 let pk = hex::encode(new_share.public_key().to_bytes());
188
189 new_shares.push(Arc::new(new_share));
190
191 println!("PK {}", pk);
192 }
193
194 let coord = SimpleMessageRelay::new();
196
197 new_shares.sort_by_key(|share| share.party_id);
198 let subset = &new_shares[0..2_usize];
199
200 let mut parties: JoinSet<Result<_, _>> = JoinSet::new();
201 for (setup, seed) in setup_dsg(None, subset, "m") {
202 parties.spawn(run_dsg(setup, seed, coord.connect()));
203 }
204
205 while let Some(fini) = parties.join_next().await {
206 let fini = fini.unwrap();
207
208 if let Err(ref err) = fini {
209 println!("error {err:?}");
210 }
211 let _fini = fini.unwrap();
212 }
213 }
214}