jax cuda 自定义算子
如何通过jax框架进行 自定义算子的实现呢?
gpu_ops.cpp
#include "kernels.h"
#include "pybind11_kernel_helpers.h"
namespace
{
pybind11::dict MVMRegistrations()
{
pybind11::dict dict;
dict["gpu_BF16"] =
gpu_ops::EncapsulateFunction(gpu_ops::gpu_BF16);
return dict;
}
PYBIND11_MODULE(gpu_ops, m)
{
m.def("get_mvm_registrations", &MVMRegistrations);
m.def("create_mvm_descriptor",
[](int numRows, int numCols, int batchSize, int miniBatchSize, gpu_ops::ElementType f_type)
{
return gpu_ops::PackDescriptor(gpu_ops::MFMDescriptor{
numRows, numCols, batchSize, miniBatchSize, f_type});
});
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
.value("BF16", gpu_ops::ElementType::BF16)
.value("F32", gpu_ops::ElementType::F32);
}
} // namespace
定义注册函数 以及pybind的接口
kenrl_helpers.h
// This header is not specific to our application and you'll probably want
// something like this for any extension you're building. This includes the
// infrastructure needed to serialize descriptors that are used with the
// "opaque" parameter of the GPU custom call. In our example we'll use this
// parameter to pass the size of our problem.
#ifndef _GPU_OPS_KERNEL_HELPERS_H_
#define _GPU_OPS_KERNEL_HELPERS_H_
#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>
#define JAX_APEX_WARP_SIZE 32
namespace gpu_ops
{
// https://en.cppreference.com/w/cpp/numeric/bit_cast
template <class To, class From>
typename std::enable_if<sizeof(To) == sizeof(From) &&
std::is_trivially_copyable<From>::value &&
std::is_trivially_copyable<To>::value,
To>::type
bit_cast(const From &src) noexcept
{
static_assert(std::is_trivially_constructible<To>::value,
"This implementation additionally requires destination type to "
"be trivially constructible");
To dst;
memcpy(&dst, &src, sizeof(To));
return dst;
}
template <typename T>
std::string PackDescriptorAsString(const T &descriptor)
{
return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
}
template <typename T>
const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len)
{
if (opaque_len != sizeof(T))
{
throw std::runtime_error("Invalid opaque object size");
}
return bit_cast<const T *>(opaque);
}
} // namespace gpu_ops
#endif
kernels.h
#ifndef _GPU_OPS_KERNELS_H_
#define _GPU_OPS_KERNELS_H_
#include <cuda_runtime_api.h>
#include <cuda_bf16.h>
#include <cstddef>
#include <cstdint>
namespace gpu_ops
{
enum ElementType
{
F32,
BF16,
};
struct MFMDescriptor
{
int numRows;
int numCols;
int batchSize;
int miniBatchSize;
ElementType f_type;
};
void gpu_BF16(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len);
} // namespace gpu_ops
#endif
kernel.cu
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <iostream>
#include <chrono>
#include <cmath>
#include <algorithm> // 包含std::fill函数
#include "kernel_helpers.h"
#include "kernels.h"
using namespace std;
#define BLOCK_SIZE 16
namespace gpu_ops
{
void ThrowIfError(cudaError_t error)
{
if (error != cudaSuccess)
{
throw std::runtime_error(cudaGetErrorString(error));
}
}
template <typename T>
__global__ void MFMKernel(const T *__restrict__ A, const bool *__restrict__ B, T *__restrict__ C, int ARows, int ACols, int BCols)
{
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < ARows && col < BCols)
{
T sum = 0;
#pragma unroll
for (int k = 0; k < ACols; ++k)
{
if (B[k * BCols + col])
{
sum += A[row * ACols + k];
}
}
C[row * BCols + col] = sum;
}
}
template <typename T>
__global__ void BatchMFMKernel(const T *__restrict__ A, const bool *__restrict__ B, T *__restrict__ C, int batchSize, int ARows, int ACols, int BCols)
{
int batch = blockIdx.z;
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (batch < batchSize && row < ARows && col < BCols)
{
T sum = 0;
#pragma unroll
for (int k = 0; k < ACols; ++k)
{
if (B[(batch * ACols * BCols) + (k * BCols) + col])
{
sum += A[(batch * ARows * ACols) + (row * ACols) + k];
}
}
C[(batch * ARows * BCols) + (row * BCols) + col] = sum;
}
}
void gpu_BF16(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len)
{
const MFMDescriptor &d = *UnpackDescriptor<MFMDescriptor>(opaque, opaque_len);
const int M = d.batchSize; // batchSize
const int K = d.numRows; // inputDim
const int N = d.numCols; // outputDim
const int miniBatchSize = d.miniBatchSize;
const __nv_bfloat16 *x_dev = reinterpret_cast<const __nv_bfloat16 *>(buffers[0]);
const bool *Weight_dev = reinterpret_cast<const bool *>(buffers[1]);
__nv_bfloat16 *y_dev = reinterpret_cast<__nv_bfloat16 *>(buffers[2]);
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE);
if (miniBatchSize != 0) // 3d
{
dim3 gridDim((N + blockDim.x - 1) / blockDim.x,
(M + blockDim.y - 1) / blockDim.y,
M);
BatchMFMKernel<__nv_bfloat16><<<gridDim, blockDim, 0, stream>>>(x_dev, Weight_dev, y_dev, M, miniBatchSize, K, N);
}
else // 2d
{
dim3 gridDim((N + blockDim.x - 1) / blockDim.x,
(M + blockDim.y - 1) / blockDim.y);
MFMKernel<__nv_bfloat16><<<gridDim, blockDim, 0, stream>>>(x_dev, Weight_dev, y_dev, M, K, N);
}
ThrowIfError(cudaGetLastError());
}
}
最核心的cuda代码
pybind11_kernel_helpers.h
// This header extends kernel_helpers.h with the pybind11 specific interface to
// serializing descriptors. It also adds a pybind11 function for wrapping our
// custom calls in a Python capsule. This is separate from kernel_helpers so
// that the CUDA code itself doesn't include pybind11. I don't think that this
// is strictly necessary, but they do it in jaxlib, so let's do it here too.
#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
#include <pybind11/pybind11.h>
#include "kernel_helpers.h"
namespace gpu_ops
{
template <typename T>
pybind11::bytes PackDescriptor(const T &descriptor)
{
return pybind11::bytes(PackDescriptorAsString(descriptor));
}
template <typename T>
pybind11::capsule EncapsulateFunction(T *fn)
{
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
} // namespace gpu_ops
#endif
测试代码
import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call
from jax.core import ShapedArray
from jax import random
import time
from functools import partial
from colorama import Fore, Back, Style
import inspect
from ec.ops.build import gpu_ops
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def matrix_mul(a, b):
return jnp.matmul(a, b)
def test_2d():
x_host, A_host = init_data(batch_size=32)
for i in range(2):
print(x_host[i,:10])
print(A_host[i,:10])
print("intput_vector.shape: ",x_host.shape)
print("intput_matrix.shape: ",A_host.shape)
start_time = time.perf_counter()
result = jit_matrix_mul(x_host,A_host).block_until_ready()
end_time = time.perf_counter()
execution_time = (end_time - start_time) * 1000
print("result.shape: ",result.shape)
print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
print("The output results of random sampling: ")
for i in range(2):
print(result[i,:10])
def test_3d():
float_vector_batched_b = []
bool_matrix_b = []
for i in range(2):
float_vector_batched, bool_matrix = init_data(batch_size=32)
float_vector_batched_b.append(float_vector_batched[None, ...])
bool_matrix_b.append(bool_matrix[None, ...])
x_host = jnp.concatenate(float_vector_batched_b, axis=0)
A_host = jnp.concatenate(bool_matrix_b, axis=0)
for i in range(2):
print(x_host[i,i+1,:10])
print(A_host[i,i+1,:10])
print("intput_vector.shape: ",x_host.shape)
print("intput_matrix.shape: ",A_host.shape)
start_time = time.perf_counter()
result = jit_matrix_mul(x_host,A_host).block_until_ready()
end_time = time.perf_counter()
execution_time = (end_time - start_time) * 1000
print("result.shape: ",result.shape)
print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
print("The output results of random sampling: ")
for i in range(2):
print(result[i,i+1,:10])
def init_data(batch_size, matrix_rows=3, matrix_cols=10):
'''
Initializes data for matrix-vector multiplication tests.
This function creates a boolean matrix and a batch of floating-point vectors.
The boolean matrix is randomly filled with an equal distribution of True (1) and False (0) values,
simulating a sparse matrix environment. Each vector in the batch is initialized with random values
using a normal distribution, suitable for performing matrix-vector multiplication.
Parameters:
- batch_size (int): The number of vectors in the batch.
- matrix_rows (int): The number of rows in the boolean matrix, also determines the length of each vector.
- matrix_cols (int): The number of columns in the boolean matrix.
Returns:
- A tuple containing:
- A batch of floating-point vectors, shaped as (batch_size, matrix_rows).
- A boolean matrix, shaped as (matrix_rows, matrix_cols).
Note:
- The length of the vectors equals the number of rows in the matrix, ensuring compatibility for multiplication.
'''
total_elements = matrix_rows * matrix_cols
num_zeros = total_elements // 2
num_ones = total_elements - num_zeros
values = jnp.array([0, 1], dtype=jnp.bool_)
probabilities = jnp.array([0.5, 0.5])
key = random.PRNGKey(0)
key, subkey = random.split(key)
zero_values = random.choice(subkey, values, shape=(num_zeros,), p=probabilities)
key, subkey = random.split(key)
one_values = random.choice(subkey, values, shape=(num_ones,), p=probabilities)
bool_matrix = jnp.concatenate((zero_values, one_values)).reshape((matrix_rows, matrix_cols))
float_vector_batched = []
for _ in range(batch_size):
key, subkey = random.split(key)
float_vector = random.normal(subkey, (matrix_rows,), dtype=jnp.bfloat16)
float_vector_batched.append(float_vector[None,:])
return jnp.concatenate(float_vector_batched, axis=0),bool_matrix
def mvm_p_fwd(vector, matrix):
output = mvm_p.bind(vector, matrix)
return output
for _name, _value in gpu_ops.get_mvm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def element_type_to_descriptor_type_mapping(element_type):
_element_type_to_descriptor_type_mapping = {
ir.BF16Type.get(): gpu_ops.ElementType.BF16,
ir.F32Type.get(): gpu_ops.ElementType.F32,
}
return _element_type_to_descriptor_type_mapping.get(element_type)
def _mvm_p_fwd_cuda_lowering(ctx, vector,matrix):
vector_type = ir.RankedTensorType(vector.type)
vector_shape = vector_type.shape
matrix_type = ir.RankedTensorType(matrix.type)
matrix_shape = matrix_type.shape
if len(vector_shape) ==2:
batchSize,inputDim = vector_shape
inputDim,outPutDim = matrix_shape
result_shape = (batchSize, outPutDim)
opaque = gpu_ops.create_mvm_descriptor(
inputDim,
outPutDim,
batchSize,
0,
element_type_to_descriptor_type_mapping(vector_type.element_type),
)
out = custom_call(
b"gpu_BF16",
result_types=[
ir.RankedTensorType.get(result_shape, vector_type.element_type),
],
operands=[vector,matrix],
backend_config=opaque,
operand_layouts=default_layouts(vector_shape,matrix_shape),
result_layouts=default_layouts(result_shape),
).results
return out
else:
batchSize,miniBatchSize,inputDim = vector_shape
batchSize,inputDim,outPutDim = matrix_shape
result_shape = (batchSize, miniBatchSize,outPutDim)
opaque = gpu_ops.create_mvm_descriptor(
inputDim,
outPutDim,
batchSize,
miniBatchSize,
element_type_to_descriptor_type_mapping(vector_type.element_type),
)
out = custom_call(
b"gpu_BF16",
result_types=[
ir.RankedTensorType.get(result_shape, vector_type.element_type),
],
operands=[vector,matrix],
backend_config=opaque,
operand_layouts=default_layouts(vector_shape,matrix_shape),
result_layouts=default_layouts(result_shape),
).results
return out
def _mvm_p_fwd_abstract(vector,matrix):
w_dtype = dtypes.canonicalize_dtype(vector.dtype)
if len(vector.shape) == 2:
batchSize,inputDim = vector.shape
inputDim,outPutDim = matrix.shape
res_shape = (batchSize,outPutDim)
else:
batchSize,miniBatchSize,inputDim = vector.shape
batchSize,inputDim,outPutDim = matrix.shape
res_shape = (batchSize,miniBatchSize,outPutDim)
return (
ShapedArray(res_shape, w_dtype, named_shape=matrix.named_shape), # output
)
def test_2d_ops():
x_host, A_host = init_data(batch_size=32)
for i in range(2):
print(x_host[i,:10])
print(A_host[i,:10])
print("intput_vector.shape: ",x_host.shape)
print("intput_matrix.shape: ",A_host.shape)
start_time = time.perf_counter()
result = mvm_p_fwd(x_host,A_host)[0]
end_time = time.perf_counter()
execution_time = (end_time - start_time) * 1000
print("result.shape: ",result.shape)
print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
print(Style.RESET_ALL)
for i in range(2):
print(result[i,:10])
def test_3d_ops():
float_vector_batched_b = []
bool_matrix_b = []
for i in range(2):
float_vector_batched, bool_matrix = init_data(batch_size=32)
float_vector_batched_b.append(float_vector_batched[None, ...])
bool_matrix_b.append(bool_matrix[None, ...])
x_host = jnp.concatenate(float_vector_batched_b, axis=0)
A_host = jnp.concatenate(bool_matrix_b, axis=0)
for i in range(2):
print(x_host[i,i+1,:10])
print(A_host[i,i+1,:10])
print("intput_vector.shape: ",x_host.shape)
print("intput_matrix.shape: ",A_host.shape)
start_time = time.perf_counter()
result = mvm_p_fwd(x_host,A_host)[0]
end_time = time.perf_counter()
execution_time = (end_time - start_time) * 1000
print("result.shape: ",result.shape)
print( Fore.RED, inspect.currentframe().f_code.co_name," execution_time: ",execution_time , " ms",Style.RESET_ALL)
print("The output results of random sampling: ")
for i in range(2):
print(result[i,i+1,:10])
def print_section_header(func):
title = func.__name__
header = f"\n{'=' * 10} {title} start {'=' * 10}\n"
print(header)
func()
ender = f"\n{'=' * 10} {title} end {'=' * 10}\n"
print(ender)
if __name__ == '__main__':
mvm_p = core.Primitive("mvm_p_fwd")
mvm_p.multiple_results = True
mvm_p.def_impl(partial(xla.apply_primitive, mvm_p))
mvm_p.def_abstract_eval(_mvm_p_fwd_abstract)
mlir.register_lowering(
mvm_p,
_mvm_p_fwd_cuda_lowering,
platform="gpu",
)
jit_matrix_mul = jax.jit(matrix_mul)
print_section_header(test_2d)
print_section_header(test_2d_ops)
print_section_header(test_3d)
print_section_header(test_3d_ops)