// Basically reworking what this guy already did:
// http://www.codeproject.com/KB/cpp/thunk32.aspx

#define WIN32_LEAN_AND_MEAN
#include <windows.h>

template <typename D, typename S>
D really_reinterpret_cast(S s) {
    char __static_assert_that_types_have_same_size[sizeof(S) == sizeof(D)];

    union {
        S s;
        D d;
    } u;
    
    u.s = s;
    return u.d;
}

template <typename C, typename M>
void* createThunk(C* instance, M method) {
    char code[] = {
        0xB9, 0, 0, 0, 0,   // mov ecx, 0
        0xB8, 0, 0, 0, 0,   // mov eax, 0
        0xFF, 0xE0          // jmp eax
    };

    // YEEHAW
    *((I**)(code + 1)) = instance;
    *((void**)(code + 6)) = really_reinterpret_cast<void*>(method);

    void* thunk = VirtualAlloc(0, sizeof(code), MEM_COMMIT, PAGE_EXECUTE_READWRITE);
    memcpy(thunk, &code, sizeof(code));
    FlushInstructionCache(GetCurrentProcess(), thunk, sizeof(code));

    return thunk;
}

void releaseThunk(void* thunk) {
    VirtualFree(thunk, 0, MEM_RELEASE);
}

template <typename F>
struct methodptr_traits;

template <typename ReturnType, typename T>
struct methodptr_traits<ReturnType (T::*)()> {
    typedef T class_type;
    typedef ReturnType (__stdcall *function_type)();
};

template <typename ReturnType, typename T, typename Arg1>
struct methodptr_traits<ReturnType (T::*)(Arg1)> {
    typedef T class_type;
    typedef ReturnType (__stdcall *function_type)(Arg1);
};

template <typename M>
struct Thunk {
    typedef M Method;
    typedef typename methodptr_traits<M>::class_type Class;
    typedef typename methodptr_traits<M>::function_type Function;

    Thunk(Class* instance, Method method)
        : ptr(createThunk(instance, method))
    { }

    ~Thunk() {
        releaseThunk(ptr);
    }

    Function get() const {
        return reinterpret_cast<Function>(ptr);
    }

private:
    Thunk(const Thunk&);

    void* ptr;
};

//// Example follows

#include <cstdio>
using std::printf;

struct I {
    virtual void print() = 0;
    virtual void printSum(double y) = 0;
};

struct C : I {
    C(int x)
        : x(x)
    { }
    
    void print() {
        printf("My x is %i\n", x);
    }
    
    void printSum(double y) {
        printf("%i + %f = %f\n", x, y, x + y);
    }

    int x;
};

typedef methodptr_traits<void (I::*)()>::function_type Function0;

void main() {
    I* instance = new C(4);

    Thunk<void (I::*)()> thunk(instance, &I::print);
    printf("\n");
    thunk.get()();

    Thunk<void (I::*)(double)> thunk2(instance, &I::printSum);
    printf("\n");
    thunk2.get()(3.14);

    delete instance;
}