jax cuda 自定义算子

如何通过jax框架进行 自定义算子的实现呢?

gpu_ops.cpp

#include "kernels.h"
#include "pybind11_kernel_helpers.h"

namespace
{
    pybind11::dict MVMRegistrations()
    {
        pybind11::dict dict;
        dict["gpu_BF16"] =
            gpu_ops::EncapsulateFunction(gpu_ops::gpu_BF16);
        return dict;
    }

    PYBIND11_MODULE(gpu_ops, m)
    {
        m.def("get_mvm_registrations", &MVMRegistrations);

        m.def("create_mvm_descriptor",
              [](int numRows, int numCols, int batchSize, int miniBatchSize, gpu_ops::ElementType f_type)
              {
                  return gpu_ops::PackDescriptor(gpu_ops::MFMDescriptor{
                      numRows, numCols, batchSize, miniBatchSize, f_type});
              });

        pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
            .value("BF16", gpu_ops::ElementType::BF16)
            .value("F32", gpu_ops::ElementType::F32);
    }
} // namespace

定义注册函数 以及pybind的接口

kenrl_helpers.h

// This header is not specific to our application and you'll probably want
// something like this for any extension you're building. This includes the
// infrastructure needed to serialize descriptors that are used with the
// "opaque" parameter of the GPU custom call. In our example we'll use this
// parameter to pass the size of our problem.

#ifndef _GPU_OPS_KERNEL_HELPERS_H_
#define _GPU_OPS_KERNEL_HELPERS_H_

#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>

#define JAX_APEX_WARP_SIZE 32

namespace gpu_ops
{

    // https://en.cppreference.com/w/cpp/numeric/bit_cast
    template <class To, class From>
    typename std::enable_if<sizeof(To) == sizeof(From) &&
                                std::is_trivially_copyable<From>::value &&
                                std::is_trivially_copyable<To>::value,
                            To>::type
    bit_cast(const From &src) noexcept
    {
        static_assert(std::is_trivially_constructible<To>::value,
                      "This implementation additionally requires destination type to "
                      "be trivially constructible");

        To dst;
        memcpy(&dst, &src, sizeof(To));
        return dst;
    }

    template <typename T>
    std::string PackDescriptorAsString(const T &descriptor)
    {
        return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
    }

    template <typename T>
    const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len)
    {
        if (opaque_len != sizeof(T))
        {
            throw std::runtime_error("Invalid opaque object size");
        }
        return bit_cast<const T *>(opaque);
    }

} // namespace gpu_ops

#endif

kernels.h

#ifndef _GPU_OPS_KERNELS_H_
#define _GPU_OPS_KERNELS_H_

#include <cuda_runtime_api.h>
#include <cuda_bf16.h>
#include <cstddef>
#include <cstdint>

namespace gpu_ops
{
    enum ElementType
    {
        F32,
        BF16,
    };

    struct MFMDescriptor
    {
        int numRows;
        int numCols;
        int batchSize;
        int miniBatchSize;
        ElementType f_type;
    };

    void gpu_BF16(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len);

} // namespace gpu_ops

#endif

kernel.cu

#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <iostream>
#include <chrono>
#include <cmath>
#include <algorithm> // 包含std::fill函数
#include "kernel_helpers.h"
#include "kernels.h"
using namespace std;

#define BLOCK_SIZE 16
namespace gpu_ops
{

    void ThrowIfError(cudaError_t error)
    {
        if (error != cudaSuccess)
        {
            throw std::runtime_error(cudaGetErrorString(error));
        }
    }

    template <typename T>
    __global__ void MFMKernel(const T *__restrict__ A, const bool *__restrict__ B, T *__restrict__ C, int ARows, int ACols, int BCols)
    {
        int row = blockIdx.y * blockDim.y + threadIdx.y;
        int col = blockIdx.x * blockDim.x + threadIdx.x;

        if (row < ARows && col < BCols)
        {
            T sum = 0;
#pragma unroll
            for (int k = 0; k < ACols; ++k)
            {
                if (B[k * BCols + col])
                {
                    sum += A[row * ACols + k];
                }
            }
            C[row * BCols + col] = sum;
        }
    }

    template <typename T>
    __global__ void BatchMFMKernel(const T *__restrict__ A, const bool *__restrict__ B, T *__restrict__ C, int batchSize, int ARows, int ACols, int BCols)
    {
        int batch = blockIdx.z;
        int row = blockIdx.y * blockDim.y + threadIdx.y;
        int col = blockIdx.x * blockDim.x + threadIdx.x;

        if (batch < batchSize && row < ARows && col < BCols)
        {
            T sum = 0;
#pragma unroll
            for (int k = 0; k < ACols; ++k)
            {
                if (B[(batch * ACols * BCols) + (k * BCols) + col])
                {
                    sum += A[(batch * ARows * ACols) + (row * ACols) + k];
                }
            }
            C[(batch * ARows * BCols) + (row * BCols) + col] = sum;
        }
    }

    void gpu_BF16(cudaStream_t stream, void **buffers, const char *opaque,
                  std::size_t opaque_len)
    {
        const MFMDescriptor &d = *UnpackDescriptor<MFMDescriptor>(opaque, opaque_len);

        const int M = d.batchSize; // batchSize
        const int K = d.numRows;   // inputDim
        const int N = d.numCols;   // outputDim
        const int miniBatchSize = d.miniBatchSize;

        const __nv_bfloat16 *x_dev = reinterpret_cast<const __nv_bfloat16 *>(buffers[0]);
        const bool *Weight_dev = reinterpret_cast<const bool *>(buffers[1]);
        __nv_bfloat16 *y_dev = reinterpret_cast<__nv_bfloat16 *>(buffers[2]);

        dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE);

        if (miniBatchSize != 0) // 3d
        {
            dim3 gridDim((N + blockDim.x - 1) / blockDim.x,
                         (M + blockDim.y - 1) / blockDim.y,
                         M);
            BatchMFMKernel<__nv_bfloat16><<<gridDim, blockDim, 0, stream>>>(x_dev, Weight_dev, y_dev, M, miniBatchSize, K, N);
        }
        else // 2d
        {
            dim3 gridDim((N + blockDim.x - 1) / blockDim.x,
                         (M + blockDim.y - 1) / blockDim.y);
            MFMKernel<__nv_bfloat16><<<gridDim, blockDim, 0, stream>>>(x_dev, Weight_dev, y_dev, M, K, N);
        }
        ThrowIfError(cudaGetLastError());
    }

}

最核心的cuda代码

pybind11_kernel_helpers.h

// This header extends kernel_helpers.h with the pybind11 specific interface to
// serializing descriptors. It also adds a pybind11 function for wrapping our
// custom calls in a Python capsule. This is separate from kernel_helpers so
// that the CUDA code itself doesn't include pybind11. I don't think that this
// is strictly necessary, but they do it in jaxlib, so let's do it here too.

#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_

#include <pybind11/pybind11.h>

#include "kernel_helpers.h"

namespace gpu_ops
{

    template <typename T>
    pybind11::bytes PackDescriptor(const T &descriptor)
    {
        return pybind11::bytes(PackDescriptorAsString(descriptor));
    }

    template <typename T>
    pybind11::capsule EncapsulateFunction(T *fn)
    {
        return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
    }

} // namespace gpu_ops

#endif

测试代码

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call
from jax.core import ShapedArray
from jax import random

import time
from functools import partial
from colorama import Fore, Back, Style
import inspect

from ec.ops.build import gpu_ops



import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


def matrix_mul(a, b):
    return jnp.matmul(a, b)    

def test_2d():
    x_host, A_host = init_data(batch_size=32)
    for i in range(2):
        print(x_host[i,:10])   
        print(A_host[i,:10])   
    print("intput_vector.shape: ",x_host.shape)    
    print("intput_matrix.shape: ",A_host.shape)    
    start_time = time.perf_counter()
    result = jit_matrix_mul(x_host,A_host).block_until_ready()
    end_time = time.perf_counter()
    execution_time = (end_time - start_time) * 1000 
    print("result.shape: ",result.shape)
    print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
    print("The output results of random sampling: ")
    for i in range(2):
        print(result[i,:10]) 

def test_3d():
    float_vector_batched_b = []
    bool_matrix_b = []
    for i in range(2):
        float_vector_batched, bool_matrix = init_data(batch_size=32)
        float_vector_batched_b.append(float_vector_batched[None, ...])
        bool_matrix_b.append(bool_matrix[None, ...])
    x_host = jnp.concatenate(float_vector_batched_b, axis=0)        
    A_host = jnp.concatenate(bool_matrix_b, axis=0)     
    for i in range(2):
        print(x_host[i,i+1,:10])   
        print(A_host[i,i+1,:10])   
    print("intput_vector.shape: ",x_host.shape)    
    print("intput_matrix.shape: ",A_host.shape)       
    start_time = time.perf_counter()
    result = jit_matrix_mul(x_host,A_host).block_until_ready()
    end_time = time.perf_counter()
    execution_time = (end_time - start_time) * 1000 
    print("result.shape: ",result.shape)
    print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
    print("The output results of random sampling: ")
    for i in range(2):
        print(result[i,i+1,:10])

  


def init_data(batch_size, matrix_rows=3, matrix_cols=10):
    '''
    Initializes data for matrix-vector multiplication tests.

    This function creates a boolean matrix and a batch of floating-point vectors.
    The boolean matrix is randomly filled with an equal distribution of True (1) and False (0) values,
    simulating a sparse matrix environment. Each vector in the batch is initialized with random values
    using a normal distribution, suitable for performing matrix-vector multiplication.

    Parameters:
    - batch_size (int): The number of vectors in the batch.
    - matrix_rows (int): The number of rows in the boolean matrix, also determines the length of each vector.
    - matrix_cols (int): The number of columns in the boolean matrix.

    Returns:
    - A tuple containing:
        - A batch of floating-point vectors, shaped as (batch_size, matrix_rows).
        - A boolean matrix, shaped as (matrix_rows, matrix_cols).

    Note:
    - The length of the vectors equals the number of rows in the matrix, ensuring compatibility for multiplication.
    '''
   
    total_elements = matrix_rows * matrix_cols
    num_zeros = total_elements // 2
    num_ones = total_elements - num_zeros

    values = jnp.array([0, 1], dtype=jnp.bool_)
    probabilities = jnp.array([0.5, 0.5]) 
    key = random.PRNGKey(0)  
    key, subkey = random.split(key)
    zero_values = random.choice(subkey, values, shape=(num_zeros,), p=probabilities)
    key, subkey = random.split(key)
    one_values = random.choice(subkey, values, shape=(num_ones,), p=probabilities)

    bool_matrix = jnp.concatenate((zero_values, one_values)).reshape((matrix_rows, matrix_cols))
    
    float_vector_batched = []
    for _ in range(batch_size):
        key, subkey = random.split(key)
        float_vector = random.normal(subkey, (matrix_rows,), dtype=jnp.bfloat16) 
        float_vector_batched.append(float_vector[None,:])

    return jnp.concatenate(float_vector_batched, axis=0),bool_matrix


def mvm_p_fwd(vector, matrix):
    output = mvm_p.bind(vector, matrix)
    return output

for _name, _value in gpu_ops.get_mvm_registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")


def default_layouts(*shapes):
    return [range(len(shape) - 1, -1, -1) for shape in shapes]


def element_type_to_descriptor_type_mapping(element_type):
    _element_type_to_descriptor_type_mapping = {
        ir.BF16Type.get(): gpu_ops.ElementType.BF16,
        ir.F32Type.get(): gpu_ops.ElementType.F32,
    }
    return _element_type_to_descriptor_type_mapping.get(element_type)


def _mvm_p_fwd_cuda_lowering(ctx, vector,matrix):
    vector_type = ir.RankedTensorType(vector.type)
    vector_shape = vector_type.shape 
    matrix_type = ir.RankedTensorType(matrix.type)
    matrix_shape = matrix_type.shape 

    if len(vector_shape) ==2:
        batchSize,inputDim = vector_shape
        inputDim,outPutDim = matrix_shape
        result_shape = (batchSize, outPutDim)

        opaque = gpu_ops.create_mvm_descriptor(
            inputDim,
            outPutDim,
            batchSize,
            0,
            element_type_to_descriptor_type_mapping(vector_type.element_type),
        )

        out = custom_call(
            b"gpu_BF16",
            result_types=[
                ir.RankedTensorType.get(result_shape, vector_type.element_type), 
            ], 
            operands=[vector,matrix], 
            backend_config=opaque,
            operand_layouts=default_layouts(vector_shape,matrix_shape), 
            result_layouts=default_layouts(result_shape),
        ).results
        return out
    else:
        batchSize,miniBatchSize,inputDim = vector_shape
        batchSize,inputDim,outPutDim = matrix_shape
        result_shape = (batchSize, miniBatchSize,outPutDim)

        opaque = gpu_ops.create_mvm_descriptor(
            inputDim,
            outPutDim,
            batchSize,
            miniBatchSize,
            element_type_to_descriptor_type_mapping(vector_type.element_type),
        )

        out = custom_call(
            b"gpu_BF16",
            result_types=[
                ir.RankedTensorType.get(result_shape, vector_type.element_type), 
            ],
            operands=[vector,matrix], 
            backend_config=opaque,
            operand_layouts=default_layouts(vector_shape,matrix_shape),
            result_layouts=default_layouts(result_shape),
        ).results
        return out

  
def _mvm_p_fwd_abstract(vector,matrix):
    w_dtype = dtypes.canonicalize_dtype(vector.dtype)
    if len(vector.shape) == 2:
        batchSize,inputDim = vector.shape
        inputDim,outPutDim =  matrix.shape
        res_shape = (batchSize,outPutDim)    
    else:    
        batchSize,miniBatchSize,inputDim = vector.shape
        batchSize,inputDim,outPutDim =  matrix.shape
        res_shape = (batchSize,miniBatchSize,outPutDim)
    return (
        ShapedArray(res_shape, w_dtype, named_shape=matrix.named_shape),  # output
    )


def test_2d_ops():
    x_host, A_host = init_data(batch_size=32)
    for i in range(2):
        print(x_host[i,:10])   
        print(A_host[i,:10])   
    print("intput_vector.shape: ",x_host.shape)    
    print("intput_matrix.shape: ",A_host.shape)      
    start_time = time.perf_counter()
    result = mvm_p_fwd(x_host,A_host)[0]
    end_time = time.perf_counter()
    execution_time = (end_time - start_time) * 1000 
    print("result.shape: ",result.shape)
    print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
    print(Style.RESET_ALL)

    for i in range(2):
        print(result[i,:10])


def test_3d_ops():
    float_vector_batched_b = []
    bool_matrix_b = []
    for i in range(2):
        float_vector_batched, bool_matrix = init_data(batch_size=32)
        float_vector_batched_b.append(float_vector_batched[None, ...])
        bool_matrix_b.append(bool_matrix[None, ...])
    x_host = jnp.concatenate(float_vector_batched_b, axis=0)        
    A_host = jnp.concatenate(bool_matrix_b, axis=0)    
    for i in range(2):
        print(x_host[i,i+1,:10])   
        print(A_host[i,i+1,:10])      
    print("intput_vector.shape: ",x_host.shape)    
    print("intput_matrix.shape: ",A_host.shape)       
    start_time = time.perf_counter()
    result = mvm_p_fwd(x_host,A_host)[0]
    end_time = time.perf_counter()
    execution_time = (end_time - start_time) * 1000 
    print("result.shape: ",result.shape)
    print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
    print("The output results of random sampling: ")
    for i in range(2):
        print(result[i,i+1,:10])


def print_section_header(func):
    title = func.__name__
    header = f"\n{'=' * 10} {title} start {'=' * 10}\n"
    print(header)
    func()
    ender = f"\n{'=' * 10} {title} end {'=' * 10}\n"
    print(ender)

    

if __name__ == '__main__':
    mvm_p = core.Primitive("mvm_p_fwd")
    mvm_p.multiple_results = True
    mvm_p.def_impl(partial(xla.apply_primitive, mvm_p))
    mvm_p.def_abstract_eval(_mvm_p_fwd_abstract)

    mlir.register_lowering(
        mvm_p,
        _mvm_p_fwd_cuda_lowering,
        platform="gpu",
    )

    jit_matrix_mul = jax.jit(matrix_mul)

    print_section_header(test_2d)
    print_section_header(test_2d_ops)
    print_section_header(test_3d)
    print_section_header(test_3d_ops)