dkls23/keygen/
keyshare.rs1use core::{mem, ops::Deref};
15
16use derivation_path::DerivationPath;
17use k256::{NonZeroScalar, ProjectivePoint, Scalar};
18use zeroize::ZeroizeOnDrop;
19
20use sl_oblivious::soft_spoken::{ReceiverOTSeed, SenderOTSeed};
21
22use sl_mpc_mate::bip32::{
23 derive_child_pubkey, derive_xpub, get_finger_print, BIP32Error,
24 KeyFingerPrint, Prefix, XPubKey,
25};
26
27use crate::proto::*;
28
29use self::details::KeyshareInfo;
30
31mod details;
32
33#[derive(Clone, ZeroizeOnDrop)]
46pub struct Keyshare {
47 buffer: Vec<u8>,
48}
49
50impl Keyshare {
51 pub const MAGIC: [u8; 4] = [0, 0, 0, 1];
55}
56
57impl Keyshare {
58 const INFO: usize = mem::size_of::<details::KeyshareInfo>();
59 const OTHER: usize = mem::size_of::<details::OtherParty>();
60 const EACH: usize = mem::size_of::<details::EachParty>();
61
62 fn calculate_size(n: u8, extra: usize) -> usize {
65 assert!(n > 1);
66
67 Self::INFO
68 + (n as usize) * Self::EACH
69 + (n as usize - 1) * Self::OTHER
70 + extra
71 }
72
73 pub fn new(n: u8, t: u8, id: u8, extra: &[u8]) -> Keyshare {
84 let size = Self::calculate_size(n, extra.len());
85 let mut buffer = vec![0u8; size];
86
87 buffer[size - extra.len()..].copy_from_slice(extra);
88
89 let mut share = Self { buffer };
90
91 let info = share.info_mut();
92
93 info.magic = Self::MAGIC;
94 info.total_parties = n;
95 info.threshold = t;
96 info.party_id = id;
97 info.extra = (extra.len() as u32).to_be_bytes();
98
99 share
100 }
101
102 fn is_valid_buffer(buffer: &[u8]) -> bool {
103 if buffer.len() <= Self::INFO {
104 return false;
105 }
106
107 let info = match bytemuck::try_from_bytes::<KeyshareInfo>(
108 &buffer[..Self::INFO],
109 )
110 .ok()
111 {
112 Some(info) => info,
113 _ => return false,
114 };
115
116 if info.magic != Self::MAGIC {
118 return false;
119 }
120
121 if info.threshold < 2 || info.threshold > info.total_parties {
122 return false;
123 }
124
125 if decode_point(&info.public_key).is_none() {
126 return false;
127 }
128
129 if decode_scalar(&info.s_i).is_none() {
130 return false;
131 }
132
133 let extra: usize = u32::from_be_bytes(info.extra) as usize;
134 let size = Self::calculate_size(info.total_parties, extra);
135
136 if size != buffer.len() {
137 return false;
138 }
139
140 true
141 }
142
143 pub fn from_bytes(buffer: &[u8]) -> Option<Self> {
151 if Self::is_valid_buffer(buffer) {
152 Some(Self {
153 buffer: buffer.to_vec(),
154 })
155 } else {
156 None
157 }
158 }
159
160 pub fn from_vec(buffer: Vec<u8>) -> Result<Self, Vec<u8>> {
168 if Self::is_valid_buffer(&buffer) {
169 Ok(Self { buffer })
170 } else {
171 Err(buffer)
172 }
173 }
174
175 pub fn as_slice(&self) -> &[u8] {
177 &self.buffer
178 }
179
180 pub fn public_key(&self) -> ProjectivePoint {
182 decode_point(&self.info().public_key).unwrap()
183 }
184
185 pub fn rank_list(&self) -> Vec<u8> {
187 (0..self.info().total_parties)
188 .map(|p| self.each(p).rank)
189 .collect()
190 }
191
192 pub fn x_i_list(&self) -> Vec<NonZeroScalar> {
194 (0..self.info().total_parties)
195 .map(|p| decode_nonzero(&self.each(p).x_i).unwrap())
196 .collect()
197 }
198
199 pub(crate) fn get_x_i(&self, party_id: u8) -> NonZeroScalar {
200 NonZeroScalar::new(decode_scalar(&self.each(party_id).x_i).unwrap())
201 .unwrap()
202 }
203
204 pub fn zero_ranks(&self) -> bool {
206 for p in 0..self.info().total_parties {
207 if self.each(p).rank != 0 {
208 return false;
209 }
210 }
211
212 true
213 }
214
215 pub fn get_rank(&self, party_id: u8) -> u8 {
220 self.each(party_id).rank
221 }
222
223 pub fn s_i(&self) -> Scalar {
225 decode_scalar(&self.info().s_i).unwrap()
226 }
227
228 pub fn extra_data(&self) -> &[u8] {
232 let n = self.info().total_parties as usize;
233 let offset = Self::INFO + Self::OTHER * (n - 1) + Self::EACH * n;
234
235 &self.buffer[offset..]
236 }
237
238 pub(crate) fn info(&self) -> &details::KeyshareInfo {
239 let bytes = &self.buffer[..Self::INFO];
240 bytemuck::from_bytes(bytes)
241 }
242
243 pub(crate) fn info_mut(&mut self) -> &mut details::KeyshareInfo {
244 let bytes = &mut self.buffer[..Self::INFO];
245
246 bytemuck::from_bytes_mut(bytes)
247 }
248
249 pub(crate) fn other_mut(
250 &mut self,
251 party_id: u8,
252 ) -> &mut details::OtherParty {
253 assert!(party_id < self.info().total_parties);
254
255 let n = self.info().total_parties as usize;
256 let offset = Self::INFO + Self::EACH * n;
257
258 let idx = self.get_idx_from_id(party_id);
259 let bytes = &mut self.buffer[offset..][..Self::OTHER * (n - 1)];
260
261 let others: &mut [details::OtherParty] =
262 bytemuck::cast_slice_mut(bytes);
263
264 &mut others[idx]
265 }
266
267 pub(crate) fn other(&self, party_id: u8) -> &details::OtherParty {
268 assert!(party_id < self.info().total_parties);
269
270 let n = self.info().total_parties as usize;
271 let offset = Self::INFO + Self::EACH * n;
272
273 let bytes = &self.buffer[offset..][..Self::OTHER * (n - 1)];
274
275 let others: &[details::OtherParty] = bytemuck::cast_slice(bytes);
276
277 &others[self.get_idx_from_id(party_id)]
278 }
279
280 pub(crate) fn each_mut(
281 &mut self,
282 party_id: u8,
283 ) -> &mut details::EachParty {
284 assert!(party_id < self.info().total_parties);
285
286 let n = self.info().total_parties as usize;
287
288 let bytes = &mut self.buffer[Self::INFO..][..Self::EACH * n];
289 let each: &mut [details::EachParty] = bytemuck::cast_slice_mut(bytes);
290
291 &mut each[party_id as usize]
292 }
293
294 pub(crate) fn each(&self, party_id: u8) -> &details::EachParty {
295 assert!(party_id < self.info().total_parties);
296
297 let n = self.info().total_parties as usize;
298
299 let bytes = &self.buffer[Self::INFO..][..Self::EACH * n];
300 let each: &[details::EachParty] = bytemuck::cast_slice(bytes);
301
302 &each[party_id as usize]
303 }
304
305 pub fn big_s(&self, party_id: u8) -> ProjectivePoint {
310 decode_point(&self.each(party_id).big_s).unwrap()
311 }
312
313 fn get_idx_from_id(&self, party_id: u8) -> usize {
314 assert!(self.info().party_id != party_id);
315 let idx = if party_id > self.info().party_id {
316 party_id - 1
317 } else {
318 party_id
319 };
320
321 idx as _
322 }
323
324 pub(crate) fn sender_seed(&self, party_id: u8) -> &SenderOTSeed {
325 &self.other(party_id).send_ot_seed
326 }
327
328 pub(crate) fn receiver_seed(&self, party_id: u8) -> &ReceiverOTSeed {
329 &self.other(party_id).recv_ot_seed
330 }
331}
332
333impl Deref for Keyshare {
334 type Target = details::KeyshareInfo;
335
336 fn deref(&self) -> &Self::Target {
337 let bytes = &self.buffer[..Self::INFO];
338
339 bytemuck::from_bytes(bytes)
340 }
341}
342
343impl Keyshare {
344 pub fn root_chain_code(&self) -> [u8; 32] {
346 self.info().root_chain_code
347 }
348
349 pub fn root_public_key(&self) -> ProjectivePoint {
351 self.public_key()
352 }
353}
354
355impl Keyshare {
356 pub fn get_finger_print(&self) -> KeyFingerPrint {
358 get_finger_print(&self.root_public_key())
359 }
360
361 pub fn derive_with_offset(
369 &self,
370 chain_path: &DerivationPath,
371 ) -> Result<(Scalar, ProjectivePoint), BIP32Error> {
372 let mut pubkey = self.root_public_key();
373 let mut chain_code = self.root_chain_code();
374 let mut additive_offset = Scalar::ZERO;
375 for child_num in chain_path {
376 let (il_int, child_pubkey, child_chain_code) =
377 derive_child_pubkey(&pubkey, chain_code, child_num)?;
378 pubkey = child_pubkey;
379 chain_code = child_chain_code;
380 additive_offset += il_int;
381 }
382
383 Ok((additive_offset, pubkey))
385 }
386
387 pub fn derive_child_pubkey(
395 &self,
396 chain_path: &DerivationPath,
397 ) -> Result<ProjectivePoint, BIP32Error> {
398 let (_, child_pubkey) = self.derive_with_offset(chain_path)?;
399
400 Ok(child_pubkey)
401 }
402
403 pub fn derive_xpub(
412 &self,
413 prefix: Prefix,
414 chain_path: DerivationPath,
415 ) -> Result<XPubKey, BIP32Error> {
416 derive_xpub(
417 prefix,
418 &self.root_public_key(),
419 self.root_chain_code(),
420 chain_path,
421 )
422 }
423}