Skip to content

Commit

Permalink
Add support for arbitrary group sizes in gemm_grouped_batched! (Jul…
Browse files Browse the repository at this point in the history
  • Loading branch information
lpawela committed Jul 19, 2024
1 parent 1bdbb86 commit beccab1
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 5 deletions.
83 changes: 78 additions & 5 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,63 @@ end
## (GE) general matrix-matrix multiplication grouped batched
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
(:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))
@eval begin
function gemm_grouped_batched!(transA::Vector{Char},
transB::Vector{Char},
alpha::Vector{$elty},
A::Vector{<:Vector{<:StridedCuMatrix{$elty}}},
B::Vector{<:Vector{<:StridedCuMatrix{$elty}}},
beta::Vector{$elty},
C::Vector{<:Vector{<:StridedCuMatrix{$elty}}})

if length(A) != length(B) || length(A) != length(C)
throw(DimensionMismatch("A, B and C must contain the same number of groups"))
end
group_count = length(A)
for i=1:group_count
if length(A[i]) != length(B[i]) || length(A[i]) != length(C[i])
throw(DimensionMismatch("A, B and C must contain the same number of matrices"))
end
end
group_size = length.(A)

for i = 1:group_count
m = size(A[i][1], transA[i] == 'N' ? 1 : 2)
k = size(A[i][1], transA[i] == 'N' ? 2 : 1)
n = size(B[i][1], transB[i] == 'N' ? 2 : 1)
if m != size(C[i][1],1) || n != size(C[i][1],2) || k != size(B[i][1], transB[i] == 'N' ? 1 : 2)
throw(DimensionMismatch(""))
end
end

transa = convert.(cublasOperation_t, transA)
transb = convert.(cublasOperation_t, transB)
m = [size(A[i][1], transA[i] == 'N' ? 1 : 2) for i = 1 : group_count]
k = [size(A[i][1], transA[i] == 'N' ? 2 : 1) for i = 1 : group_count]
n = [size(B[i][1], transB[i] == 'N' ? 2 : 1) for i = 1 : group_count]
lda = [max(1,stride(A[i][1],2)) for i = 1 : group_count]
ldb = [max(1,stride(B[i][1],2)) for i = 1 : group_count]
ldc = [max(1,stride(C[i][1],2)) for i = 1 : group_count]
Aptrs = unsafe_batch(reduce(vcat, A))
Bptrs = unsafe_batch(reduce(vcat, B))
Cptrs = unsafe_batch(reduce(vcat, C))

if CUBLAS.version() >= v"12.0"
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
else
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
end
unsafe_free!(Cptrs)
unsafe_free!(Bptrs)
unsafe_free!(Aptrs)

C
end
end

# Group size hardcoded to one
@eval begin
function gemm_grouped_batched!(transA::Vector{Char},
transB::Vector{Char},
Expand Down Expand Up @@ -1260,24 +1317,40 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped
unsafe_free!(Cptrs)
unsafe_free!(Bptrs)
unsafe_free!(Aptrs)

C
end
end
end

function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char}, alpha::Vector{T},
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
beta = [zero(T) for i = 1:length(transA)]
C = CuMatrix{T}[similar(B[i], (size(A[i], transA[i] == 'N' ? 1 : 2), size(B[i], transB[i] == 'N' ? 2 : 1))) for i in 1:length(A)]
A::Vector{<:Vector{<:StridedCuMatrix{T}}}, B::Vector{<:Vector{<:StridedCuMatrix{T}}}) where T
num_groups = length(A)
group_sizes = length.(A)
beta = [zero(T) for i = 1:num_groups]
C = [[similar(B[i][j], (size(A[i][j], transA[i] == 'N' ? 1 : 2), size(B[i][j], transB[i] == 'N' ? 2 : 1))) for j in 1:group_sizes[i]] for i in 1:num_groups]
gemm_grouped_batched!(transA, transB, alpha, A, B, beta, C)
end

function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char},
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
A::Vector{<:Vector{<:StridedCuMatrix{T}}}, B::Vector{<:Vector{<:StridedCuMatrix{T}}}) where T
alpha = [one(T) for i = 1:length(transA)]
gemm_grouped_batched(transA, transB, alpha, A, B)
end

# Group size hardcoded to one
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char}, alpha::Vector{T},
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
beta = [zero(T) for i = 1:length(transA)]
C = CuMatrix{T}[similar(B[i], (size(A[i], transA[i] == 'N' ? 1 : 2), size(B[i], transB[i] == 'N' ? 2 : 1))) for i in 1:length(A)]
gemm_grouped_batched!(transA, transB, alpha, A, B, beta, C)
end

function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char},
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
alpha = [one(T) for i = 1:length(transA)]
gemm_grouped_batched(transA, transB, alpha, A, B)
end

## (GE) general matrix-matrix multiplication batched
for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :Float64),
(:cublasSgemmBatched, :cublasSgemmBatched_64, :Float32),
Expand Down
38 changes: 38 additions & 0 deletions test/libraries/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1711,6 +1711,44 @@ end
end
end

if CUDA.CUBLAS.version() >= v"12.4.2"
@testset "elty = $elty" for elty in [Float32, Float64]
num_groups = 10
group_sizes = collect(1:num_groups)
transA = ['N' for i in 1:num_groups]
transB = ['N' for i in 1:num_groups]
alpha = rand(elty, num_groups)
beta = rand(elty, num_groups)
# generate matrices
bA = [[rand(elty,3*i,2*i) for j in 1:group_sizes[i]] for i in 1:num_groups]
bB = [[rand(elty,2*i,5*i) for j in 1:group_sizes[i]] for i in 1:num_groups]
bC = [[rand(elty,3*i,5*i) for j in 1:group_sizes[i]] for i in 1:num_groups]
# move to device
bd_A = [[CuArray(bA[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
bd_B = [[CuArray(bB[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
bd_C = [[CuArray(bC[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
@testset "gemm_grouped_batched!" begin
# C = (alpha*A)*B + beta*C
CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A,bd_B,beta,bd_C)
for i in 1:num_groups, j in 1:group_sizes[i]
bC[i][j] = alpha[i] * bA[i][j] * bB[i][j] + beta[i] * bC[i][j]
h_C = Array(bd_C[i][j])
@test bC[i][j] h_C
end
end

@testset "gemm_grouped_batched" begin
bd_C = CUBLAS.gemm_grouped_batched(transA,transB,bd_A,bd_B)
for i in 1:num_groups, j in 1:group_sizes[i]
bC[i][j] = bA[i][j] * bB[i][j]
h_C = Array(bd_C[i][j])
@test bC[i][j] h_C
end
end
end
end

# Group size hardcoded to one
if CUDA.CUBLAS.version() >= v"12.4.2"
@testset "elty = $elty" for elty in [Float32, Float64]

Expand Down

0 comments on commit beccab1

Please sign in to comment.