use http::{StreamId, Header, HttpResult, HttpScheme, ErrorCode};
use http::frame::{HttpSetting, PingFrame};
use http::connection::{SendFrame, ReceiveFrame, HttpConnection, EndStream, SendStatus};
use http::session::{Session, SessionState, Stream, DefaultStream, DefaultSessionState};
use http::session::Server as ServerMarker;
use http::priority::SimplePrioritizer;
pub trait StreamFactory {
type Stream: Stream;
fn create(&mut self, id: StreamId) -> Self::Stream;
}
pub struct ServerSession<'a, State, F, S>
where State: SessionState + 'a,
S: SendFrame + 'a,
F: StreamFactory<Stream = State::Stream> + 'a
{
state: &'a mut State,
factory: &'a mut F,
sender: &'a mut S,
}
impl<'a, State, F, S> ServerSession<'a, State, F, S>
where State: SessionState + 'a,
S: SendFrame + 'a,
F: StreamFactory<Stream = State::Stream> + 'a
{
#[inline]
pub fn new(state: &'a mut State,
factory: &'a mut F,
sender: &'a mut S)
-> ServerSession<'a, State, F, S> {
ServerSession {
state: state,
factory: factory,
sender: sender,
}
}
}
impl<'a, State, F, S> Session for ServerSession<'a, State, F, S>
where State: SessionState + 'a,
S: SendFrame + 'a,
F: StreamFactory<Stream = State::Stream> + 'a
{
fn new_data_chunk(&mut self,
stream_id: StreamId,
data: &[u8],
_: &mut HttpConnection)
-> HttpResult<()> {
debug!("Data chunk for stream {}", stream_id);
let mut stream = match self.state.get_stream_mut(stream_id) {
None => {
debug!("Received a frame for an unknown stream!");
return Ok(());
}
Some(stream) => stream,
};
stream.new_data_chunk(data);
Ok(())
}
fn new_headers<'n, 'v>(&mut self,
stream_id: StreamId,
headers: Vec<Header<'n, 'v>>,
_conn: &mut HttpConnection)
-> HttpResult<()> {
debug!("Headers for stream {}", stream_id);
if let Some(stream) = self.state.get_stream_mut(stream_id) {
stream.set_headers(headers);
return Ok(());
};
let mut stream = self.factory.create(stream_id);
stream.set_headers(headers);
let _ = self.state.insert_incoming(stream_id, stream);
Ok(())
}
fn end_of_stream(&mut self, stream_id: StreamId, _: &mut HttpConnection) -> HttpResult<()> {
debug!("End of stream {}", stream_id);
let mut stream = match self.state.get_stream_mut(stream_id) {
None => {
debug!("Received a frame for an unknown stream!");
return Ok(());
}
Some(stream) => stream,
};
stream.close_remote();
Ok(())
}
fn rst_stream(&mut self,
stream_id: StreamId,
error_code: ErrorCode,
_: &mut HttpConnection)
-> HttpResult<()> {
debug!("RST_STREAM id={:?}, error={:?}", stream_id, error_code);
self.state.get_stream_mut(stream_id).map(|stream| stream.on_rst_stream(error_code));
Ok(())
}
fn new_settings(&mut self,
_settings: Vec<HttpSetting>,
conn: &mut HttpConnection)
-> HttpResult<()> {
debug!("Sending a SETTINGS ack");
conn.sender(self.sender).send_settings_ack()
}
fn on_ping(&mut self, ping: &PingFrame, conn: &mut HttpConnection) -> HttpResult<()> {
debug!("Sending a PING ack");
conn.sender(self.sender).send_ping_ack(ping.opaque_data())
}
fn on_pong(&mut self, _ping: &PingFrame, _conn: &mut HttpConnection) -> HttpResult<()> {
debug!("Received a PING ack");
Ok(())
}
}
pub struct ServerConnection<F, State = DefaultSessionState<ServerMarker, DefaultStream>>
where State: SessionState,
F: StreamFactory<Stream = State::Stream>
{
conn: HttpConnection,
pub state: State,
factory: F,
}
impl<F, State> ServerConnection<F, State>
where State: SessionState,
F: StreamFactory<Stream = State::Stream>
{
pub fn with_connection(conn: HttpConnection,
state: State,
factory: F)
-> ServerConnection<F, State> {
ServerConnection {
conn: conn,
state: state,
factory: factory,
}
}
#[inline]
pub fn scheme(&self) -> HttpScheme {
self.conn.scheme
}
pub fn send_settings<S: SendFrame>(&mut self, sender: &mut S) -> HttpResult<()> {
self.conn.sender(sender).send_settings_ack()
}
pub fn expect_settings<Recv: ReceiveFrame, Sender: SendFrame>(&mut self,
rx: &mut Recv,
tx: &mut Sender)
-> HttpResult<()> {
let mut session = ServerSession::new(&mut self.state, &mut self.factory, tx);
self.conn.expect_settings(rx, &mut session)
}
#[inline]
pub fn handle_next_frame<Recv: ReceiveFrame, Sender: SendFrame>(&mut self,
rx: &mut Recv,
tx: &mut Sender)
-> HttpResult<()> {
let mut session = ServerSession::new(&mut self.state, &mut self.factory, tx);
self.conn.handle_next_frame(rx, &mut session)
}
#[inline]
pub fn start_response<'n, 'v, S: SendFrame>(&mut self,
headers: Vec<Header<'n, 'v>>,
stream_id: StreamId,
end_stream: EndStream,
sender: &mut S)
-> HttpResult<()> {
self.conn.sender(sender).send_headers(headers, stream_id, end_stream)
}
pub fn send_next_data<S: SendFrame>(&mut self, sender: &mut S) -> HttpResult<SendStatus> {
debug!("Sending next data...");
const MAX_CHUNK_SIZE: usize = 8 * 1024;
let mut buf = [0; MAX_CHUNK_SIZE];
let mut prioritizer = SimplePrioritizer::new(&mut self.state, &mut buf);
self.conn.sender(sender).send_next_data(&mut prioritizer)
}
}
#[cfg(test)]
mod tests {
use super::ServerSession;
use http::tests::common::{TestStream, TestStreamFactory, build_mock_http_conn, MockSendFrame};
use http::{Header, ErrorCode, HttpError};
use http::session::{DefaultSessionState, SessionState, Stream, Session};
use http::session::Server as ServerMarker;
#[test]
fn test_server_session() {
let mut state = DefaultSessionState::<ServerMarker, TestStream>::new();
let mut conn = build_mock_http_conn();
let mut sender = MockSendFrame::new();
let headers = vec![Header::new(b":method".to_vec(), b"GET".to_vec())];
{
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.new_headers(1, headers.clone(), &mut conn).unwrap();
}
assert!(state.get_stream_ref(1).is_some());
assert_eq!(state.get_stream_ref(1).unwrap().headers.clone().unwrap(),
headers);
{
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.new_data_chunk(1, &[1, 2, 3], &mut conn).unwrap();
}
assert_eq!(state.get_stream_ref(1).unwrap().body, vec![1, 2, 3]);
{
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.new_data_chunk(1, &[4], &mut conn).unwrap();
}
assert_eq!(state.get_stream_ref(1).unwrap().body, vec![1, 2, 3, 4]);
{
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.new_headers(3, headers.clone(), &mut conn).unwrap();
session.new_data_chunk(3, &[100], &mut conn).unwrap();
}
assert!(state.get_stream_ref(3).is_some());
assert_eq!(state.get_stream_ref(3).unwrap().headers.clone().unwrap(),
headers);
assert_eq!(state.get_stream_ref(3).unwrap().body, vec![100]);
{
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.end_of_stream(1, &mut conn).unwrap();
}
assert!(state.get_stream_ref(1).unwrap().is_closed_remote());
assert!(!state.get_stream_ref(3).unwrap().is_closed_remote());
}
#[test]
fn test_server_session_rst_stream() {
let mut state = DefaultSessionState::<ServerMarker, TestStream>::new();
let mut conn = build_mock_http_conn();
let mut sender = MockSendFrame::new();
state.insert_incoming(1, TestStream::new()).unwrap();
state.insert_incoming(3, TestStream::new()).unwrap();
state.insert_incoming(5, TestStream::new()).unwrap();
{
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.rst_stream(3, ErrorCode::Cancel, &mut conn).unwrap();
}
assert!(state.get_stream_ref(1).map(|stream| stream.errors.len() == 0).unwrap());
assert!(state.get_stream_ref(3)
.map(|stream| {
stream.errors.len() == 1 && stream.errors[0] == ErrorCode::Cancel
})
.unwrap());
assert!(state.get_stream_ref(5).map(|stream| stream.errors.len() == 0).unwrap());
}
#[test]
fn test_server_session_on_goaway() {
let mut state = DefaultSessionState::<ServerMarker, TestStream>::new();
let mut conn = build_mock_http_conn();
let mut sender = MockSendFrame::new();
let res = {
let mut factory = TestStreamFactory;
let mut session = ServerSession::new(&mut state, &mut factory, &mut sender);
session.on_goaway(0, ErrorCode::ProtocolError, None, &mut conn)
};
if let Err(HttpError::PeerConnectionError(err)) = res {
assert_eq!(err.error_code(), ErrorCode::ProtocolError);
assert_eq!(err.debug_data(), None);
} else {
panic!("Expected a PeerConnectionError");
}
}
}