Loading...
Searching...
No Matches
py_cast.h
Go to the documentation of this file.
1#ifndef _BBM_PY_CAST_H_
2#define _BBM_PY_CAST_H_
3
4#include "pybind11/numpy.h"
5
6#include "util/typestring.h"
8
9/************************************************************************/
10/*! \file py_cast.h
11
12 \brief Wrapper around py::cast for more robust casting to attributes and to
13 aggregate types such as vec2d, vec3d, spectrum, etc...
14*************************************************************************/
15
16namespace bbm {
17 namespace python {
18
19 /********************************************************************/
20 /*! \brief Default casting
21 ********************************************************************/
22 template<typename TARGET>
23 struct py_cast
24 {
25 template<typename T>
26 static inline TARGET cast(T&& t)
27 {
28 try
29 {
30 if constexpr (requires(T&& t) {{py::cast<TARGET>(std::forward<T>(t))};}) return py::cast<TARGET>(std::forward<T>(t));
31 else throw std::runtime_error("BBM: casting error");
32 }
33 catch(...)
34 {
35 throw std::runtime_error(std::string("BBM: do not know how to cast to ") + std::string(bbm::typestring<TARGET>) + ".");
36 }
37 }
38 };
39
40 /********************************************************************/
41 /*! \brief specialization for bbm::attribute
42 ********************************************************************/
43 template<typename TARGET> requires concepts::attribute<TARGET>
44 struct py_cast<TARGET>
45 {
46 template<typename T>
47 static inline TARGET cast(T&& t)
48 {
49 return py_cast<typename std::decay_t<TARGET>::type>::cast(std::forward<T>(t));
50 }
51 };
52
53 /********************************************************************/
54 /*! \brief specialization for aggregate types like vec2d, vec3d, spectrum, ...
55 ********************************************************************/
56 template<typename TARGET> requires (!std::same_as<value_t<TARGET>, std::decay_t<TARGET>> && !concepts::attribute<TARGET>)
57 struct py_cast<TARGET>
58 {
59 private:
60
61 //! \brief copy constant values (val) in 'result'
62 template<typename T>
63 static void from_const(T& result, const scalar_t<TARGET>& val)
64 {
65 if constexpr (std::same_as<value_t<T>, std::decay_t<T>>) result = val;
66 else for(auto& r : result) from_const(r, val);
67 }
68
69 //! \brief recusively copy values from the sub-array at partial indices [idx...] to result
70 template<typename T, typename... Idx>
71 static void from_py_array(T& result, const py::array_t<scalar_t<TARGET>>& parr, Idx... idx)
72 {
73 // base case: T is not a aggregate type
74 if constexpr (std::same_as<value_t<T>, std::decay_t<T>>)
75 {
76 // special case: end of array reached
77 if((sizeof...(Idx) == parr.ndim())) return from_const(result, parr.at(idx...));
78
79 // check size (must be last dimension with only 1 element)
80 size_t size = parr.shape()[sizeof...(Idx)];
81 if(sizeof...(Idx)+1 == parr.ndim() && size == 1) return from_const(result, parr.at(idx..., 0));
82
83 // error
84 throw bbm_size_error;
85 }
86
87 // recursion
88 else
89 {
90 // check dimensions
91 if(parr.ndim() <= ssize_t(sizeof...(Idx))) throw bbm_size_error;
92
93 // check size
94 size_t size = parr.shape()[sizeof...(Idx)];
95 if(size != bbm::size(result)) throw bbm_size_error;
96
97 // copy multiple data entries
98 for(size_t current_idx = 0; current_idx != bbm::size(result); ++current_idx)
99 {
100 // special case: end of array reached
101 if(parr.ndim() == sizeof...(Idx)+1) from_const( value(*std::next(bbm::begin(result), current_idx)), parr.at(idx..., current_idx));
102
103 // recurse
104 else from_py_array(value(*std::next(bbm::begin(result), current_idx)), parr, idx..., current_idx);
105 }
106 }
107
108 // Done.
109 }
110
111 public:
112 template<typename T>
113 static inline TARGET cast(T&& t)
114 {
115 // try casting to TARGET first
116 try
117 {
118 return py::cast<TARGET>(std::forward<T>(t));
119 } catch(...) {}
120
121 // try casting to scalar_t instead of TARGET
122 try {
123 if constexpr (std::constructible_from<TARGET, scalar_t<TARGET>>)
124 return TARGET( py_cast<scalar_t<TARGET>>::cast(std::forward<T>(t)) );
125 } catch(...) {}
126
127 // final attempt; try casting to py::array
128 auto parr = py::cast<py::array_t<scalar_t<TARGET>>>(std::forward<T>(t));
129 TARGET result;
130 from_py_array(result, parr);
131 return result;
132 }
133 };
134
135 } // end python namespace
136} // end bbm namespace
137
138#endif /* _BBM_PY_CAST_H_ */
Helper methods for extracting the value of an attribute (according to concepts::attribute).
attribute concept
Definition: attribute.h:39
#define bbm_size_error
Definition: error.h:45
Definition: aggregatebsdf.h:29
size_t size(T &&t)
Definition: iterator_util.h:22
auto begin(T &&t)
Definition: iterator_util.h:29
decltype(auto) value(T &&t)
return the value of an attribute, or if not an attribute the object
Definition: attribute_value.h:20
static void from_const(T &result, const scalar_t< TARGET > &val)
copy constant values (val) in 'result'
Definition: py_cast.h:63
static TARGET cast(T &&t)
Definition: py_cast.h:47
static void from_py_array(T &result, const py::array_t< scalar_t< TARGET > > &parr, Idx... idx)
recusively copy values from the sub-array at partial indices [idx...] to result
Definition: py_cast.h:71
Default casting.
Definition: py_cast.h:24
static TARGET cast(T &&t)
Definition: py_cast.h:26
produce stringview of type name of a type. Avoids using typeid for GCC, MSVC, and CLANG....