\(\DeclareMathOperator{\bind}{bind} \DeclareMathOperator{\mac}{mac} \DeclareMathOperator{\fold}{fold} \DeclareMathOperator{\product}{product} \DeclareMathOperator{\transpose}{transpose} \DeclareMathOperator{\dotp}{\boldsymbol{\cdot}}\)

Tuple Matrix Multiplication

Instead of the normal loop-based definition, matrix multiplication can be defined as a composition of higher-order functions. This can be implemented almost literally in CFL, as a generic function.

When applied to tuples, the resulting function is evaluable at compile-time, and a case where regular and meta programming coincide in CFL.

See Also

bind, tuple_fold, tuple_product, enable_if_tuple_rank


Definition

Matrix multiplication \(A \times B = C\) is generally defined as:

$$ \begin{equation*} C = \begin{pmatrix} \sum{a_{1k}b_{k1}} & \cdots & \sum{a_{1k}b_{kc}} \\ \vdots & \ddots & \vdots \\ \sum{a_{rk}b_{k1}} & \cdots & \sum{a_{rk}b_{kc}} \end{pmatrix}_{r \times c} \end{equation*} $$

Element \(c_{ij}\) can be recognized as the dot product \(\dotp\) between row \(A_i\) and column \(B_j\). This dot product can be viewed as a fold with a ternary multiply-and-accumulate operator \(\mac\) and a zero initial value.

$$ \begin{align*} \mac (x,y,z) &= x + y * z\\ A_i \dotp B_j = \sum_{k}{a_{ik} b_{kj}} &= \fold (\mac, 0, A_i, B_j) \end{align*} $$

The dot product is applicable for one-dimensional vectors. As the arrays \(A\) and \(B\) are two-dimensional, the result \(C\) of a matrix multiplication can be regarded as the product of \(A\) and \(B^T\) using the dot product as operator.

$$ \begin{equation*} C = \product \big(\dotp, A, \transpose (B)\big) \end{equation*} $$

Putting this together with partial application and composition, dot product \(\dotp\) and matrix multiplication \(\times\) can be defined as:

$$ \begin{align*} \dotp &= \fold (\mac, 0, x, y)\\ \times &= \product \big(\dotp, x, \transpose (y)\big) \end{align*} $$

where \(x\) and \(y\) are free variables.

Implementation

The \(\mac\), \(\dotp\) and \(\times\) functions above may be defined as generic C++ functions with CFL. Here tuples are used as arrays without any loss of generality, as the same applies to regular arrays (actually, tuples are more general). For demonstration, mac is here defined as a lambda expression. The remainder of the tutorial uses a constexpr version provided by the library.

auto mac = [] (auto a0, auto a1, auto a2)
 -> decltype (a0 + a1 * a2) 
{ 
    return    a0 + a1 * a2; 
};

A trailing return type declaration should be used instead of automatic return type deduction, to allow for SFINAE instead of an hard error. Generally, functions used with CFL should allow for substitution failure, as this is used throughout the implementation for compile-time decisions. (There is a proposal to change this).

By convention (and necessity), tuple-specific array functions are prefixed with tuple in CFL. Hence, fold, product and transpose for tuples are named tuple_product, tuple_fold and tuple_transpose, respectively. Continuing with the implementation:

auto dotp = tuple_fold (mac, 0, 1__, 2__);

dotp is now a binary function accepting two vectors (tuples) as arguments. As tuple_fold is a generic CFL function it accepts placeholders without an explicit call to bind. Note how using _c placeholders here would be an error - the \(product\) function where dotp is inserted expects a function as first argument, not a value.

auto f = tuple_product
( 
    dotp, 
    1__, 
    tuple_transpose (2_c)
);

This defines f as a function of two matrices (tuples), and the syntax is pretty close to the generic definition above. However, f will also accept arrays with higher dimension than two, then following the regular reduction scheme for products. This is pretty generic, but not always desired. Instead f can be augmented with a predicate to only accept arguments of dimension two. For convenience, a custom predicate is providied for tuple ranks, otherwise enable_if can be used.

auto g = enable_if_tuple_rank <2,2> (f);

g is a generic function, accepting placeholders and containers as well. All of the above could have been evaluated at compile-time and defined in one expression. Although GPU compilers have had some issues regarding compile-time evaluation, everything above is GPU compatible.

Finally, applying arguments to g yields the result ((19, 22), (43, 50)).

auto r = g
(
    tuple (tuple (1,2), tuple (3,4)),
    tuple (tuple (5,6), tuple (7,8))
)

Performance

A hand-coded matrix mulitplication implementation as below is used for a side-by-side comparison of generated code. Similar to ref, g is wrapped in a function to prevent inlining. As can be seen below, g has a reasonable overhead compared to the reference implementation, albeit not zero in this case.

// reference implementation

template <typename T, size_t R, size_t C> 
struct cmatrix
{
    T data [R][C];
};

template < typename T, size_t R, size_t C, size_t N> 
__attribute__ ((noinline)) 
__attribute__ ((optimize("unroll-loops")))
cmatrix <T, R, C> ref (
    T const (& a) [R][N],
    T const (& b) [N][C]
)
{
    cmatrix <T, R, C> res;
    for (size_t r = 0; r < R; ++r)
    {
        for (size_t c = 0; c < C; ++c)
        {
            res.data [r][c] = 0;
            for (size_t n = 0; n < N; ++n)
            {
                res.data [r][c] += 
                    a [r][n] * 
                    b [n][c];
            }
        }
    }
    return res;
}

g                                  ref
===============================    ===============================
4008cf: push   %r14                400816: push   %r12
4008d1: push   %r13                400818: push   %rbp
4008d3: push   %r12                400819: push   %rbx
4008d5: push   %rbp                40081a: sub    $0x10,%rsp
4008d6: push   %rbx                40081e: mov    0x8(%rsi),%edx
4008d7: mov    (%rdi),%rax         400821: mov    (%rsi),%ecx
4008da: mov    0x10(%rdi),%r8      400823: mov    0x4(%rdi),%r9d
4008de: mov    0x8(%rsi),%r9       400827: mov    0x10(%rsi),%r11d
4008e2: mov    0x10(%rsi),%rdx     40082b: mov    %fs:0x28,%rax
4008e6: mov    (%rax),%r14d        400832: 
4008e9: mov    (%rsi),%rax         400834: mov    %rax,0x8(%rsp)
4008ec: mov    (%r9),%ebp          400839: xor    %eax,%eax
4008ef: mov    (%r8),%r10d         40083b: mov    (%rdi),%eax
4008f2: mov    0x18(%rsi),%r9      40083d: mov    %edx,%r12d
4008f6: mov    0x20(%rsi),%r8      400840: mov    %ecx,%ebp
4008fa: mov    0x28(%rsi),%rsi     400842: imul   %r9d,%r12d
4008fe: mov    (%rax),%ecx         400846: mov    0x4(%rsi),%r8d
400900: mov    %r14d,%r13d         40084a: mov    0xc(%rsi),%r10d
400903: mov    0x8(%rdi),%rax      40084e: mov    0x14(%rsi),%ebx
400907: mov    (%rdx),%edx         400851: mov    0x8(%rdi),%esi
400909: mov    (%r9),%ebx          400854: imul   %eax,%ebp
40090c: mov    (%r8),%r8d          400857: imul   %r10d,%r9d
40090f: mov    (%rsi),%r11d        40085b: add    %r12d,%ebp
400912: mov    0x18(%rdi),%rsi     40085e: mov    %r11d,%r12d
400916: mov    (%rax),%eax         400861: imul   %r8d,%eax
400918: imul   %ecx,%r13d          400865: imul   %esi,%r12d
40091c: mov    (%rsi),%r9d         400869: add    %r9d,%eax
40091f: mov    0x20(%rdi),%rsi     40086c: imul   %ebx,%esi
400923: mov    0x28(%rdi),%rdi     40086f: add    %esi,%eax
400927: mov    %eax,%r12d          400871: lea    0x0(%rbp,%r12,1),%esi
40092a: imul   %edx,%r12d          400876: mov    0xc(%rdi),%r12d
40092e: mov    (%rsi),%esi         40087a: mov    0x10(%rdi),%ebp
400930: mov    (%rdi),%edi         40087d: mov    0x14(%rdi),%edi
400932: add    %r13d,%r12d         400880: shl    $0x20,%rax
400935: mov    %r10d,%r13d         400884: or     %rsi,%rax
400938: imul   %esi,%edx           400887: mov    %r11d,%esi
40093b: imul   %r9d,%ecx           40088a: imul   %r12d,%ecx
40093f: imul   %ebx,%esi           40088e: imul   %ebp,%edx
400942: add    %edx,%ecx           400891: add    %edx,%ecx
400944: imul   %ebp,%r9d           400893: mov    %ebp,%edx
400948: imul   %r8d,%r13d          400895: imul   %r12d,%r8d
40094c: add    %r9d,%esi           400899: imul   %r10d,%edx
40094f: imul   %ebp,%r14d          40089d: imul   %edi,%esi
400953: imul   %ebx,%eax           4008a0: add    %r8d,%edx
400956: imul   %edi,%r8d           4008a3: imul   %ebx,%edi
40095a: pop    %rbx                4008a6: add    %esi,%ecx
40095b: add    %r14d,%eax          4008a8: add    %edi,%edx
40095e: imul   %r11d,%edi          4008aa: shl    $0x20,%rdx
400962: add    %r8d,%ecx           4008ae: or     %rcx,%rdx
400965: imul   %r11d,%r10d         4008b1: mov    0x8(%rsp),%rbx
400969: pop    %rbp                4008b6: xor    %fs:0x28,%rbx
40096a: lea    (%rsi,%rdi,1),%edx  4008bd: 
40096d: add    %r10d,%eax          4008bf: jne    4008ca mm_ref+0xb4>
400970: lea    (%r12,%r13,1),%r10d 4008c1: add    $0x10,%rsp
400974: shl    $0x20,%rax          4008c5: pop    %rbx
400978: shl    $0x20,%rdx          4008c6: pop    %rbp
40097c: pop    %r12                4008c7: pop    %r12
40097e: or     %r10,%rax           4008c9: retq   
400981: or     %rcx,%rdx
400984: pop    %r13
400986: pop    %r14
400988: retq