Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert Interop to Truffle Libraries #1

Merged
merged 3 commits into from
Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions projects/com.nvidia.grcuda.parser.antlr/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -30,23 +31,18 @@ QUIETLY$(MX_VERBOSE) = @
SOURCE=${PARSER_PATH}/GrCUDA.g4
STAMP=${PARSER_PATH}/GrCUDA.g4.stamp

GEN_FILES=${OUTPUT_PATH}/GrCUDALexer.java \
${OUTPUT_PATH}/GrCUDAParser.java \
${OUTPUT_PATH}/GrCUDABaseVisitor.java \
${OUTPUT_PATH}/GrCUDABaseListener.java \
${OUTPUT_PATH}/GrCUDAVisitor.java \
${OUTPUT_PATH}/GrCUDAListener.java
GEN_FILES=${OUTPUT_PATH}/GrCUDALexer.java ${OUTPUT_PATH}/GrCUDAParser.java

.PHONY: default clean
default: ${GEN_FILES}

${STAMP}: ${SOURCE}
$(QUIETLY) touch $@
$(QUIETLY) ${JAVA_HOME}/bin/java -cp ${ANTLR_JAR} org.antlr.v4.Tool -visitor -package ${PARSER_PKG} -o ${OUTPUT_PATH} ${SOURCE}
$(QUIETLY) ${JAVA_HOME}/bin/java -cp ${ANTLR_JAR} org.antlr.v4.Tool -no-visitor -no-listener -package ${PARSER_PKG} -o ${OUTPUT_PATH} ${SOURCE}


${OUTPUT_PATH}/%.java: ${STAMP}
$(QUIETLY) python ${POSTPROCESSOR} $@
$(QUIETLY) python3 ${POSTPROCESSOR} ${GEN_FILES}

clean:
rm -f ${TARGETS}
18 changes: 9 additions & 9 deletions projects/com.nvidia.grcuda.parser.antlr/postprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -87,17 +88,16 @@ def transform(lines):


def main():
if len(sys.argv) != 2:
print('file argument missing', file=sys.stderr)
if len(sys.argv) < 2:
print('file arguments missing', file=sys.stderr)
sys.exit(1)

file_name = sys.argv[1]
with open(file_name, 'rt') as in_lines:
out_lines = transform(in_lines)

with open(file_name, 'w') as out_file:
out_file.write(out_lines)

file_names = sys.argv[1:]
for file_name in file_names:
with open(file_name, 'rt') as in_lines:
out_lines = transform(in_lines)
with open(file_name, 'w') as out_file:
out_file.write(out_lines)

if __name__ == '__main__':
main()
236 changes: 150 additions & 86 deletions projects/com.nvidia.grcuda/src/com/nvidia/grcuda/DeviceArray.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -29,13 +30,62 @@

import com.nvidia.grcuda.gpu.CUDARuntime;
import com.nvidia.grcuda.gpu.LittleEndianNativeArrayView;
import com.oracle.truffle.api.interop.ForeignAccess;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.interop.UnknownIdentifierException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.profiles.ValueProfile;

@ExportLibrary(InteropLibrary.class)
public final class DeviceArray implements TruffleObject {

private static final MemberSet PUBLIC_MEMBERS = new MemberSet();
private static final MemberSet MEMBERS = new MemberSet("pointer");

@ExportLibrary(InteropLibrary.class)
public static final class MemberSet implements TruffleObject {

@CompilationFinal(dimensions = 1) private final String[] values;

public MemberSet(String... values) {
this.values = values;
}

@ExportMessage
@SuppressWarnings("static-method")
public boolean hasArrayElements() {
return true;
}

@ExportMessage
public long getArraySize() {
return values.length;
}

@ExportMessage
public boolean isArrayElementReadable(long index) {
return index >= 0 && index < values.length;
}

@ExportMessage
public Object readArrayElement(long index) throws InvalidArrayIndexException {
if ((index < 0) || (index >= values.length)) {
CompilerDirectives.transferToInterpreter();
throw InvalidArrayIndexException.create(index);
}
return values[(int) index];
}
}

private final CUDARuntime runtime;

/** Data type of elements stored in the array. */
Expand Down Expand Up @@ -68,24 +118,55 @@ public long getPointer() {
return nativeView.getStartAddress();
}

public boolean isIndexValid(long index) {
return (index >= 0) && (index < numElements);
@Override
public String toString() {
return "DeviceArray(elementType=" + elementType + ", numElements=" + numElements + ", nativeView=" + nativeView + ')';
}

@Override
public ForeignAccess getForeignAccess() {
return DeviceArrayForeign.ACCESS;
protected void finalize() {
runtime.cudaFree(nativeView);
}

// Implementation of InteropLibrary

@ExportMessage
@SuppressWarnings("static-method")
boolean hasArrayElements() {
return true;
}

long getSizeElements() {
@ExportMessage
long getArraySize() {
return numElements;
}

Number readElement(long index) {
@ExportMessage
boolean isArrayElementReadable(long index) {
return index >= 0 && index < numElements;
}

@ExportMessage
boolean isArrayElementModifiable(long index) {
return index >= 0 && index < numElements;
}

@SuppressWarnings("static-method")
@ExportMessage
boolean isArrayElementInsertable(@SuppressWarnings("unused") long index) {
return false;
}

@ExportMessage
Object readArrayElement(long index,
@Shared("elementType") @Cached("createIdentityProfile()") ValueProfile elementTypeProfile) throws InvalidArrayIndexException {

if ((index < 0) || (index >= numElements)) {
throw new ArrayIndexOutOfBoundsException();
CompilerDirectives.transferToInterpreter();
throw InvalidArrayIndexException.create(index);

}
switch (elementType) {
switch (elementTypeProfile.profile(elementType)) {
case BYTE:
case CHAR:
return nativeView.getByte(index);
Expand All @@ -103,99 +184,82 @@ Number readElement(long index) {
return null;
}

void writeElement(long index, Number value) {
if ((index < 0) || (index >= numElements)) {
throw new ArrayIndexOutOfBoundsException();
}
switch (elementType) {
case BYTE:
case CHAR:
nativeView.setByte(index, value.byteValue());
break;
case SHORT:
nativeView.setShort(index, value.shortValue());
break;
case INT:
nativeView.setInt(index, value.intValue());
break;
case LONG:
nativeView.setLong(index, value.longValue());
break;
case FLOAT:
nativeView.setFloat(index, value.floatValue());
break;
case DOUBLE:
nativeView.setDouble(index, value.doubleValue());
break;
}
}

@Override
public String toString() {
return "DeviceArray(elementType=" + elementType +
", numElements=" + numElements + ", nativeView=" + nativeView + ')';
}

@Override
protected void finalize() throws Throwable {
runtime.cudaFree(nativeView);
super.finalize();
}
@ExportMessage
void writeArrayElement(long index, Object value,
@CachedLibrary(limit = "3") InteropLibrary valueLibrary,
@Shared("elementType") @Cached("createIdentityProfile()") ValueProfile elementTypeProfile) throws UnsupportedTypeException, InvalidArrayIndexException {

public static final class ReadElementNode extends Node {
private final ValueProfile profile = ValueProfile.createIdentityProfile();

public Number readElement(DeviceArray deviceArray, long index) {
if ((index < 0) || (index >= deviceArray.numElements)) {
throw new ArrayIndexOutOfBoundsException();
}
switch (profile.profile(deviceArray.elementType)) {
case BYTE:
case CHAR:
return deviceArray.nativeView.getByte(index);
case SHORT:
return deviceArray.nativeView.getShort(index);
case INT:
return deviceArray.nativeView.getInt(index);
case LONG:
return deviceArray.nativeView.getLong(index);
case FLOAT:
return deviceArray.nativeView.getFloat(index);
case DOUBLE:
return deviceArray.nativeView.getDouble(index);
}
return null;
if ((index < 0) || (index >= numElements)) {
CompilerDirectives.transferToInterpreter();
throw InvalidArrayIndexException.create(index);
}
}
try {
switch (elementTypeProfile.profile(elementType)) {

public static final class WriteElementNode extends Node {
private final ValueProfile profile = ValueProfile.createIdentityProfile();

public void writeElement(DeviceArray deviceArray, long index, Number value) {
if ((index < 0) || (index >= deviceArray.numElements)) {
throw new ArrayIndexOutOfBoundsException();
}
switch (profile.profile(deviceArray.elementType)) {
case BYTE:
case CHAR:
deviceArray.nativeView.setByte(index, value.byteValue());
nativeView.setByte(index, valueLibrary.asByte(value));
break;
case SHORT:
deviceArray.nativeView.setShort(index, value.shortValue());
nativeView.setShort(index, valueLibrary.asShort(value));
break;
case INT:
deviceArray.nativeView.setInt(index, value.intValue());
nativeView.setInt(index, valueLibrary.asInt(value));
break;
case LONG:
deviceArray.nativeView.setLong(index, value.longValue());
nativeView.setLong(index, valueLibrary.asLong(value));
break;
case FLOAT:
deviceArray.nativeView.setFloat(index, value.floatValue());
// going via "double" to allow floats to be initialized with doubles
nativeView.setFloat(index, (float) valueLibrary.asDouble(value));
break;
case DOUBLE:
deviceArray.nativeView.setDouble(index, value.doubleValue());
nativeView.setDouble(index, valueLibrary.asDouble(value));
break;
}
} catch (UnsupportedMessageException e) {
CompilerDirectives.transferToInterpreter();
throw UnsupportedTypeException.create(new Object[]{value}, "value cannot be coerced to " + elementType);
}
}

@ExportMessage
@SuppressWarnings("static-method")
boolean hasMembers() {
return true;
}

@ExportMessage
@SuppressWarnings("static-method")
Object getMembers(boolean includeInternal) {
return includeInternal ? MEMBERS : PUBLIC_MEMBERS;
}

@ExportMessage
@SuppressWarnings("static-method")
boolean isMemberReadable(String member,
@Shared("member") @Cached("createIdentityProfile()") ValueProfile memberProfile) {
return "pointer".equals(memberProfile.profile(member));
}

@ExportMessage
Object readMember(String member,
@Shared("member") @Cached("createIdentityProfile()") ValueProfile memberProfile) throws UnknownIdentifierException {
if (!isMemberReadable(member, memberProfile)) {
CompilerDirectives.transferToInterpreter();
throw UnknownIdentifierException.create(member);
}
return getPointer();
}

@ExportMessage
@SuppressWarnings("static-method")
boolean isPointer() {
return true;
}

@ExportMessage
long asPointer() {
return getPointer();
}
}
Loading