use std::io;
use http::{ErrorCode, StreamId};
use http::frame::{Frame, FrameIR, FrameBuilder, FrameHeader, RawFrame, NoFlag};
pub const RST_STREAM_FRAME_LEN: u32 = 4;
pub const RST_STREAM_FRAME_TYPE: u8 = 0x3;
#[derive(Clone, Debug, PartialEq)]
pub struct RstStreamFrame {
raw_error_code: u32,
stream_id: StreamId,
flags: u8,
}
impl RstStreamFrame {
pub fn new(stream_id: StreamId, error_code: ErrorCode) -> RstStreamFrame {
RstStreamFrame {
raw_error_code: error_code.into(),
stream_id: stream_id,
flags: 0,
}
}
pub fn with_raw_error_code(stream_id: StreamId, raw_error_code: u32) -> RstStreamFrame {
RstStreamFrame {
raw_error_code: raw_error_code,
stream_id: stream_id,
flags: 0,
}
}
pub fn error_code(&self) -> ErrorCode {
self.raw_error_code.into()
}
pub fn raw_error_code(&self) -> u32 {
self.raw_error_code
}
}
impl<'a> Frame<'a> for RstStreamFrame {
type FlagType = NoFlag;
fn from_raw(raw_frame: &'a RawFrame<'a>) -> Option<Self> {
let (payload_len, frame_type, flags, stream_id) = raw_frame.header();
if payload_len != RST_STREAM_FRAME_LEN {
return None;
}
if frame_type != RST_STREAM_FRAME_TYPE {
return None;
}
if stream_id == 0x0 {
return None;
}
let error = unpack_octets_4!(raw_frame.payload(), 0, u32);
Some(RstStreamFrame {
raw_error_code: error,
stream_id: stream_id,
flags: flags,
})
}
fn is_set(&self, _: NoFlag) -> bool {
false
}
fn get_stream_id(&self) -> StreamId {
self.stream_id
}
fn get_header(&self) -> FrameHeader {
(RST_STREAM_FRAME_LEN,
RST_STREAM_FRAME_TYPE,
self.flags,
self.stream_id)
}
}
impl FrameIR for RstStreamFrame {
fn serialize_into<B: FrameBuilder>(self, builder: &mut B) -> io::Result<()> {
try!(builder.write_header(self.get_header()));
try!(builder.write_u32(self.raw_error_code));
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::RstStreamFrame;
use http::tests::common::serialize_frame;
use http::ErrorCode;
use http::frame::{pack_header, FrameHeader, Frame};
fn prepare_frame_bytes(header: FrameHeader, payload: Vec<u8>) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend(pack_header(&header).to_vec());
buf.extend(payload);
buf
}
#[test]
fn test_parse_valid() {
let raw = prepare_frame_bytes((4, 0x3, 0, 1), vec![0, 0, 0, 1]);
let rst = RstStreamFrame::from_raw(&raw.into()).expect("Valid frame expected");
assert_eq!(rst.error_code(), ErrorCode::ProtocolError);
assert_eq!(rst.get_stream_id(), 1);
}
#[test]
fn test_parse_valid_with_unknown_flags() {
let raw = prepare_frame_bytes((4, 0x3, 0x80, 1), vec![0, 0, 0, 1]);
let rst = RstStreamFrame::from_raw(&raw.into()).expect("Valid frame expected");
assert_eq!(rst.error_code(), ErrorCode::ProtocolError);
assert_eq!(rst.get_stream_id(), 1);
assert_eq!(rst.get_header().2, 0x80);
}
#[test]
fn test_parse_unknown_error_code() {
let raw = prepare_frame_bytes((4, 0x3, 0x80, 1), vec![1, 0, 0, 1]);
let rst = RstStreamFrame::from_raw(&raw.into()).expect("Valid frame expected");
assert_eq!(rst.error_code(), ErrorCode::InternalError);
assert_eq!(rst.raw_error_code(), 0x01000001);
}
#[test]
fn test_parse_invalid_stream_id() {
let raw = prepare_frame_bytes((4, 0x3, 0x80, 0), vec![0, 0, 0, 1]);
assert!(RstStreamFrame::from_raw(&raw.into()).is_none());
}
#[test]
fn test_parse_invalid_payload_size() {
let raw = prepare_frame_bytes((5, 0x3, 0x00, 2), vec![0, 0, 0, 1, 0]);
assert!(RstStreamFrame::from_raw(&raw.into()).is_none());
}
#[test]
fn test_parse_invalid_id() {
let raw = prepare_frame_bytes((4, 0x1, 0x00, 2), vec![0, 0, 0, 1, 0]);
assert!(RstStreamFrame::from_raw(&raw.into()).is_none());
}
#[test]
fn test_serialize_protocol_error() {
let frame = RstStreamFrame::new(1, ErrorCode::ProtocolError);
let raw = serialize_frame(&frame);
assert_eq!(raw, prepare_frame_bytes((4, 0x3, 0, 1), vec![0, 0, 0, 1]));
}
#[test]
fn test_serialize_stream_closed() {
let frame = RstStreamFrame::new(2, ErrorCode::StreamClosed);
let raw = serialize_frame(&frame);
assert_eq!(raw, prepare_frame_bytes((4, 0x3, 0, 2), vec![0, 0, 0, 0x5]));
}
#[test]
fn test_serialize_raw_error_code() {
let frame = RstStreamFrame::with_raw_error_code(3, 1024);
let raw = serialize_frame(&frame);
assert_eq!(raw,
prepare_frame_bytes((4, 0x3, 0, 3), vec![0, 0, 0x04, 0]));
}
#[test]
fn test_partial_eq() {
let frame1 = RstStreamFrame::with_raw_error_code(3, 1);
let frame2 = RstStreamFrame::new(3, ErrorCode::ProtocolError);
assert_eq!(frame1, frame2);
}
}