Skip to main content

sinktools/
demux_map_lazy.rs

1//! [`LazyDemuxSink`] and related items.
2use core::fmt::Debug;
3use core::hash::Hash;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use std::collections::HashMap;
7
8use crate::{Sink, ready_both};
9
10/// Sink which receives keys paired with items `(Key, Item)`, and lazily creates sinks on first use.
11pub struct LazyDemuxSink<Key, Si, Func> {
12    sinks: HashMap<Key, Si>,
13    func: Func,
14}
15
16impl<Key, Si, Func> LazyDemuxSink<Key, Si, Func> {
17    /// Create with the given initialization function.
18    pub fn new<Item>(func: Func) -> Self
19    where
20        Self: Sink<(Key, Item)>,
21    {
22        Self {
23            sinks: HashMap::new(),
24            func,
25        }
26    }
27}
28
29impl<Key, Si, Item, Func> Sink<(Key, Item)> for LazyDemuxSink<Key, Si, Func>
30where
31    Key: Eq + Hash + Debug + Unpin,
32    Si: Sink<Item> + Unpin,
33    Func: FnMut(&Key) -> Si + Unpin,
34{
35    type Error = Si::Error;
36
37    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38        #[expect(
39            clippy::disallowed_methods,
40            reason = "nondeterministic iteration order, the `try_fold` is not order-dependent"
41        )]
42        self.get_mut()
43            .sinks
44            .values_mut()
45            .try_fold(Poll::Ready(()), |poll, sink| {
46                ready_both!(poll, Pin::new(sink).poll_ready(cx)?);
47                Poll::Ready(Ok(()))
48            })
49    }
50
51    fn start_send(self: Pin<&mut Self>, item: (Key, Item)) -> Result<(), Self::Error> {
52        let this = self.get_mut();
53        let sink = this
54            .sinks
55            .entry(item.0)
56            .or_insert_with_key(|k| (this.func)(k));
57        Pin::new(sink).start_send(item.1)
58    }
59
60    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61        #[expect(
62            clippy::disallowed_methods,
63            reason = "nondeterministic iteration order, the `try_fold` is not order-dependent"
64        )]
65        self.get_mut()
66            .sinks
67            .values_mut()
68            .try_fold(Poll::Ready(()), |poll, sink| {
69                ready_both!(poll, Pin::new(sink).poll_flush(cx)?);
70                Poll::Ready(Ok(()))
71            })
72    }
73
74    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        #[expect(
76            clippy::disallowed_methods,
77            reason = "nondeterministic iteration order, the `try_fold` is not order-dependent"
78        )]
79        self.get_mut()
80            .sinks
81            .values_mut()
82            .try_fold(Poll::Ready(()), |poll, sink| {
83                ready_both!(poll, Pin::new(sink).poll_close(cx)?);
84                Poll::Ready(Ok(()))
85            })
86    }
87}
88
89/// Creates a `LazyDemuxSink` that lazily creates sinks on first use for each key.
90///
91/// This requires sinks `Si` to be `Unpin`. If your sinks are not `Unpin`, first wrap them in `Box::pin` to make them `Unpin`.
92pub fn demux_map_lazy<Key, Si, Item, Func>(func: Func) -> LazyDemuxSink<Key, Si, Func>
93where
94    Key: Eq + Hash + Debug + Unpin,
95    Si: Sink<Item> + Unpin,
96    Func: FnMut(&Key) -> Si + Unpin,
97{
98    LazyDemuxSink::new(func)
99}
100
101#[cfg(test)]
102mod test {
103    use core::cell::RefCell;
104    use core::pin::pin;
105    use std::collections::HashMap;
106    use std::rc::Rc;
107
108    use futures_util::SinkExt;
109
110    use super::*;
111    use crate::for_each::ForEach;
112
113    #[tokio::test]
114    async fn test_lazy_demux_sink() {
115        let outputs: Rc<RefCell<HashMap<String, Vec<u8>>>> = Rc::new(RefCell::new(HashMap::new()));
116        let outputs_clone = outputs.clone();
117
118        let mut sink = demux_map_lazy(move |key: &String| {
119            let key = key.clone();
120            let outputs = outputs_clone.clone();
121            ForEach::new(move |item: &[u8]| {
122                outputs
123                    .borrow_mut()
124                    .entry(key.clone())
125                    .or_default()
126                    .extend_from_slice(item);
127            })
128        });
129
130        sink.send(("a".to_owned(), b"test1".as_slice()))
131            .await
132            .unwrap();
133        sink.send(("b".to_owned(), b"test2".as_slice()))
134            .await
135            .unwrap();
136        sink.send(("a".to_owned(), b"test3".as_slice()))
137            .await
138            .unwrap();
139        sink.flush().await.unwrap();
140        sink.close().await.unwrap();
141
142        let outputs = outputs.borrow();
143        assert_eq!(outputs.get("a").unwrap().as_slice(), b"test1test3");
144        assert_eq!(outputs.get("b").unwrap().as_slice(), b"test2");
145    }
146
147    #[test]
148    fn test_lazy_demux_sink_good() {
149        use core::task::Context;
150
151        let outputs: Rc<RefCell<HashMap<String, Vec<u8>>>> = Rc::new(RefCell::new(HashMap::new()));
152        let outputs_clone = outputs.clone();
153
154        let mut sink = pin!(demux_map_lazy(move |key: &String| {
155            let outputs = outputs_clone.clone();
156            let key = key.clone();
157            ForEach::new(move |item: &[u8]| {
158                outputs
159                    .borrow_mut()
160                    .entry(key.clone())
161                    .or_default()
162                    .extend_from_slice(item);
163            })
164        }));
165
166        let cx = &mut Context::from_waker(futures_task::noop_waker_ref());
167
168        assert_eq!(Poll::Ready(Ok(())), sink.as_mut().poll_ready(cx));
169        assert_eq!(
170            Ok(()),
171            sink.as_mut()
172                .start_send(("a".to_owned(), b"test1".as_slice()))
173        );
174        assert_eq!(
175            Ok(()),
176            sink.as_mut()
177                .start_send(("b".to_owned(), b"test2".as_slice()))
178        );
179        assert_eq!(
180            Ok(()),
181            sink.as_mut()
182                .start_send(("a".to_owned(), b"test3".as_slice()))
183        );
184        assert_eq!(Poll::Ready(Ok(())), sink.as_mut().poll_flush(cx));
185        assert_eq!(Poll::Ready(Ok(())), sink.as_mut().poll_close(cx));
186
187        let outputs = outputs.borrow();
188        assert_eq!(outputs.get("a").unwrap().as_slice(), b"test1test3");
189        assert_eq!(outputs.get("b").unwrap().as_slice(), b"test2");
190    }
191}