Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions gtwrap/matlab_wrapper/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class CheckMixin:
)
# Ignore the namespace for these datatypes
ignore_namespace: Tuple = ('Matrix', 'Vector', 'Point2', 'Point3')
# Matrix-like view types that can alias MATLAB double matrix storage.
matrix_view_types: Tuple = ('ConstMatrixView', )
# Methods that should be ignored
ignore_methods: Tuple = ('pickle', )
# Methods that should not be wrapped directly
Expand All @@ -42,6 +44,7 @@ def can_be_pointer(self, arg_type: parser.Type):
"""
return (arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and not self.is_matrix_view(arg_type)
and arg_type.typename.name != 'string')

def is_shared_ptr(self, arg_type: parser.Type):
Expand All @@ -67,6 +70,10 @@ def is_ref(self, arg_type: parser.Type):
arg_type.typename.name not in self.not_ptr_type and \
arg_type.is_ref

def is_matrix_view(self, arg_type: parser.Type):
"""Check if `arg_type` should be unwrapped as a matrix view."""
return arg_type.typename.name in self.matrix_view_types

def is_class_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if arg_type is an enum in the class `class_`."""
if class_:
Expand Down
7 changes: 7 additions & 0 deletions gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self,
'unsigned char': 'unsigned char',
'Vector': 'double',
'Matrix': 'double',
'ConstMatrixView': 'double',
'int': 'numeric',
'size_t': 'numeric',
'Key': 'numeric',
Expand All @@ -69,6 +70,7 @@ def __init__(self,
'Point3': 'double',
'Vector': 'double',
'Matrix': 'double',
'ConstMatrixView': 'double',
'Key': 'numeric',
'bool': 'bool'
}
Expand Down Expand Up @@ -354,6 +356,11 @@ def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None):
arg_type = f"{enum_type}"
unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);'

elif self.is_matrix_view(arg.ctype):
arg_type = self._format_type_name(arg.ctype.typename)
unwrap = 'unwrapMatrixView< {ctype} >(in[{id}]);'.format(
ctype=arg_type, id=arg_id)

elif self.is_ref(arg.ctype): # and not constructor:
arg_type = "{ctype}&".format(ctype=ctype_sep)
unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format(
Expand Down
36 changes: 31 additions & 5 deletions gtwrap/pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ class PybindWrapper:
Class to generate binding code for Pybind11 specifically.
"""

ARG_POLICY_SUPPORT = """
#include <type_traits>

namespace gtwrap {
namespace internal {

template <typename T>
struct PyArgPolicy {
static pybind11::arg make(const char* name) { return pybind11::arg(name); }
};

template <typename T>
pybind11::arg py_arg(const char* name) {
return PyArgPolicy<typename std::decay<T>::type>::make(name);
}

} // namespace internal
} // namespace gtwrap
"""

def __init__(self,
module_name,
top_module_namespaces='',
Expand Down Expand Up @@ -70,8 +90,11 @@ def _py_args_names(self, args):
default = ' = {arg.default}'.format(arg=arg)
else:
default = ''
argument = 'py::arg("{name}"){default}'.format(
name=arg.name, default='{0}'.format(default))
argument = (
'gtwrap::internal::py_arg<{ctype}>("{name}"){default}'
).format(ctype=arg.ctype.to_cpp(),
name=arg.name,
default='{0}'.format(default))
py_args.append(argument)
return ", " + ", ".join(py_args)
else:
Expand Down Expand Up @@ -251,12 +274,13 @@ def _wrap_method(self,
method,
(parser.StaticMethod, instantiator.InstantiatedStaticMethod))
return_void = method.return_type.is_void()
return_ref = getattr(
getattr(method.return_type, 'type1', None), 'is_ref', False)
return_type = getattr(method.return_type, 'type1', None)
return_ref = getattr(return_type, 'is_ref', False)
return_const = getattr(return_type, 'is_const', False)

# For methods returning const T&, use reference_internal policy
# to avoid unnecessary copies and keep the returned reference alive.
if return_ref and is_method:
if return_ref and return_const and is_method:
lambda_ret = ' -> const auto&'
ref_policy = ', py::return_value_policy::reference_internal'
else:
Expand Down Expand Up @@ -752,6 +776,8 @@ def wrap_file(self, content, module_name=None, submodules=None):
module_def = "void {0}(py::module_ &m_)".format(module_name)
submodules = []

includes += self.ARG_POLICY_SUPPORT

return self.module_template.format(
module_def=module_def,
module_name=module_name,
Expand Down
25 changes: 24 additions & 1 deletion matlab.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extern "C" {
#include <mex.h>
}

#include <limits>
#include <list>
#include <set>
#include <sstream>
Expand Down Expand Up @@ -429,6 +430,29 @@ gtsam::Matrix unwrap< gtsam::Matrix >(const mxArray* array) {
return A;
}

// unwrap a MATLAB double matrix as a const Eigen matrix view without copying
template <typename MatrixView>
MatrixView unwrapMatrixView(const mxArray* array) {
if (mxIsDouble(array)==false || mxIsComplex(array) || mxIsSparse(array))
error("unwrapMatrixView: not a full real double matrix");
const mwSize rows = mxGetM(array), cols = mxGetN(array);
if (rows > static_cast<mwSize>(std::numeric_limits<Eigen::Index>::max()) ||
cols > static_cast<mwSize>(std::numeric_limits<Eigen::Index>::max())) {
error("unwrapMatrixView: matrix dimensions exceed Eigen::Index");
}
const Eigen::Index m = static_cast<Eigen::Index>(rows);
const Eigen::Index n = static_cast<Eigen::Index>(cols);
#ifdef DEBUG_WRAP
mexPrintf("unwrapMatrixView called with %lldx%lld argument\n",
static_cast<long long>(m), static_cast<long long>(n));
#endif
using Stride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using ConstMatrixMap = Eigen::Map<const gtsam::Matrix, 0, Stride>;
const double* data = static_cast<const double*>(mxGetData(array));
ConstMatrixMap map(data, m, n, Stride(m, 1));
return MatrixView(map);
}

/*
[create_object] creates a MATLAB proxy class object with a mexhandle
in the self property. Matlab does not allow the creation of matlab
Expand Down Expand Up @@ -547,4 +571,3 @@ Class* unwrap_ptr(const mxArray* obj, const string& propertyName) {
// static_assert(unwrap_shared_ptr_Matrix_attempted, "Matrix cannot be unwrapped as a shared pointer");
// return Matrix();
//}

Loading
Loading