Skip to content

Commit

Permalink
deal with offsets within the kernelparameterswrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Bamford committed Apr 29, 2015
1 parent ac4aa41 commit 6302166
Showing 1 changed file with 36 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.nd4j.linalg.jcublas.util.PreparedKernelParams;
import org.nd4j.linalg.jcublas.util.KernelParamsWrapper;
import org.nd4j.linalg.util.ArrayUtil;


Expand Down Expand Up @@ -263,41 +263,31 @@ private JCudaBuffer toArgs(Object[] extraArgs, String dataType) {


private void invoke(Accumulation op) {
JCudaBuffer xBuffer = (JCudaBuffer) op.x().data();
//Pointer xPointer = xBuffer.getHostPointer().withByteOffset(xBuffer.getElementSize() * op.x().offset());

JCudaBuffer result = null;

int resultLength = 32;
if (op.x().data().dataType() == DataBuffer.DOUBLE) {
// double[] resultBuffer = new double[resultLength];
// for (int i = 0; i < resultBuffer.length; i++)
// resultBuffer[i] = op.zero().doubleValue();
result = new CudaDoubleDataBuffer(resultLength);


} else {
result = new CudaFloatDataBuffer(resultLength);
}

if (op.y() != null) {
JCudaBuffer yBuffer = (JCudaBuffer) op.y().data();
//Pointer yPointer = yBuffer.getHostPointer().withByteOffset(op.y().offset() * yBuffer.getElementSize());

//int n,int xOffset,int yOffset, double *dx, double *dy,int incx,int incy,double *result
Object[] kernelParams = new Object[] {
new int[]{op.n()},
new int[]{op.x().offset()},
new int[]{op.y().offset()},
xBuffer,
yBuffer,
new int[]{op.x().majorStride()},
new int[]{op.y().majorStride()},
op.n(),
op.x().offset(),
op.y().offset(),
op.x(),
op.y(),
op.x().majorStride(),
op.y().majorStride(),
toArgs(op.extraArgs(), getType(op)),
result
};

try(PreparedKernelParams kParams = new PreparedKernelParams(kernelParams)) {
try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams)) {

invokeFunction(op, kParams.getKernelParameters());
setResultForOp(op, result.getDevicePointer());
Expand All @@ -313,13 +303,13 @@ private void invoke(Accumulation op) {
Object[] kernelParams = new Object[] {
op.n(),
op.x().offset(),
xBuffer,
op.x(),
op.x().majorStride(),
toArgs(op.extraArgs(), getType(op)),
result
};

try(PreparedKernelParams kParams = new PreparedKernelParams(kernelParams)) {
try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams)) {

invokeFunction(op, kParams.getKernelParameters());
setResultForOp(op, result.getDevicePointer());
Expand Down Expand Up @@ -361,28 +351,23 @@ private void setResultForOp(Accumulation acc, Pointer devicePointer) {


private void invoke(ScalarOp op) {
JCudaBuffer xBuffer = (JCudaBuffer) op.x().data();
//Pointer xPointer = xBuffer.getHostPointer().withByteOffset(op.x().offset() * xBuffer.getElementSize());

JCudaBuffer zBuffer = (JCudaBuffer) op.z().data();
//Pointer zPointer = zBuffer.getHostPointer().withByteOffset(zBuffer.getElementSize() * op.z().offset());

if (op.y() != null) {
JCudaBuffer yBuffer = (JCudaBuffer) op.y().data();
Pointer yPointer = yBuffer.getHostPointer().withByteOffset(yBuffer.getElementSize() * op.y().offset());
Object[] kernelParams = new Object[]{
new int[]{op.n()},
new int[]{op.x().offset()},
new int[]{op.y().offset()},
xBuffer,
yPointer,
new int[]{op.x().majorStride()},
new int[]{op.y().majorStride()},
op.n(),
op.x().offset(),
op.y().offset(),
op.x(),
op.y(),
op.x().majorStride(),
op.y().majorStride(),
toArgs(op.extraArgs(), getType(op)),
zBuffer
op.z()
};

try(PreparedKernelParams kParams = new PreparedKernelParams(kernelParams)) {
try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams)) {

invokeFunction(op, kParams.getKernelParameters());
} catch(Exception e) {
Expand All @@ -395,16 +380,16 @@ private void invoke(ScalarOp op) {
//int n, int idx,double dx,double *dy,int incy,double *result

Object[] kernelParams = new Object[]{
new int[]{op.n()},
new int[]{op.x().offset()},
op.n(),
op.x().offset(),
PointerUtil.getPointer(op),
xBuffer,
new int[]{op.x().majorStride()},
op.x(),
op.x().majorStride(),
toArgs(op.extraArgs(), getType(op)),
zBuffer
op.z()
};

try(PreparedKernelParams kParams = new PreparedKernelParams(kernelParams)) {
try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams)) {

invokeFunction(op, kParams.getKernelParameters());
} catch(Exception e) {
Expand All @@ -423,15 +408,8 @@ private String getType(Op op) {


private void invoke(TransformOp op) {
JCudaBuffer xBuffer = (JCudaBuffer) op.x().data();
//Pointer xPointer = xBuffer.getHostPointer().withByteOffset(xBuffer.getElementSize() * op.x().offset());

JCudaBuffer zBuffer = (JCudaBuffer) op.z().data();
//Pointer zPointer = zBuffer.getHostPointer().withByteOffset(zBuffer.getElementSize() * op.z().offset());


if (op.y() != null) {
JCudaBuffer yBuffer = (JCudaBuffer) op.y().data();
//Pointer yPointer = yBuffer.getHostPointer().withByteOffset(op.y().offset() * yBuffer.getElementSize());

/**
* Construct pointer arguments in the following order:
Expand All @@ -447,17 +425,17 @@ private void invoke(TransformOp op) {
op.n(),
op.x().offset(),
op.y().offset(),
xBuffer,
yBuffer,
op.x(),
op.y(),
op.x().majorStride(),
op.y().majorStride(),
toArgs(op.extraArgs(), getType(op)),
zBuffer
op.z()
};

try(PreparedKernelParams kParams = new PreparedKernelParams(kernelParams)) {
try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams)) {
invokeFunction(op, kParams.getKernelParameters());
zBuffer.copyToHost();
((JCudaBuffer)op.z().data()).copyToHost();
} catch(Exception e) {
throw new RuntimeException("Could not execute kernel", e);
}
Expand All @@ -468,15 +446,15 @@ private void invoke(TransformOp op) {
Object[] kernelParams = new Object[]{
op.n(),
op.x().offset(),
xBuffer,
op.x(),
op.x().majorStride(),
toArgs(op.extraArgs(), getType(op)),
zBuffer
op.z()
};

try(PreparedKernelParams kParams = new PreparedKernelParams(kernelParams)) {
try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams)) {
invokeFunction(op, kParams.getKernelParameters());
zBuffer.copyToHost();
((JCudaBuffer)op.z().data()).copyToHost();
} catch(Exception e) {
throw new RuntimeException("Could not execute kernel", e);
}
Expand Down

0 comments on commit 6302166

Please sign in to comment.