Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session cache to SslConnector #1042

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openssl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cfg-if = "0.1"
foreign-types = "0.3.1"
lazy_static = "1"
libc = "0.2"
linked_hash_set = "0.1"

openssl-sys = { version = "0.9.40", path = "../openssl-sys" }

Expand Down
1 change: 1 addition & 0 deletions openssl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ extern crate foreign_types;
#[macro_use]
extern crate lazy_static;
extern crate libc;
extern crate linked_hash_set;
extern crate openssl_sys as ffi;

#[cfg(test)]
Expand Down
82 changes: 82 additions & 0 deletions openssl/src/ssl/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use linked_hash_set::LinkedHashSet;
use ssl::{SslSession, SslSessionRef};
use std::borrow::Borrow;
use std::collections::hash_map::{Entry, HashMap};
use std::hash::{Hash, Hasher};

#[derive(Hash, PartialEq, Eq, Clone)]
pub struct SessionKey {
pub host: String,
pub port: u16,
}

#[derive(Clone)]
struct HashSession(SslSession);

impl PartialEq for HashSession {
fn eq(&self, other: &HashSession) -> bool {
self.0.id() == other.0.id()
}
}

impl Eq for HashSession {}

impl Hash for HashSession {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.0.id().hash(state);
}
}

impl Borrow<[u8]> for HashSession {
fn borrow(&self) -> &[u8] {
self.0.id()
}
}

pub struct SessionCache {
sessions: HashMap<SessionKey, LinkedHashSet<HashSession>>,
reverse: HashMap<HashSession, SessionKey>,
}

impl SessionCache {
pub fn new() -> SessionCache {
SessionCache {
sessions: HashMap::new(),
reverse: HashMap::new(),
}
}

pub fn insert(&mut self, key: SessionKey, session: SslSession) {
let session = HashSession(session);

self.sessions
.entry(key.clone())
.or_insert_with(LinkedHashSet::new)
.insert(session.clone());
self.reverse.insert(session.clone(), key);
}

pub fn get(&mut self, key: &SessionKey) -> Option<SslSession> {
let sessions = self.sessions.get_mut(key)?;
let session = sessions.front().cloned()?;
sessions.refresh(&session);
Some(session.0)
}

pub fn remove(&mut self, session: &SslSessionRef) {
let key = match self.reverse.remove(session.id()) {
Some(key) => key,
None => return,
};

if let Entry::Occupied(mut sessions) = self.sessions.entry(key) {
sessions.get_mut().remove(session.id());
if sessions.get().is_empty() {
sessions.remove();
}
}
}
}
156 changes: 147 additions & 9 deletions openssl/src/ssl/connector.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
use foreign_types::{ForeignType, ForeignTypeRef};
use std::io::{Read, Write};
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};

use dh::Dh;
use error::ErrorStack;
use ex_data::Index;
use ssl::cache::{SessionCache, SessionKey};
use ssl::{
HandshakeError, Ssl, SslContext, SslContextBuilder, SslMethod, SslMode, SslOptions, SslRef,
SslStream, SslVerifyMode,
SslSessionCacheMode, SslStream, SslVerifyMode,
};
use version;

lazy_static! {
// The unwrap here isn't great but this only fails on OOM
static ref SESSION_CACHE_KEY_INDEX: Index<Ssl, SessionKey> = Ssl::new_ex_index().unwrap();
}

fn ctx(method: SslMethod) -> Result<SslContextBuilder, ErrorStack> {
let mut ctx = SslContextBuilder::new(method)?;

Expand Down Expand Up @@ -46,7 +55,10 @@ fn ctx(method: SslMethod) -> Result<SslContextBuilder, ErrorStack> {
/// OpenSSL's built in hostname verification is used when linking against OpenSSL 1.0.2 or 1.1.0,
/// and a custom implementation is used when linking against OpenSSL 1.0.1.
#[derive(Clone)]
pub struct SslConnector(SslContext);
pub struct SslConnector {
ctx: SslContext,
session_cache: Option<Arc<Mutex<SessionCache>>>,
}

impl SslConnector {
/// Creates a new builder for TLS connections.
Expand All @@ -60,7 +72,10 @@ impl SslConnector {
)?;
setup_verify(&mut ctx);

Ok(SslConnectorBuilder(ctx))
Ok(SslConnectorBuilder {
ctx,
session_cache: false,
})
}

/// Initiates a client-side TLS session on a stream.
Expand All @@ -75,43 +90,113 @@ impl SslConnector {

/// Returns a structure allowing for configuration of a single TLS session before connection.
pub fn configure(&self) -> Result<ConnectConfiguration, ErrorStack> {
Ssl::new(&self.0).map(|ssl| ConnectConfiguration {
Ssl::new(&self.ctx).map(|ssl| ConnectConfiguration {
ssl,
sni: true,
verify_hostname: true,
session_cache: self.session_cache.clone().map(|session_cache| {
SessionCacheConfiguration {
cache: session_cache,
ctx_ptr: self.ctx.as_ptr() as usize,
key: None,
}
}),
})
}
}

/// A builder for `SslConnector`s.
pub struct SslConnectorBuilder(SslContextBuilder);
pub struct SslConnectorBuilder {
ctx: SslContextBuilder,
session_cache: bool,
}

impl SslConnectorBuilder {
/// Consumes the builder, returning an `SslConnector`.
pub fn build(self) -> SslConnector {
SslConnector(self.0.build())
pub fn build(mut self) -> SslConnector {
if self.session_cache {
let session_cache = Arc::new(Mutex::new(SessionCache::new()));

let mode = self.ctx.set_session_cache_mode(SslSessionCacheMode::CLIENT);
self.ctx.set_session_cache_mode(mode | SslSessionCacheMode::CLIENT);

let cache = session_cache.clone();
self.ctx.set_new_session_callback(move |ssl, session| {
if let Some(key) = ssl.ex_data(*SESSION_CACHE_KEY_INDEX) {
if let Ok(mut cache) = cache.lock() {
cache.insert(key.clone(), session);
}
}
});

let cache = session_cache.clone();
self.ctx.set_remove_session_callback(move |_, session| {
if let Ok(mut cache) = cache.lock() {
cache.remove(session);
}
});

SslConnector {
ctx: self.ctx.build(),
session_cache: Some(session_cache),
}
} else {
SslConnector {
ctx: self.ctx.build(),
session_cache: None,
}
}
}

/// A builder-style version of `set_use_session_cache`.
pub fn use_session_cache(mut self, use_session_cache: bool) -> SslConnectorBuilder {
self.set_use_session_cache(use_session_cache);
self
}

/// Configures the use of session cache.
///
/// Setting `use_session_cache` to `true` causes `build` to set up the relevant session cache
/// callbacks for the connector. Previous callbacks may be invalidated.
///
/// After enabling session cache, configure each connection with
/// [`ConnectConfiguration::set_session_cache_key`].
///
/// Defaults to `false`.
///
/// [`ConnectConfiguration::set_session_cache_key`]:
/// struct.ConnectConfiguration.html#method.set_session_cache_key
pub fn set_use_session_cache(&mut self, use_session_cache: bool) {
self.session_cache = use_session_cache
}
}

impl Deref for SslConnectorBuilder {
type Target = SslContextBuilder;

fn deref(&self) -> &SslContextBuilder {
&self.0
&self.ctx
}
}

impl DerefMut for SslConnectorBuilder {
fn deref_mut(&mut self) -> &mut SslContextBuilder {
&mut self.0
&mut self.ctx
}
}

struct SessionCacheConfiguration {
cache: Arc<Mutex<SessionCache>>,
ctx_ptr: usize, // Used to verify that the context hasn't changed
key: Option<SessionKey>,
}

/// A type which allows for configuration of a client-side TLS session before connection.
pub struct ConnectConfiguration {
ssl: Ssl,
sni: bool,
verify_hostname: bool,
session_cache: Option<SessionCacheConfiguration>,
}

impl ConnectConfiguration {
Expand Down Expand Up @@ -147,6 +232,25 @@ impl ConnectConfiguration {
self.verify_hostname = verify_hostname;
}

/// A builder-style version of `set_session_cache_key`.
pub fn session_cache_key(mut self, host: String, port: u16) -> ConnectConfiguration {
self.set_session_cache_key(host, port);
self
}

/// Configures the session cache key for this connection.
///
/// To be effective, session cache needs to be enabled by
/// [`SslConnectorBuilder::set_use_session_cache`].
///
/// [`SslConnectorBuilder::set_use_session_cache`]:
/// struct.SslConnectorBuilder.html#method.set_use_session_cache
pub fn set_session_cache_key(&mut self, host: String, port: u16) {
if let Some(ref mut session_cache) = self.session_cache {
session_cache.key = Some(SessionKey { host, port });
}
}

/// Initiates a client-side TLS session on a stream.
///
/// The domain is used for SNI and hostname verification if enabled.
Expand All @@ -162,8 +266,42 @@ impl ConnectConfiguration {
setup_verify_hostname(&mut self.ssl, domain)?;
}

if let Some(session_cache) = self.session_cache {
if let Some(key) = session_cache.key {
Self::setup_session_cache(
&mut self.ssl,
session_cache.cache,
session_cache.ctx_ptr,
key,
)
.map_err(|e| HandshakeError::SetupFailure(e))?;
}
}

self.ssl.connect(stream)
}

fn setup_session_cache(
ssl: &mut Ssl,
cache: Arc<Mutex<SessionCache>>,
ctx_ptr: usize,
key: SessionKey,
) -> Result<(), ErrorStack> {
if ssl.ssl_context().as_ptr() as usize == ctx_ptr {
// Safety: the context hasn't changed.
if let Ok(mut cache) = cache.lock() {
if let Some(session) = cache.get(&key) {
unsafe {
ssl.set_session(&session)?;
}
}
}
}

ssl.set_ex_data(*SESSION_CACHE_KEY_INDEX, key);

Ok(())
}
}

impl Deref for ConnectConfiguration {
Expand Down
1 change: 1 addition & 0 deletions openssl/src/ssl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub use ssl::connector::{
pub use ssl::error::{Error, ErrorCode, HandshakeError};

mod bio;
mod cache;
mod callbacks;
mod connector;
mod error;
Expand Down
Loading