a blog

RSS Feed

by Gui Andrade

Implementing specialization in Rust by matching marker types

Or, creating a higher order 'mapping' type

Hacking with some performance-critical Rust code recently, I found myself missing some way to construct an enum with variants chosen at compile time by type reflection.

Here's an example of my ideal type (syntax aside):

struct UseSmallNum {}
struct UseNormalNum {}
struct UseBigNum {}
struct UseHugeNum {}

tymap Foo<K> for K {
    UseSmallNum  => u32,
    UseNormalNum => u64,
    UseBigNum    => u128,
    UseHugeNum   => (u128, u128)
};

impl<K> Foo<K> {
    fn positive(&self) -> bool {
        match self {
            UseSmallNum(ref val: u32)         => val >= 0,
            UseNormalNum(ref val: u64)        => val >= 0,
            UseBigNum(ref val: u128)          => val >= 0,
            UseHugeNum(ref val: (u128, u128)) => val.0 >= 0,
        }
    }
}

fn main() {
    let foo = Number::<UseSmallNum>::new(0u32);
    assert!(foo.positive());
}

But why?

I wanted to be able to specialize some struct or function, based on some kind of marker type.

Specifically, in my llama emulator project I wanted to be able to compile two versions of the main interpreter. The Nintendo 3DS uses two (actually three) CPU families: ARM9 and ARM11. And there are significant changes between them, but I wanted to express these differences in terms of specialization. So here I wanted to be able to have some code like this:

struct Arm9;
struct Arm11;

struct StatusReg9 { ... }
struct StatusReg11 { ... }
tymap StatusReg<V> for V {
    Arm9 => StatusReg9,
    Arm11 => StatusReg11,
};

struct Cpu<Version> {
    cpsr: StatusReg<Version>,
}
impl<V> Cpu<V> {
    fn step_instruction() {
        ...

        // Update status reg
        match cpsr {
            Arm9(ref val: StatusReg9) => val.update(...),
            Arm11(ref val: StatusReg11) => val.update(...),
        }
    }
}

You can think of it like a nicer version of C++'s std::enable_if.

Benefits over wrapping an enum

In the hot path, every last branch instruction is a potential performance pitfall. Because all the code that runs here is selected at compile-time, the optimizer can erase all reflection, just like using generics.

Is this idiomatic? What about traits?

Traits would definitely be the idiomatic way to solve this problem. But using a tymap can provide both an enum's exhaustiveness guarantees and the code monomorphization of traits.

And, most importantly, the tymap can be stored in structs without any type erasure or runtime penalty.

Implementation

Turns out, we can implement something very similar in stable Rust with the help of macros, enums, and a little unsafe code. See this GitHub repo if you'd like to browse a complete implementation.

Storage size and alignment

We need our tymap type to be at least as large as the largest variant, with the correct alignment as well. But we won't use a struct with all the variants, because we don't want to waste too much space. A union wouldn't work either, because Rust unions currently only accept Copy fields.

Fortunately, Rust's enum type already provides us size and alignment guarantees. And we can use #[repr(u8)] to allow us to access the enum variant data directly, without caring about the discriminant. This effectively allows us a union without union's restrictions:

#[repr(u8)]
enum $type_name {
    $( $key_name ( $val ) ),*
}

impl $type_name {
    unsafe fn inner<V: 'static>(&self) -> *const V {
        #[repr(C)]
        struct Repr<V> {
            _discriminiant: u8,
            _inner: V
        }
        let repr_ptr = self as *const Self as *const Repr<V>;
        &(*repr_ptr)._inner
    }

    unsafe fn inner_mut<V: 'static>(&mut self) -> *mut V {
        #[repr(C)]
        struct Repr<V> {
            _discriminiant: u8,
            _inner: V
        }
        let repr_ptr = self as *mut Self as *mut Repr<V>;
        &mut (*repr_ptr)._inner
    }
}

Declaring a tymap

We'll have the following syntax for our tymap macro:

struct UseSmallNum {}
struct UseNormalNum {}
struct UseBigNum {}
struct UseHugeNum {}

tymap!(Number {
    UseSmallNum  { val1: u32 },
    UseNormalNum { val2: u64 },
    UseBigNum    { val3: u128 },
    UseHugeNum   { val4: (u128, u128) }
});

valX, here, is an ident, not a type, and is only used for for variable binding and type parameters; it doesn't provide any additional "information" in the type declaration.

KeyTypeX is both an ident and a type, which unfortunately means that we can only use type names which are also valid idents (but type A = B mitigates this problem).

It would be nice if these tricks weren't necessary, but these are limitations of macro_rules!. A future iteration could use a procedural macro instead for an even nicer syntax.

Constructing a new tymap

First, let's add the type parameter K to $type_name, as well as the enum variant _hidden(PhantomData<K>). K will be the currently active "key" of our type mapping.

impl<K: 'static> $type_name<K> {
    pub fn new<V: 'static>(v: V) -> Self {
        // Don't allow the creation of a tymap with an invalid key-value pair.
        // Would be nice if this could be a static assertion, but TypeId::of isn't a const fn.
        // This assertion is necessarily optimized away.
        assert!(
            $( (TypeId::of::<K>() == TypeId::of::<$key>()
                && TypeId::of::<V>() == TypeId::of::<$val>()) )||*
        );

        let mut out = $type_name::_hidden(PhantomData);
        unsafe { ptr::write(out.inner_mut(), v) }
        out
    }
}

Implementing match

// match_ref and match_move follow pretty clearly from the following implementation
pub fn match_ref_mut
    <$($VAL: 'static),*, Out>
    (&mut self, $($val_func: impl FnOnce(&mut $VAL) -> Out),*)
    -> Out
{
    $(
        /// These if-statements are necessarily optimized away
        if TypeId::of::<K>() == TypeId::of::<$key>()
            && TypeId::of::<$VAL>() == TypeId::of::<$val>() {

            let cast_self = unsafe { &mut *self.inner_mut() };
            return $val_func(cast_self);
        }
    ) else *

    // Again, unfortunate that this can't be a static assertion; optimized away
    unreachable!();
}

How would you use it?

fn positive<NTy>(num: Number<NTy>) -> i32 {
    return foo.match_ref(
        |val: &u32|          *val >= 0,
        |val: &u64|          *val >= 0,
        |val: &u128|         *val >= 0,
        |val: &(u128, u128)| *val.0 >= 0,
    )
}