Skip to main content

sinktools/
lazy_sink_source.rs

1//! [`LazySinkSource`], and related items.
2
3use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6use std::sync::Arc;
7use std::task::Wake;
8
9use futures_util::task::AtomicWaker;
10use futures_util::{Sink, Stream, ready};
11
12#[derive(Default)]
13struct DualWaker {
14    sink: AtomicWaker,
15    stream: AtomicWaker,
16}
17
18impl Wake for DualWaker {
19    fn wake(self: Arc<Self>) {
20        self.sink.wake();
21        self.stream.wake();
22    }
23}
24
25enum SharedState<Fut, St, Si, Item> {
26    Uninit {
27        future: Pin<Box<Fut>>,
28    },
29    Thunkulating {
30        future: Pin<Box<Fut>>,
31        item: Option<Item>,
32        dual_waker: Arc<DualWaker>,
33    },
34    Done {
35        stream: Pin<Box<St>>,
36        sink: Pin<Box<Si>>,
37        buf: Option<Item>,
38    },
39    Taken,
40}
41
42/// A lazy sink-source that can be split into a sink and a source. The internal state is initialized when the first item is attempted to be pulled from the source half, or when the first item is sent to the sink half.
43pub struct LazySinkSource<Fut, St, Si, Item, Error> {
44    state: SharedState<Fut, St, Si, Item>,
45    _phantom: PhantomData<Error>,
46}
47
48impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
49    /// Creates a new `LazySinkSource` with the given initialization future.
50    pub fn new(future: Fut) -> Self {
51        Self {
52            state: SharedState::Uninit {
53                future: Box::pin(future),
54            },
55            _phantom: PhantomData,
56        }
57    }
58}
59
60impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
61where
62    Self: Unpin,
63    Fut: Future<Output = Result<(St, Si), Error>>,
64    St: Stream,
65    Si: Sink<Item>,
66    Error: From<Si::Error>,
67{
68    type Error = Error;
69
70    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71        let state = &mut self.get_mut().state;
72
73        if let SharedState::Uninit { .. } = &*state {
74            return Poll::Ready(Ok(()));
75        }
76
77        if let SharedState::Thunkulating {
78            future,
79            item,
80            dual_waker,
81        } = &mut *state
82        {
83            dual_waker.sink.register(cx.waker());
84            let waker = Waker::from(Arc::clone(dual_waker));
85
86            let mut new_context = Context::from_waker(&waker);
87
88            match future.as_mut().poll(&mut new_context) {
89                Poll::Ready(Ok((stream, sink))) => {
90                    let buf = item.take();
91                    *state = SharedState::Done {
92                        stream: Box::pin(stream),
93                        sink: Box::pin(sink),
94                        buf,
95                    };
96                }
97                Poll::Ready(Err(e)) => {
98                    return Poll::Ready(Err(e));
99                }
100                Poll::Pending => {
101                    return Poll::Pending;
102                }
103            }
104        }
105
106        if let SharedState::Done { sink, buf, .. } = &mut *state {
107            if buf.is_some() {
108                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
109                sink.as_mut().start_send(buf.take().unwrap())?;
110            }
111            let result = sink.as_mut().poll_ready(cx).map_err(From::from);
112            return result;
113        }
114
115        panic!("LazySinkHalf in invalid state.");
116    }
117
118    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
119        let state = &mut self.get_mut().state;
120
121        if let SharedState::Uninit { .. } = &*state {
122            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
123            if let SharedState::Uninit { future } = old_state {
124                *state = SharedState::Thunkulating {
125                    future,
126                    item: Some(item),
127                    dual_waker: Default::default(),
128                };
129
130                return Ok(());
131            }
132        }
133
134        if let SharedState::Thunkulating { .. } = &mut *state {
135            panic!("LazySinkHalf not ready.");
136        }
137
138        if let SharedState::Done { sink, buf, .. } = &mut *state {
139            debug_assert!(buf.is_none());
140            let result = sink.as_mut().start_send(item).map_err(From::from);
141            return result;
142        }
143
144        panic!("LazySinkHalf not ready.");
145    }
146
147    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148        let state = &mut self.get_mut().state;
149
150        if let SharedState::Uninit { .. } = &*state {
151            return Poll::Ready(Ok(()));
152        }
153
154        if let SharedState::Thunkulating {
155            future,
156            item,
157            dual_waker,
158        } = &mut *state
159        {
160            dual_waker.sink.register(cx.waker());
161            let waker = Waker::from(Arc::clone(dual_waker));
162
163            let mut new_context = Context::from_waker(&waker);
164
165            match future.as_mut().poll(&mut new_context) {
166                Poll::Ready(Ok((stream, sink))) => {
167                    let buf = item.take();
168                    *state = SharedState::Done {
169                        stream: Box::pin(stream),
170                        sink: Box::pin(sink),
171                        buf,
172                    };
173                }
174                Poll::Ready(Err(e)) => {
175                    return Poll::Ready(Err(e));
176                }
177                Poll::Pending => {
178                    return Poll::Pending;
179                }
180            }
181        }
182
183        if let SharedState::Done { sink, buf, .. } = &mut *state {
184            if buf.is_some() {
185                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
186                sink.as_mut().start_send(buf.take().unwrap())?;
187            }
188            let result = sink.as_mut().poll_flush(cx).map_err(From::from);
189            return result;
190        }
191
192        panic!("LazySinkHalf in invalid state.");
193    }
194
195    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196        let state = &mut self.get_mut().state;
197
198        if let SharedState::Uninit { .. } = &*state {
199            return Poll::Ready(Ok(()));
200        }
201
202        if let SharedState::Thunkulating {
203            future,
204            item,
205            dual_waker,
206        } = &mut *state
207        {
208            dual_waker.sink.register(cx.waker());
209            let waker = Waker::from(Arc::clone(dual_waker));
210
211            let mut new_context = Context::from_waker(&waker);
212
213            match future.as_mut().poll(&mut new_context) {
214                Poll::Ready(Ok((stream, sink))) => {
215                    let buf = item.take();
216                    *state = SharedState::Done {
217                        stream: Box::pin(stream),
218                        sink: Box::pin(sink),
219                        buf,
220                    };
221                }
222                Poll::Ready(Err(e)) => {
223                    return Poll::Ready(Err(e));
224                }
225                Poll::Pending => {
226                    return Poll::Pending;
227                }
228            }
229        }
230
231        if let SharedState::Done { sink, buf, .. } = &mut *state {
232            if buf.is_some() {
233                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
234                sink.as_mut().start_send(buf.take().unwrap())?;
235            }
236            let result = sink.as_mut().poll_close(cx).map_err(From::from);
237            return result;
238        }
239
240        panic!("LazySinkHalf in invalid state.");
241    }
242}
243
244impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
245where
246    Self: Unpin,
247    Fut: Future<Output = Result<(St, Si), Error>>,
248    St: Stream,
249    Si: Sink<Item>,
250{
251    type Item = St::Item;
252
253    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254        let state = &mut self.get_mut().state;
255
256        if let SharedState::Uninit { .. } = &*state {
257            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
258            if let SharedState::Uninit { future } = old_state {
259                *state = SharedState::Thunkulating {
260                    future,
261                    item: None,
262                    dual_waker: Default::default(),
263                };
264            } else {
265                unreachable!();
266            }
267        }
268
269        if let SharedState::Thunkulating {
270            future,
271            item,
272            dual_waker,
273        } = &mut *state
274        {
275            dual_waker.stream.register(cx.waker());
276            let waker = Waker::from(Arc::clone(dual_waker));
277
278            let mut new_context = Context::from_waker(&waker);
279
280            match future.as_mut().poll(&mut new_context) {
281                Poll::Ready(Ok((stream, sink))) => {
282                    let buf = item.take();
283                    *state = SharedState::Done {
284                        stream: Box::pin(stream),
285                        sink: Box::pin(sink),
286                        buf,
287                    };
288                }
289
290                Poll::Ready(Err(_)) => {
291                    return Poll::Ready(None);
292                }
293
294                Poll::Pending => {
295                    return Poll::Pending;
296                }
297            }
298        }
299
300        if let SharedState::Done { stream, .. } = &mut *state {
301            let result = stream.as_mut().poll_next(cx);
302            match &result {
303                Poll::Ready(Some(_)) => {}
304                Poll::Ready(None) => {}
305                Poll::Pending => {}
306            }
307            return result;
308        }
309
310        panic!("LazySourceHalf in invalid state.");
311    }
312}
313
314#[cfg(test)]
315mod test {
316    use futures_util::{SinkExt, StreamExt};
317    use tokio_util::sync::PollSendError;
318
319    use super::*;
320
321    #[tokio::test(flavor = "current_thread")]
322    async fn stream_drives_initialization() {
323        let local = tokio::task::LocalSet::new();
324        local
325            .run_until(async {
326                let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
327
328                let sink_source = LazySinkSource::new(async move {
329                    let () = init_lazy_recv.await.unwrap();
330                    let (send, recv) = tokio::sync::mpsc::channel(1);
331                    let sink = tokio_util::sync::PollSender::new(send);
332                    let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
333                    Ok::<_, PollSendError<_>>((stream, sink))
334                });
335
336                let (mut sink, mut stream) = sink_source.split();
337
338                // Ensures stream starts the lazy.
339                let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>();
340                let stream_task = tokio::task::spawn_local(async move {
341                    stream_init_send.send(()).unwrap();
342                    (stream.next().await.unwrap(), stream.next().await.unwrap())
343                });
344                let sink_task = tokio::task::spawn_local(async move {
345                    stream_init_recv.await.unwrap();
346                    SinkExt::send(&mut sink, "test1").await.unwrap();
347                    SinkExt::send(&mut sink, "test2").await.unwrap();
348                });
349
350                // finish the future.
351                init_lazy_send.send(()).unwrap();
352
353                tokio::task::yield_now().await;
354
355                assert!(sink_task.is_finished());
356                assert_eq!(("test1", "test2"), stream_task.await.unwrap());
357                sink_task.await.unwrap();
358            })
359            .await;
360    }
361
362    #[tokio::test(flavor = "current_thread")]
363    async fn sink_drives_initialization() {
364        let local = tokio::task::LocalSet::new();
365        local
366            .run_until(async {
367                let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
368
369                let sink_source = LazySinkSource::new(async move {
370                    let () = init_lazy_recv.await.unwrap();
371                    let (send, recv) = tokio::sync::mpsc::channel(1);
372                    let sink = tokio_util::sync::PollSender::new(send);
373                    let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
374                    Ok::<_, PollSendError<_>>((stream, sink))
375                });
376
377                let (mut sink, mut stream) = sink_source.split();
378
379                // Ensures stream starts the lazy.
380                let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>();
381                let stream_task = tokio::task::spawn_local(async move {
382                    sink_init_recv.await.unwrap();
383                    (stream.next().await.unwrap(), stream.next().await.unwrap())
384                });
385                let sink_task = tokio::task::spawn_local(async move {
386                    sink_init_send.send(()).unwrap();
387                    SinkExt::send(&mut sink, "test1").await.unwrap();
388                    SinkExt::send(&mut sink, "test2").await.unwrap();
389                });
390
391                // finish the future.
392                init_lazy_send.send(()).unwrap();
393
394                tokio::task::yield_now().await;
395
396                assert!(sink_task.is_finished());
397                assert_eq!(("test1", "test2"), stream_task.await.unwrap());
398                sink_task.await.unwrap();
399            })
400            .await;
401    }
402
403    #[tokio::test(flavor = "current_thread")]
404    async fn tcp_stream_drives_initialization() {
405        use tokio::net::{TcpListener, TcpStream};
406        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
407
408        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
409
410        let local = tokio::task::LocalSet::new();
411        local
412            .run_until(async {
413                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
414                let addr = listener.local_addr().unwrap();
415                println!("Listening on {}", addr);
416
417                let sink_source = LazySinkSource::new(async move {
418                    // initialization is at least partially started now.
419                    initialization_tx.send(()).unwrap();
420
421                    let (stream, _) = listener.accept().await.unwrap();
422                    let (rx, tx) = stream.into_split();
423                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
424                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
425                    Ok::<_, std::io::Error>((fr, fw))
426                });
427
428                let (mut sink, mut stream) = sink_source.split();
429
430                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
431
432                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
433
434                let sink_task = tokio::task::spawn_local(async move {
435                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
436                        .await
437                        .unwrap();
438                });
439
440                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
441                for _ in 0..20 {
442                    tokio::task::yield_now().await
443                }
444
445                // trigger further initialization of the future.
446                let mut socket = TcpStream::connect(addr).await.unwrap();
447                let (client_rx, client_tx) = socket.split();
448                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
449                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
450
451                // try to be really sure that the effects of the above initialization completing are propagated.
452                for _ in 0..20 {
453                    tokio::task::yield_now().await
454                }
455
456                assert!(!stream_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
457
458                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
459                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
460                    .await
461                    .unwrap();
462
463                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
464                sink_task.await.unwrap();
465
466                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
467            })
468            .await;
469    }
470
471    #[tokio::test(flavor = "current_thread")]
472    async fn tcp_sink_drives_initialization() {
473        use tokio::net::{TcpListener, TcpStream};
474        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
475
476        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
477
478        let local = tokio::task::LocalSet::new();
479        local
480            .run_until(async {
481                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
482                let addr = listener.local_addr().unwrap();
483                println!("Listening on {}", addr);
484
485
486                let sink_source = LazySinkSource::new(async move {
487                    // initialization is at least partially started now.
488                    initialization_tx.send(()).unwrap();
489
490                    let (stream, _) = listener.accept().await.unwrap();
491                    let (rx, tx) = stream.into_split();
492                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
493                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
494                    Ok::<_, std::io::Error>((fr, fw))
495                });
496
497                let (mut sink, mut stream) = sink_source.split();
498
499                let sink_task = tokio::task::spawn_local(async move {
500                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
501                        .await
502                        .unwrap();
503                });
504
505                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
506
507                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
508
509                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
510                for _ in 0..20 {
511                    tokio::task::yield_now().await
512                }
513
514                assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now.");
515
516                // trigger further initialization of the future.
517                let mut socket = TcpStream::connect(addr).await.unwrap();
518                let (client_rx, client_tx) = socket.split();
519                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
520                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
521
522                // try to be really sure that the effects of the above initialization completing are propagated.
523                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
524
525                assert!(sink_task.is_finished()); // Sink should have sent its item.
526
527                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
528
529                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
530                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
531                    .await
532                    .unwrap();
533
534                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
535                sink_task.await.unwrap();
536            })
537            .await;
538    }
539}