Skip to content

Commit

Permalink
[Feature] Compatibility to DLPack 0.6 in tensoradapter (dmlc#3803)
Browse files Browse the repository at this point in the history
* compatibility to DLPack 0.6 in tensoradapter

* fix

* oops

Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com>
  • Loading branch information
BarclayII and VoVAllen committed Mar 10, 2022
1 parent 27d3af0 commit 520cef8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
8 changes: 1 addition & 7 deletions tensoradapter/include/tensoradapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
#include <dlpack/dlpack.h>
#include <vector>

#if defined(WIN32) || defined(_WIN32)
#define TA_EXPORTS __declspec(dllexport)
#else
#define TA_EXPORTS
#endif

namespace tensoradapter {

extern "C" {
Expand All @@ -31,7 +25,7 @@ extern "C" {
* \param ctx The device
* \return The allocated tensor
*/
TA_EXPORTS DLManagedTensor* TAempty(
DLManagedTensor* TAempty(
std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);

}
Expand Down
16 changes: 16 additions & 0 deletions tensoradapter/include/tensoradapter_exports.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*!
* Copyright (c) 2020 by Contributors
* \file tensoradapter_exports.h
* \brief Header file for functions exposed by the adapter library.
*/

#ifndef TENSORADAPTER_EXPORTS_H_
#define TENSORADAPTER_EXPORTS_H_

#if defined(WIN32) || defined(_WIN32)
#define TA_EXPORTS __declspec(dllexport)
#else
#define TA_EXPORTS
#endif

#endif // TENSORADAPTER_EXPORTS_H_
10 changes: 8 additions & 2 deletions tensoradapter/pytorch/torch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
* \brief Implementation of PyTorch adapter library.
*/

#include <tensoradapter.h>
#include <tensoradapter_exports.h>
#include <torch/torch.h>
#include <ATen/DLConvertor.h>
#include <vector>
#include <iostream>

#if DLPACK_VERSION > 040
// Compatibility across DLPack - note that this assumes that the ABI stays the same.
#define kDLGPU kDLCUDA
#define DLContext DLDevice
#endif

namespace tensoradapter {

static at::Device get_device(DLContext ctx) {
Expand All @@ -29,7 +35,7 @@ static at::Device get_device(DLContext ctx) {

extern "C" {

DLManagedTensor* TAempty(
TA_EXPORTS DLManagedTensor* TAempty(
std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
Expand Down

0 comments on commit 520cef8

Please sign in to comment.