Skip to content

Commit fa69eb2

Browse files
committed
eigen compatibility for plot and loglog
1 parent f1f39a5 commit fa69eb2

File tree

1 file changed

+107
-37
lines changed

1 file changed

+107
-37
lines changed

matplotlibcpp.h

Lines changed: 107 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
//
2+
// Changes:
3+
// * Make work for Eigen Vectors and Matrices
4+
// * Implement a better way for named_plot, maybe just as additional
5+
// method with extra keyword
6+
// * add location keyword for legend
7+
// * add submodule for our own functions such as spy
8+
//
9+
110
#pragma once
211

312
#include <algorithm>
@@ -57,6 +66,8 @@ struct _interpreter {
5766
PyObject *s_python_function_ylabel;
5867
PyObject *s_python_function_xticks;
5968
PyObject *s_python_function_yticks;
69+
PyObject *s_python_function_xscale;
70+
PyObject *s_python_function_yscale;
6071
PyObject *s_python_function_grid;
6172
PyObject *s_python_function_clf;
6273
PyObject *s_python_function_errorbar;
@@ -181,6 +192,8 @@ struct _interpreter {
181192
s_python_function_ylabel = PyObject_GetAttrString(pymod, "ylabel");
182193
s_python_function_xticks = PyObject_GetAttrString(pymod, "xticks");
183194
s_python_function_yticks = PyObject_GetAttrString(pymod, "yticks");
195+
s_python_function_xscale = PyObject_GetAttrString(pymod, "xscale");
196+
s_python_function_yscale = PyObject_GetAttrString(pymod, "yscale");
184197
s_python_function_grid = PyObject_GetAttrString(pymod, "grid");
185198
s_python_function_xlim = PyObject_GetAttrString(pymod, "xlim");
186199
s_python_function_ion = PyObject_GetAttrString(pymod, "ion");
@@ -209,6 +222,8 @@ struct _interpreter {
209222
!s_python_function_legend || !s_python_function_ylim ||
210223
!s_python_function_title || !s_python_function_axis ||
211224
!s_python_function_xlabel || !s_python_function_ylabel ||
225+
!s_python_function_xticks || !s_python_function_yticks ||
226+
!s_python_function_xscale || !s_python_function_yscale ||
212227
!s_python_function_grid || !s_python_function_xlim ||
213228
!s_python_function_ion || !s_python_function_ginput ||
214229
!s_python_function_save || !s_python_function_clf ||
@@ -241,6 +256,10 @@ struct _interpreter {
241256
!PyFunction_Check(s_python_function_axis) ||
242257
!PyFunction_Check(s_python_function_xlabel) ||
243258
!PyFunction_Check(s_python_function_ylabel) ||
259+
!PyFunction_Check(s_python_function_xticks) ||
260+
!PyFunction_Check(s_python_function_yticks) ||
261+
!PyFunction_Check(s_python_function_xscale) ||
262+
!PyFunction_Check(s_python_function_yscale) ||
244263
!PyFunction_Check(s_python_function_grid) ||
245264
!PyFunction_Check(s_python_function_xlim) ||
246265
!PyFunction_Check(s_python_function_ion) ||
@@ -334,6 +353,8 @@ template <> struct select_npy_type<uint64_t> {
334353
const static NPY_TYPES type = NPY_UINT64;
335354
};
336355

356+
// TODO change to Vector template so useable for Eigen vectors,
357+
// should be enough since it also provides the end and begin methods
337358
template <typename Numeric> PyObject *get_array(const std::vector<Numeric> &v) {
338359
detail::_interpreter::get(); // interpreter needs to be initialized for the
339360
// numpy commands to work
@@ -353,6 +374,29 @@ template <typename Numeric> PyObject *get_array(const std::vector<Numeric> &v) {
353374
return varray;
354375
}
355376

377+
template <typename Vector> PyObject *get_array(const Vector &v) {
378+
detail::_interpreter::get(); // interpreter needs to be initialized for the
379+
// numpy commands to work
380+
// both Eigen::Matrix<..> and std::vector<..> have the member value_type
381+
NPY_TYPES type = select_npy_type<typename Vector::value_type>::type;
382+
if (type == NPY_NOTYPE) {
383+
std::vector<double> vd(v.size());
384+
npy_intp vsize = v.size();
385+
std::copy(v.begin(), v.end(), vd.begin());
386+
PyObject *varray =
387+
PyArray_SimpleNewFromData(1, &vsize, NPY_DOUBLE, (void *)(vd.data()));
388+
return varray;
389+
}
390+
391+
npy_intp vsize = v.size();
392+
PyObject *varray =
393+
PyArray_SimpleNewFromData(1, &vsize, type, (void *)(v.data()));
394+
return varray;
395+
}
396+
397+
// TODO maybe we have to add a function for Eigen matrices, not sure
398+
// if the v[0] is valid for matrices and also the ::std::vector &v_row
399+
// probably doesn't work
356400
template <typename Numeric>
357401
PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
358402
detail::_interpreter::get(); // interpreter needs to be initialized for the
@@ -380,9 +424,16 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
380424

381425
#else // fallback if we don't have numpy: copy every element of the given vector
382426

427+
// TODO to Vector templat
383428
template <typename Numeric> PyObject *get_array(const std::vector<Numeric> &v) {
384-
detail::_interpreter::get();
429+
PyObject *list = PyList_New(v.size());
430+
for (size_t i = 0; i < v.size(); ++i) {
431+
PyList_SetItem(list, i, PyFloat_FromDouble(v.at(i)));
432+
}
433+
return list;
434+
}
385435

436+
template <typename Vector> PyObject *get_array(const Vector &v) {
386437
PyObject *list = PyList_New(v.size());
387438
for (size_t i = 0; i < v.size(); ++i) {
388439
PyList_SetItem(list, i, PyFloat_FromDouble(v.at(i)));
@@ -392,8 +443,8 @@ template <typename Numeric> PyObject *get_array(const std::vector<Numeric> &v) {
392443

393444
#endif // WITHOUT_NUMPY
394445

395-
template <typename Numeric>
396-
bool plot(const std::vector<Numeric> &x, const std::vector<Numeric> &y,
446+
template <typename VectorX, typename VectorY>
447+
bool plot(const VectorX &x, const VectorY &y,
397448
const std::map<std::string, std::string> &keywords) {
398449
assert(x.size() == y.size());
399450

@@ -767,9 +818,8 @@ bool named_hist(std::string label, const std::vector<Numeric> &y,
767818
return res;
768819
}
769820

770-
template <typename NumericX, typename NumericY>
771-
bool plot(const std::vector<NumericX> &x, const std::vector<NumericY> &y,
772-
const std::string &s = "") {
821+
template <typename VectorX, typename VectorY>
822+
bool plot(const VectorX &x, const VectorY &y, const std::string &s = "") {
773823
assert(x.size() == y.size());
774824

775825
PyObject *xarray = get_array(x);
@@ -904,29 +954,54 @@ bool semilogy(const std::vector<NumericX> &x, const std::vector<NumericY> &y,
904954
return res;
905955
}
906956

907-
template <typename NumericX, typename NumericY>
908-
bool loglog(const std::vector<NumericX> &x, const std::vector<NumericY> &y,
909-
const std::string &s = "") {
910-
assert(x.size() == y.size());
957+
template <typename... Args> bool loglog_call(Args... args) {
958+
// argument for xscale/yscale is only the string "log"
959+
PyObject *log_arg = PyTuple_New(1);
960+
PyObject *pystring = PyString_FromString("log");
961+
PyTuple_SetItem(log_arg, 0, pystring);
911962

912-
PyObject *xarray = get_array(x);
913-
PyObject *yarray = get_array(y);
963+
// call xscale("log") and yscale("log"), no kwargs needed hence pass NULL,
964+
// as explained in https://docs.python.org/3/c-api/object.html
965+
PyObject *res_x = PyObject_Call(
966+
detail::_interpreter::get().s_python_function_xscale, log_arg, NULL);
967+
PyObject *res_y = PyObject_Call(
968+
detail::_interpreter::get().s_python_function_yscale, log_arg, NULL);
914969

915-
PyObject *pystring = PyString_FromString(s.c_str());
970+
// clean up
971+
Py_DECREF(log_arg);
916972

917-
PyObject *plot_args = PyTuple_New(3);
918-
PyTuple_SetItem(plot_args, 0, xarray);
919-
PyTuple_SetItem(plot_args, 1, yarray);
920-
PyTuple_SetItem(plot_args, 2, pystring);
973+
if (!res_x)
974+
throw std::runtime_error("Call to xscale() failed");
975+
Py_DECREF(res_x);
921976

922-
PyObject *res = PyObject_CallObject(
923-
detail::_interpreter::get().s_python_function_loglog, plot_args);
977+
if (!res_y)
978+
throw std::runtime_error("Call to yscale() failed");
979+
Py_DECREF(res_y);
924980

925-
Py_DECREF(plot_args);
926-
if (res)
927-
Py_DECREF(res);
981+
// call plot, which gets now plotted in doubly logarithmic scale
982+
return plot(args...);
983+
}
928984

929-
return res;
985+
template <typename VectorY>
986+
bool loglog(const VectorY &y, const std::string &s = "") {
987+
return loglog_call(y, s);
988+
}
989+
990+
template <typename VectorX, typename VectorY>
991+
bool loglog(const VectorX &x, const VectorY &y, const std::string &s = "") {
992+
return loglog_call(x, y, s);
993+
}
994+
995+
template <typename VectorY>
996+
bool loglog(const VectorY &y,
997+
const std::map<std::string, std::string> &kwargs) {
998+
return loglog_call(y, kwargs);
999+
}
1000+
1001+
template <typename VectorX, typename VectorY>
1002+
bool loglog(const VectorX &x, const VectorY &y,
1003+
const std::map<std::string, std::string> &kwargs) {
1004+
return loglog_call(x, y, kwargs);
9301005
}
9311006

9321007
template <typename NumericX, typename NumericY>
@@ -1107,9 +1182,12 @@ bool named_loglog(const std::string &name, const std::vector<Numeric> &x,
11071182
return res;
11081183
}
11091184

1110-
template <typename Numeric>
1111-
bool plot(const std::vector<Numeric> &y, const std::string &format = "") {
1112-
std::vector<Numeric> x(y.size());
1185+
template <typename Vector>
1186+
bool plot(const Vector &y, const std::string &format = "") {
1187+
// TODO can this be <size_t> or do we need <typename Vector::value_type>?
1188+
// before the conversion of this function from vector<Numeric> to Vector
1189+
// the created vector x was of the same type as y
1190+
std::vector<size_t> x(y.size());
11131191
for (size_t i = 0; i < x.size(); ++i)
11141192
x.at(i) = i;
11151193
return plot(x, y, format);
@@ -1125,8 +1203,6 @@ bool stem(const std::vector<Numeric> &y, const std::string &format = "") {
11251203

11261204
template <typename Numeric>
11271205
void text(Numeric x, Numeric y, const std::string &s = "") {
1128-
detail::_interpreter::get();
1129-
11301206
PyObject *args = PyTuple_New(3);
11311207
PyTuple_SetItem(args, 0, PyFloat_FromDouble(x));
11321208
PyTuple_SetItem(args, 1, PyFloat_FromDouble(y));
@@ -1399,9 +1475,6 @@ inline void yticks(const std::vector<Numeric> &ticks,
13991475
}
14001476

14011477
inline void subplot(long nrows, long ncols, long plot_number) {
1402-
1403-
detail::_interpreter::get();
1404-
14051478
// construct positional args
14061479
PyObject *args = PyTuple_New(3);
14071480
PyTuple_SetItem(args, 0, PyFloat_FromDouble(nrows));
@@ -1441,7 +1514,6 @@ inline void title(const std::string &titlestr,
14411514

14421515
inline void suptitle(const std::string &suptitlestr,
14431516
const std::map<std::string, std::string> &keywords = {}) {
1444-
detail::_interpreter::get();
14451517
PyObject *pysuptitlestr = PyString_FromString(suptitlestr.c_str());
14461518
PyObject *args = PyTuple_New(1);
14471519
PyTuple_SetItem(args, 0, pysuptitlestr);
@@ -1741,8 +1813,6 @@ template <> struct plot_impl<std::false_type> {
17411813
template <typename IterableX, typename IterableY>
17421814
bool operator()(const IterableX &x, const IterableY &y,
17431815
const std::string &format) {
1744-
detail::_interpreter::get();
1745-
17461816
// 2-phase lookup for distance, begin, end
17471817
using std::begin;
17481818
using std::distance;
@@ -1812,16 +1882,16 @@ bool plot(const A &a, const B &b, const std::string &format, Args... args) {
18121882
*/
18131883
inline bool plot(const std::vector<double> &x, const std::vector<double> &y,
18141884
const std::string &format = "") {
1815-
return plot<double, double>(x, y, format);
1885+
return plot<std::vector<double>, std::vector<double>>(x, y, format);
18161886
}
18171887

18181888
inline bool plot(const std::vector<double> &y, const std::string &format = "") {
1819-
return plot<double>(y, format);
1889+
return plot<std::vector<double>>(y, format);
18201890
}
18211891

18221892
inline bool plot(const std::vector<double> &x, const std::vector<double> &y,
18231893
const std::map<std::string, std::string> &keywords) {
1824-
return plot<double>(x, y, keywords);
1894+
return plot<std::vector<double>, std::vector<double>>(x, y, keywords);
18251895
}
18261896

18271897
/*

0 commit comments

Comments
 (0)