Skip to content

Commit

Permalink
add POC support for cuBLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasstadler committed Mar 11, 2020
1 parent ff8fbc4 commit f191fe2
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.graalvm.options.OptionKey;

import com.nvidia.grcuda.cublas.CUBLASRegistry;
import com.nvidia.grcuda.cuml.CUMLRegistry;
import com.nvidia.grcuda.functions.BindFunction;
import com.nvidia.grcuda.functions.BindKernelFunction;
Expand Down Expand Up @@ -86,6 +87,11 @@ public GrCUDAContext(Env env) {
namespace.addNamespace(ml);
new CUMLRegistry(this).registerCUMLFunctions(ml);
}
if (this.getOption(GrCUDAOptions.CuBLASEnabled)) {
Namespace blas = new Namespace(CUBLASRegistry.NAMESPACE);
namespace.addNamespace(blas);
new CUBLASRegistry(this).registerCUBLASFunctions(blas);
}
this.rootNamespace = namespace;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.graalvm.options.OptionKey;
import org.graalvm.options.OptionStability;

import com.nvidia.grcuda.cublas.CUBLASRegistry;
import com.nvidia.grcuda.cuml.CUMLRegistry;
import com.oracle.truffle.api.Option;

Expand All @@ -41,6 +42,12 @@ private GrCUDAOptions() {
// no instances
}

@Option(category = OptionCategory.USER, help = "Enable cuBLAS support.", stability = OptionStability.STABLE) //
public static final OptionKey<Boolean> CuBLASEnabled = new OptionKey<>(true);

@Option(category = OptionCategory.USER, help = "Set the location of the cublas library.", stability = OptionStability.STABLE) //
public static final OptionKey<String> CuBLASLibrary = new OptionKey<>(CUBLASRegistry.DEFAULT_LIBRARY);

@Option(category = OptionCategory.USER, help = "Enable cuML support.", stability = OptionStability.STABLE) //
public static final OptionKey<Boolean> CuMLEnabled = new OptionKey<>(true);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.nvidia.grcuda.cublas;

import static com.nvidia.grcuda.functions.Function.INTEROP;
import static com.nvidia.grcuda.functions.Function.expectLong;

import java.util.ArrayList;

import com.nvidia.grcuda.GrCUDAContext;
import com.nvidia.grcuda.GrCUDAException;
import com.nvidia.grcuda.GrCUDAInternalException;
import com.nvidia.grcuda.GrCUDAOptions;
import com.nvidia.grcuda.Namespace;
import com.nvidia.grcuda.functions.ExternalFunctionFactory;
import com.nvidia.grcuda.functions.Function;
import com.nvidia.grcuda.gpu.UnsafeHelper;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.interop.ArityException;
import com.oracle.truffle.api.interop.InteropException;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;

public class CUBLASRegistry {
public static final String DEFAULT_LIBRARY = "libcublas.so";
public static final String DEFAULT_LIBRARY_HINT = " (CuBLAS library location can be set via the --grcuda.CuBLASLibrary= option. " +
"CuBLAS support can be disabled via --grcuda.CuBLASEnabled=false.";
public static final String NAMESPACE = "BLAS";

private final GrCUDAContext context;
private final String libraryPath;

@CompilationFinal private TruffleObject cublasCreateFunction;
@CompilationFinal private TruffleObject cublasDestroyFunction;
@CompilationFinal private TruffleObject cublasCreateFunctionNFI;
@CompilationFinal private TruffleObject cublasDestroyFunctionNFI;

private Long cublasHandle = null;

public CUBLASRegistry(GrCUDAContext context) {
this.context = context;
libraryPath = context.getOption(GrCUDAOptions.CuBLASLibrary);
}

private void ensureInitialized() {
if (cublasHandle == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();

// create NFI function objects for handle creation and destruction

cublasCreateFunctionNFI = CUBLAS_CUBLASCREATE.makeFunction(context.getCUDARuntime(), libraryPath, DEFAULT_LIBRARY_HINT);
cublasDestroyFunctionNFI = CUBLAS_CUBLASDESTROY.makeFunction(context.getCUDARuntime(), libraryPath, DEFAULT_LIBRARY_HINT);

// create wrapper for cublasCreate: cublasError_t cublasCreate(long* handle) -> int
// cublasCreate()
cublasCreateFunction = new Function(CUBLAS_CUBLASCREATE.getName()) {
@Override
@TruffleBoundary
public Object call(Object[] arguments) throws ArityException {
checkArgumentLength(arguments, 0);
try (UnsafeHelper.Integer64Object handle = UnsafeHelper.createInteger64Object()) {
Object result = INTEROP.execute(cublasCreateFunctionNFI, handle.getAddress());
checkCUBLASReturnCode(result, "cublasCreate");
return handle.getValue();
} catch (InteropException e) {
throw new GrCUDAInternalException(e);
}
}
};

// create wrapper for cublasDestroy: cublasError_t cublasDestroy(long handle) -> void
// cublasDestroy(long handle)
cublasDestroyFunction = new Function(CUBLAS_CUBLASDESTROY.getName()) {
@Override
@TruffleBoundary
public Object call(Object[] arguments) throws ArityException, UnsupportedTypeException {
checkArgumentLength(arguments, 1);
long handle = expectLong(arguments[0]);
try {
Object result = INTEROP.execute(cublasDestroyFunctionNFI, handle);
checkCUBLASReturnCode(result, "cublasDestroy");
return result;
} catch (InteropException e) {
throw new GrCUDAInternalException(e);
}
}
};

try {
Object result = INTEROP.execute(cublasCreateFunction);
cublasHandle = expectLong(result);

context.addDisposable(this::cuBLASShutdown);
} catch (InteropException e) {
throw new GrCUDAInternalException(e);
}
}
}

private void cuBLASShutdown() {
CompilerAsserts.neverPartOfCompilation();
if (cublasHandle != null) {
try {
Object result = InteropLibrary.getFactory().getUncached().execute(cublasDestroyFunction, cublasHandle);
checkCUBLASReturnCode(result, CUBLAS_CUBLASDESTROY.getName());
cublasHandle = null;
} catch (InteropException e) {
throw new GrCUDAInternalException(e);
}
}
}

public void registerCUBLASFunctions(Namespace namespace) {
// Create function wrappers (decorators for all functions except handle con- and
// destruction)
for (ExternalFunctionFactory factory : functions) {
final Function wrapperFunction = new Function(factory.getName()) {

private Function nfiFunction;

@Override
@TruffleBoundary
protected Object call(Object[] arguments) {
ensureInitialized();

Object[] argsWithHandle = new Object[arguments.length + 1];
System.arraycopy(arguments, 0, argsWithHandle, 1, arguments.length);
argsWithHandle[0] = cublasHandle;

try {
if (nfiFunction == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
nfiFunction = factory.makeFunction(context.getCUDARuntime(), libraryPath, DEFAULT_LIBRARY_HINT);
}
Object result = INTEROP.execute(nfiFunction, argsWithHandle);
context.getCUDARuntime().cudaDeviceSynchronize();
checkCUBLASReturnCode(result, nfiFunction.getName());
return result;
} catch (InteropException e) {
throw new GrCUDAInternalException(e);
}
}
};
namespace.addFunction(wrapperFunction);
}
}

private static void checkCUBLASReturnCode(Object result, String... function) {
CompilerAsserts.neverPartOfCompilation();
int returnCode;
try {
returnCode = InteropLibrary.getFactory().getUncached().asInt(result);
} catch (UnsupportedMessageException e) {
throw new GrCUDAInternalException("expected return code as Integer object in " + function + ", got " + result.getClass().getName());
}
if (returnCode != 0) {
throw new GrCUDAException(returnCode, cublasReturnCodeToString(returnCode), function);
}
}

private static String cublasReturnCodeToString(int returnCode) {
switch (returnCode) {
case 0:
return "CUBLAS_STATUS_SUCCESS";
case 1:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case 3:
return "CUBLAS_STATUS_ALLOC_FAILED";
case 7:
return "CUBLAS_STATUS_INVALID_VALUE";
case 8:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case 11:
return "CUBLAS_STATUS_MAPPING_ERROR";
case 13:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case 14:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case 15:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case 16:
return "CUBLAS_STATUS_LICENSE_ERROR";
default:
return "unknown error code: " + returnCode;
}
}

private static final ExternalFunctionFactory CUBLAS_CUBLASCREATE = new ExternalFunctionFactory("cublasCreate", "cublasCreate_v2", "(pointer): sint32");
private static final ExternalFunctionFactory CUBLAS_CUBLASDESTROY = new ExternalFunctionFactory("cublasDestroy", "cublasDestroy_v2", "(sint64): sint32");

private static final ArrayList<ExternalFunctionFactory> functions = new ArrayList<>();

static {
for (char type : new char[]{'S', 'D', 'C', 'Z'}) {
functions.add(new ExternalFunctionFactory("cublas" + type + "axpy", "cublas" + type + "axpy_v2",
"(sint64, sint32, pointer, pointer, sint32, pointer, sint32): sint32"));
functions.add(new ExternalFunctionFactory("cublas" + type + "gemm", "cublas" + type + "gemm_v2",
"(sint64, sint32, sint32, sint32, sint32, sint32, pointer, pointer, sint32, pointer, sint32, pointer, pointer, sint32): sint32"));
}
}
}

0 comments on commit f191fe2

Please sign in to comment.