use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use futures::Stream;
use futures_util::stream::FuturesUnordered;
use pin_project::pin_project;
use tokio::sync::{Semaphore, TryAcquireError};
use vortex_error::VortexUnwrap;
#[pin_project]
struct SizedFut<Fut> {
#[pin]
inner: Fut,
size_in_bytes: usize,
}
impl<Fut: Future> Future for SizedFut<Fut> {
type Output = (Fut::Output, usize);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let size_in_bytes = self.size_in_bytes;
let inner = ready!(self.project().inner.poll(cx));
Poll::Ready((inner, size_in_bytes))
}
}
#[pin_project]
pub struct SizeLimitedStream<Fut> {
#[pin]
inflight: FuturesUnordered<SizedFut<Fut>>,
bytes_available: Semaphore,
}
impl<Fut> SizeLimitedStream<Fut> {
pub fn new(max_bytes: usize) -> Self {
Self {
inflight: FuturesUnordered::new(),
bytes_available: Semaphore::new(max_bytes),
}
}
pub fn bytes_available(&self) -> usize {
self.bytes_available.available_permits()
}
}
impl<Fut> SizeLimitedStream<Fut>
where
Fut: Future,
{
pub async fn push(&self, fut: Fut, bytes: usize) {
self.bytes_available
.acquire_many(bytes.try_into().vortex_unwrap())
.await
.unwrap_or_else(|_| unreachable!("pushing to closed semaphore"))
.forget();
let sized_fut = SizedFut {
inner: fut,
size_in_bytes: bytes,
};
self.inflight.push(sized_fut);
}
pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
match self
.bytes_available
.try_acquire_many(bytes.try_into().vortex_unwrap())
{
Ok(permits) => {
permits.forget();
let sized_fut = SizedFut {
inner: fut,
size_in_bytes: bytes,
};
self.inflight.push(sized_fut);
Ok(())
}
Err(acquire_err) => match acquire_err {
TryAcquireError::Closed => {
unreachable!("try_pushing to closed semaphore");
}
TryAcquireError::NoPermits => Err(fut),
},
}
}
}
impl<Fut> Stream for SizeLimitedStream<Fut>
where
Fut: Future,
{
type Item = Fut::Output;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.inflight.poll_next(cx)) {
None => Poll::Ready(None),
Some((result, bytes_read)) => {
this.bytes_available.add_permits(bytes_read);
Poll::Ready(Some(result))
}
}
}
}
#[cfg(test)]
mod tests {
use std::{future, io};
use futures_util::future::BoxFuture;
use futures_util::{FutureExt, StreamExt};
use vortex_buffer::Buffer;
use crate::limit::SizeLimitedStream;
async fn make_future(len: usize) -> Buffer {
"a".as_bytes().iter().copied().cycle().take(len).collect()
}
#[tokio::test]
async fn test_size_limit() {
let mut size_limited = SizeLimitedStream::new(10);
size_limited.push(make_future(5), 5).await;
size_limited.push(make_future(5), 5).await;
assert!(size_limited.try_push(make_future(1), 1).is_err());
assert!(size_limited.next().await.is_some());
assert!(size_limited.try_push(make_future(1), 1).is_ok());
}
#[tokio::test]
async fn test_does_not_leak_permits() {
let bad_fut: BoxFuture<'static, io::Result<Buffer>> =
future::ready(Err(io::Error::new(io::ErrorKind::Other, "badness"))).boxed();
let good_fut: BoxFuture<'static, io::Result<Buffer>> =
future::ready(Ok(Buffer::from("aaaaa".as_bytes()))).boxed();
let mut size_limited = SizeLimitedStream::new(10);
size_limited.push(bad_fut, 10).await;
let good_fut = size_limited
.try_push(good_fut, 5)
.expect_err("try_push should fail");
let next = size_limited.next().await.unwrap();
assert!(next.is_err());
assert_eq!(size_limited.bytes_available(), 10);
assert!(size_limited.try_push(good_fut, 5).is_ok());
}
}