a blog

by Gui Andrade

Crafting a Rust enum with variants decided at compile-time

Or, creating a higher order 'mapping' type

Hacking with some performance-critical Rust code recently, I found myself missing some way to construct an enum with variants decided by type reflection.

Here's an example of my ideal type (syntax aside):

struct KeyType1;
struct KeyType2;
struct KeyType3;
struct KeyType4;

tymap Foo<K> for K {
    KeyType1 => u32,
    KeyType2 => u64,
    KeyType3 => u128,
    KeyType4 => (u128, u128)
};

impl<K> Foo<K> {
    fn positive(&self) -> bool {
        match K {
            KeyType1(ref val: u32)          => val >= 0,
            KeyType2(ref val: u64)          => val >= 0,
            KeyType3(ref val: u128)         => val >= 0,
            KeyType4(ref val: (u128, u128)) => val.0 >= 0,
        }
    }
}

fn main() {
    let foo = Foo::<KeyType1>::new(0u32);
    assert!(foo.positive());
}

But why?

Think of it like a nicer version of C++'s std::enable_if.

Benefits over wrapping an enum

In the hot path, every last branch instruction is a potential performance pitfall. Because all the code that runs here is selected at compile-time, the optimizer can erase all reflection, just like using generics.

Is this idiomatic? What about traits?

Traits would definitely be the idiomatic way to solve this problem. But using a tymap can provide both an enum's exhaustiveness guarantees and the code monomorphization of traits.

And, most importantly, the tymap can be stored in structs without any type erasure or runtime penalty.

Implementation

Turns out, we can implement something very similar in stable Rust with the help of macros, enums, and a little unsafe code. See this GitHub repo if you'd like to browse a complete implementation.

Storage size and alignment

We need our tymap type to be at least as large as the largest variant, with the correct alignment as well. But we won't use a struct with all the variants, because we don't want to waste too much space. A union wouldn't work either, because Rust unions currently only accept Copy variants.

Fortunately, Rust's enum type already provides us size and alignment guarantees. And we can use #[repr(u8)] to allow us to access the enum variant data directly, without caring about the discriminant. This effectively allows us a union without union's restrictions:

#[repr(u8)]
enum $type_name {
    $( $key_name ( $val ) ),*
}

impl $type_name {
    unsafe fn inner<V: 'static>(&self) -> *const V {
        #[repr(C)]
        struct Repr<V> {
            _discriminiant: u8,
            _inner: V
        }
        let repr_ptr = self as *const Self as *const Repr<V>;
        &(*repr_ptr)._inner
    }

    unsafe fn inner_mut<V: 'static>(&mut self) -> *mut V {
        #[repr(C)]
        struct Repr<V> {
            _discriminiant: u8,
            _inner: V
        }
        let repr_ptr = self as *mut Self as *mut Repr<V>;
        &mut (*repr_ptr)._inner
    }
}

Declaring a tymap

We'll have the following syntax for our tymap macro:

tymap!(Foo {
    KeyType1 { val1: u32 },
    KeyType2 { val2: u64 },
    KeyType3 { val3: u128 },
    KeyType4 { val4: (u128, u128) }
});

valX, here, is an ident, not a type, and is only used for for variable binding and type parameters; it doesn't provide any additional "information" in the type declaration.

KeyTypeX is both an ident and a type, which unfortunately means that we can only use type names which are also valid idents (but type A = B mitigates this problem).

It would be nice if these tricks weren't necessary, but these are limitations of macro_rules!. A future iteration could use a procedural macro instead for an even nicer syntax.

Constructing a new tymap

First, let's add the type parameter K to $type_name, as well as the enum variant _hidden(PhantomData<K>). K will be the currently active "key" of our type mapping.

impl<K: 'static> $type_name<K> {
    pub fn new<V: 'static>(v: V) -> Self {
        // Don't allow the creation of a tymap with an invalid key-value pair.
        // Would be nice if this could be a static assertion, but TypeId::of isn't a const fn.
        // This assertion is necessarily optimized away.
        assert!(
            $( (TypeId::of::<K>() == TypeId::of::<$key>()
                && TypeId::of::<V>() == TypeId::of::<$val>()) )||*
        );

        let mut out = $type_name::_hidden(PhantomData);
        unsafe { ptr::write(out.inner_mut(), v) }
        out
    }
}

Implementing match

// match_ref and match_move follow pretty clearly from the following implementation
pub fn match_ref_mut
    <$($VAL: 'static),*, Out>
    (&mut self, $($val_func: impl FnOnce(&mut $VAL) -> Out),*)
    -> Out
{
    $(
        /// These if-statements are necessarily optimized away
        if TypeId::of::<K>() == TypeId::of::<$key>()
            && TypeId::of::<$VAL>() == TypeId::of::<$val>() {

            let cast_self = unsafe { &mut *self.inner_mut() };
            return $val_func(cast_self);
        }
    ) else *

    // Again, unfortunate that this can't be a static assertion; optimized away
    unreachable!();
}

// Usage
let is_positive = foo.match_ref(
    |val: &u32|          => *val >= 0,
    |val: &u64|          => *val >= 0,
    |val: &u128|         => *val >= 0,
    |val: &(u128, u128)| => *val.0 >= 0,
);