Skip to content

Commit

Permalink
Byte swap values read from the protobuf raw data stream on big endian
Browse files Browse the repository at this point in the history
  • Loading branch information
ellert authored and guitargeek committed Apr 9, 2024
1 parent a561a9f commit 7a39d4e
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Byteswap.h"
#include "TMVA/RModelParser_ONNX.hxx"
#include "onnx_proto3.pb.h"

Expand Down Expand Up @@ -374,9 +375,14 @@ RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
case ETensorType::FLOAT: {
std::shared_ptr<void> data(malloc(fLength * sizeof(float)), free);

if (tensorproto->raw_data().empty() == false) {
auto raw_data_ptr = reinterpret_cast<float *>(const_cast<char *>(tensorproto->raw_data().c_str()));
std::memcpy(data.get(), raw_data_ptr, fLength * sizeof(float));
if (!tensorproto->raw_data().empty()) {
#ifdef R__BYTESWAP
std::memcpy(data.get(), tensorproto->raw_data().c_str(), fLength * sizeof(float));
#else
for (std::size_t k = 0; k < fLength; ++k)
(reinterpret_cast<uint32_t *>(data.get()))[k] =
Rbswap_32((reinterpret_cast<const uint32_t *>(tensorproto->raw_data().c_str()))[k]);
#endif
} else {
tensorproto->mutable_float_data()->ExtractSubrange(0, tensorproto->float_data_size(),
static_cast<float *>(data.get()));
Expand All @@ -390,9 +396,14 @@ RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
case ETensorType::INT64: {
std::shared_ptr<void> data(malloc(fLength * sizeof(int64_t)), free);

if (tensorproto->raw_data().empty() == false) {
auto raw_data_ptr = reinterpret_cast<int64_t *>(const_cast<char *>(tensorproto->raw_data().c_str()));
std::memcpy(data.get(), raw_data_ptr, fLength * sizeof(int64_t));
if (!tensorproto->raw_data().empty()) {
#ifdef R__BYTESWAP
std::memcpy(data.get(), tensorproto->raw_data().c_str(), fLength * sizeof(int64_t));
#else
for (std::size_t k = 0; k < fLength; ++k)
(reinterpret_cast<uint64_t *>(data.get()))[k] =
Rbswap_64((reinterpret_cast<const uint64_t *>(tensorproto->raw_data().c_str()))[k]);
#endif
} else {
tensorproto->mutable_int64_data()->ExtractSubrange(0, tensorproto->int64_data_size(),
static_cast<int64_t *>(data.get()));
Expand Down

0 comments on commit 7a39d4e

Please sign in to comment.