1use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6use std::sync::Arc;
7use std::task::Wake;
8
9use futures_util::task::AtomicWaker;
10use futures_util::{Sink, Stream, ready};
11
12#[derive(Default)]
13struct DualWaker {
14 sink: AtomicWaker,
15 stream: AtomicWaker,
16}
17
18impl Wake for DualWaker {
19 fn wake(self: Arc<Self>) {
20 self.sink.wake();
21 self.stream.wake();
22 }
23}
24
25enum SharedState<Fut, St, Si, Item> {
26 Uninit {
27 future: Pin<Box<Fut>>,
28 },
29 Thunkulating {
30 future: Pin<Box<Fut>>,
31 item: Option<Item>,
32 dual_waker: Arc<DualWaker>,
33 },
34 Done {
35 stream: Pin<Box<St>>,
36 sink: Pin<Box<Si>>,
37 buf: Option<Item>,
38 },
39 Taken,
40}
41
42pub struct LazySinkSource<Fut, St, Si, Item, Error> {
44 state: SharedState<Fut, St, Si, Item>,
45 _phantom: PhantomData<Error>,
46}
47
48impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
49 pub fn new(future: Fut) -> Self {
51 Self {
52 state: SharedState::Uninit {
53 future: Box::pin(future),
54 },
55 _phantom: PhantomData,
56 }
57 }
58}
59
60impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
61where
62 Self: Unpin,
63 Fut: Future<Output = Result<(St, Si), Error>>,
64 St: Stream,
65 Si: Sink<Item>,
66 Error: From<Si::Error>,
67{
68 type Error = Error;
69
70 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71 let state = &mut self.get_mut().state;
72
73 if let SharedState::Uninit { .. } = &*state {
74 return Poll::Ready(Ok(()));
75 }
76
77 if let SharedState::Thunkulating {
78 future,
79 item,
80 dual_waker,
81 } = &mut *state
82 {
83 dual_waker.sink.register(cx.waker());
84 let waker = Waker::from(Arc::clone(dual_waker));
85
86 let mut new_context = Context::from_waker(&waker);
87
88 match future.as_mut().poll(&mut new_context) {
89 Poll::Ready(Ok((stream, sink))) => {
90 let buf = item.take();
91 *state = SharedState::Done {
92 stream: Box::pin(stream),
93 sink: Box::pin(sink),
94 buf,
95 };
96 }
97 Poll::Ready(Err(e)) => {
98 return Poll::Ready(Err(e));
99 }
100 Poll::Pending => {
101 return Poll::Pending;
102 }
103 }
104 }
105
106 if let SharedState::Done { sink, buf, .. } = &mut *state {
107 if buf.is_some() {
108 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
109 sink.as_mut().start_send(buf.take().unwrap())?;
110 }
111 let result = sink.as_mut().poll_ready(cx).map_err(From::from);
112 return result;
113 }
114
115 panic!("LazySinkHalf in invalid state.");
116 }
117
118 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
119 let state = &mut self.get_mut().state;
120
121 if let SharedState::Uninit { .. } = &*state {
122 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
123 if let SharedState::Uninit { future } = old_state {
124 *state = SharedState::Thunkulating {
125 future,
126 item: Some(item),
127 dual_waker: Default::default(),
128 };
129
130 return Ok(());
131 }
132 }
133
134 if let SharedState::Thunkulating { .. } = &mut *state {
135 panic!("LazySinkHalf not ready.");
136 }
137
138 if let SharedState::Done { sink, buf, .. } = &mut *state {
139 debug_assert!(buf.is_none());
140 let result = sink.as_mut().start_send(item).map_err(From::from);
141 return result;
142 }
143
144 panic!("LazySinkHalf not ready.");
145 }
146
147 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148 let state = &mut self.get_mut().state;
149
150 if let SharedState::Uninit { .. } = &*state {
151 return Poll::Ready(Ok(()));
152 }
153
154 if let SharedState::Thunkulating {
155 future,
156 item,
157 dual_waker,
158 } = &mut *state
159 {
160 dual_waker.sink.register(cx.waker());
161 let waker = Waker::from(Arc::clone(dual_waker));
162
163 let mut new_context = Context::from_waker(&waker);
164
165 match future.as_mut().poll(&mut new_context) {
166 Poll::Ready(Ok((stream, sink))) => {
167 let buf = item.take();
168 *state = SharedState::Done {
169 stream: Box::pin(stream),
170 sink: Box::pin(sink),
171 buf,
172 };
173 }
174 Poll::Ready(Err(e)) => {
175 return Poll::Ready(Err(e));
176 }
177 Poll::Pending => {
178 return Poll::Pending;
179 }
180 }
181 }
182
183 if let SharedState::Done { sink, buf, .. } = &mut *state {
184 if buf.is_some() {
185 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
186 sink.as_mut().start_send(buf.take().unwrap())?;
187 }
188 let result = sink.as_mut().poll_flush(cx).map_err(From::from);
189 return result;
190 }
191
192 panic!("LazySinkHalf in invalid state.");
193 }
194
195 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196 let state = &mut self.get_mut().state;
197
198 if let SharedState::Uninit { .. } = &*state {
199 return Poll::Ready(Ok(()));
200 }
201
202 if let SharedState::Thunkulating {
203 future,
204 item,
205 dual_waker,
206 } = &mut *state
207 {
208 dual_waker.sink.register(cx.waker());
209 let waker = Waker::from(Arc::clone(dual_waker));
210
211 let mut new_context = Context::from_waker(&waker);
212
213 match future.as_mut().poll(&mut new_context) {
214 Poll::Ready(Ok((stream, sink))) => {
215 let buf = item.take();
216 *state = SharedState::Done {
217 stream: Box::pin(stream),
218 sink: Box::pin(sink),
219 buf,
220 };
221 }
222 Poll::Ready(Err(e)) => {
223 return Poll::Ready(Err(e));
224 }
225 Poll::Pending => {
226 return Poll::Pending;
227 }
228 }
229 }
230
231 if let SharedState::Done { sink, buf, .. } = &mut *state {
232 if buf.is_some() {
233 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
234 sink.as_mut().start_send(buf.take().unwrap())?;
235 }
236 let result = sink.as_mut().poll_close(cx).map_err(From::from);
237 return result;
238 }
239
240 panic!("LazySinkHalf in invalid state.");
241 }
242}
243
244impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
245where
246 Self: Unpin,
247 Fut: Future<Output = Result<(St, Si), Error>>,
248 St: Stream,
249 Si: Sink<Item>,
250{
251 type Item = St::Item;
252
253 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254 let state = &mut self.get_mut().state;
255
256 if let SharedState::Uninit { .. } = &*state {
257 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
258 if let SharedState::Uninit { future } = old_state {
259 *state = SharedState::Thunkulating {
260 future,
261 item: None,
262 dual_waker: Default::default(),
263 };
264 } else {
265 unreachable!();
266 }
267 }
268
269 if let SharedState::Thunkulating {
270 future,
271 item,
272 dual_waker,
273 } = &mut *state
274 {
275 dual_waker.stream.register(cx.waker());
276 let waker = Waker::from(Arc::clone(dual_waker));
277
278 let mut new_context = Context::from_waker(&waker);
279
280 match future.as_mut().poll(&mut new_context) {
281 Poll::Ready(Ok((stream, sink))) => {
282 let buf = item.take();
283 *state = SharedState::Done {
284 stream: Box::pin(stream),
285 sink: Box::pin(sink),
286 buf,
287 };
288 }
289
290 Poll::Ready(Err(_)) => {
291 return Poll::Ready(None);
292 }
293
294 Poll::Pending => {
295 return Poll::Pending;
296 }
297 }
298 }
299
300 if let SharedState::Done { stream, .. } = &mut *state {
301 let result = stream.as_mut().poll_next(cx);
302 match &result {
303 Poll::Ready(Some(_)) => {}
304 Poll::Ready(None) => {}
305 Poll::Pending => {}
306 }
307 return result;
308 }
309
310 panic!("LazySourceHalf in invalid state.");
311 }
312}
313
314#[cfg(test)]
315mod test {
316 use futures_util::{SinkExt, StreamExt};
317 use tokio_util::sync::PollSendError;
318
319 use super::*;
320
321 #[tokio::test(flavor = "current_thread")]
322 async fn stream_drives_initialization() {
323 let local = tokio::task::LocalSet::new();
324 local
325 .run_until(async {
326 let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
327
328 let sink_source = LazySinkSource::new(async move {
329 let () = init_lazy_recv.await.unwrap();
330 let (send, recv) = tokio::sync::mpsc::channel(1);
331 let sink = tokio_util::sync::PollSender::new(send);
332 let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
333 Ok::<_, PollSendError<_>>((stream, sink))
334 });
335
336 let (mut sink, mut stream) = sink_source.split();
337
338 let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>();
340 let stream_task = tokio::task::spawn_local(async move {
341 stream_init_send.send(()).unwrap();
342 (stream.next().await.unwrap(), stream.next().await.unwrap())
343 });
344 let sink_task = tokio::task::spawn_local(async move {
345 stream_init_recv.await.unwrap();
346 SinkExt::send(&mut sink, "test1").await.unwrap();
347 SinkExt::send(&mut sink, "test2").await.unwrap();
348 });
349
350 init_lazy_send.send(()).unwrap();
352
353 tokio::task::yield_now().await;
354
355 assert!(sink_task.is_finished());
356 assert_eq!(("test1", "test2"), stream_task.await.unwrap());
357 sink_task.await.unwrap();
358 })
359 .await;
360 }
361
362 #[tokio::test(flavor = "current_thread")]
363 async fn sink_drives_initialization() {
364 let local = tokio::task::LocalSet::new();
365 local
366 .run_until(async {
367 let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
368
369 let sink_source = LazySinkSource::new(async move {
370 let () = init_lazy_recv.await.unwrap();
371 let (send, recv) = tokio::sync::mpsc::channel(1);
372 let sink = tokio_util::sync::PollSender::new(send);
373 let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
374 Ok::<_, PollSendError<_>>((stream, sink))
375 });
376
377 let (mut sink, mut stream) = sink_source.split();
378
379 let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>();
381 let stream_task = tokio::task::spawn_local(async move {
382 sink_init_recv.await.unwrap();
383 (stream.next().await.unwrap(), stream.next().await.unwrap())
384 });
385 let sink_task = tokio::task::spawn_local(async move {
386 sink_init_send.send(()).unwrap();
387 SinkExt::send(&mut sink, "test1").await.unwrap();
388 SinkExt::send(&mut sink, "test2").await.unwrap();
389 });
390
391 init_lazy_send.send(()).unwrap();
393
394 tokio::task::yield_now().await;
395
396 assert!(sink_task.is_finished());
397 assert_eq!(("test1", "test2"), stream_task.await.unwrap());
398 sink_task.await.unwrap();
399 })
400 .await;
401 }
402
403 #[tokio::test(flavor = "current_thread")]
404 async fn tcp_stream_drives_initialization() {
405 use tokio::net::{TcpListener, TcpStream};
406 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
407
408 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
409
410 let local = tokio::task::LocalSet::new();
411 local
412 .run_until(async {
413 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
414 let addr = listener.local_addr().unwrap();
415 println!("Listening on {}", addr);
416
417 let sink_source = LazySinkSource::new(async move {
418 initialization_tx.send(()).unwrap();
420
421 let (stream, _) = listener.accept().await.unwrap();
422 let (rx, tx) = stream.into_split();
423 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
424 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
425 Ok::<_, std::io::Error>((fr, fw))
426 });
427
428 let (mut sink, mut stream) = sink_source.split();
429
430 let stream_task = tokio::task::spawn_local(async move { stream.next().await });
431
432 initialization_rx.await.unwrap(); let sink_task = tokio::task::spawn_local(async move {
435 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
436 .await
437 .unwrap();
438 });
439
440 for _ in 0..20 {
442 tokio::task::yield_now().await
443 }
444
445 let mut socket = TcpStream::connect(addr).await.unwrap();
447 let (client_rx, client_tx) = socket.split();
448 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
449 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
450
451 for _ in 0..20 {
453 tokio::task::yield_now().await
454 }
455
456 assert!(!stream_task.is_finished()); SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
460 .await
461 .unwrap();
462
463 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
464 sink_task.await.unwrap();
465
466 assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
467 })
468 .await;
469 }
470
471 #[tokio::test(flavor = "current_thread")]
472 async fn tcp_sink_drives_initialization() {
473 use tokio::net::{TcpListener, TcpStream};
474 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
475
476 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
477
478 let local = tokio::task::LocalSet::new();
479 local
480 .run_until(async {
481 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
482 let addr = listener.local_addr().unwrap();
483 println!("Listening on {}", addr);
484
485
486 let sink_source = LazySinkSource::new(async move {
487 initialization_tx.send(()).unwrap();
489
490 let (stream, _) = listener.accept().await.unwrap();
491 let (rx, tx) = stream.into_split();
492 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
493 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
494 Ok::<_, std::io::Error>((fr, fw))
495 });
496
497 let (mut sink, mut stream) = sink_source.split();
498
499 let sink_task = tokio::task::spawn_local(async move {
500 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
501 .await
502 .unwrap();
503 });
504
505 initialization_rx.await.unwrap(); let stream_task = tokio::task::spawn_local(async move { stream.next().await });
508
509 for _ in 0..20 {
511 tokio::task::yield_now().await
512 }
513
514 assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now.");
515
516 let mut socket = TcpStream::connect(addr).await.unwrap();
518 let (client_rx, client_tx) = socket.split();
519 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
520 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
521
522 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
524
525 assert!(sink_task.is_finished()); assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
528
529 SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
531 .await
532 .unwrap();
533
534 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
535 sink_task.await.unwrap();
536 })
537 .await;
538 }
539}