Loading...
Searching...
No Matches
control.h
Go to the documentation of this file.
1#ifndef _BBM_DRJIT_CONTROL_H_
2#define _BBM_DRJIT_CONTROL_H_
3
4#include "core/error.h"
6#include "backbone/type_traits.h"
7#include "drjit/util.h"
8
9/************************************************************************/
10/*! \file control.h
11
12 \brief Data and flow control
13 + select
14 + lookup
15 + cast
16 + set
17 + binary_search
18
19*************************************************************************/
20
21namespace backbone {
22
23 /**********************************************************************/
24 /*! \brief cast
25 **********************************************************************/
26 template<typename NEWTYPE, typename OLDTYPE>
27 NEWTYPE cast(OLDTYPE&& val)
28 {
29 if constexpr (is_diff_v<OLDTYPE> && !is_diff_v<NEWTYPE>) return cast<NEWTYPE>(drjit::detach(val));
30 else if constexpr (is_LLVMArray_v<OLDTYPE> && !is_LLVMArray_v<NEWTYPE>) return cast<NEWTYPE>(val[0]);
31 else if constexpr (drjit::is_static_array_v<OLDTYPE> && drjit::is_static_array_v<NEWTYPE>)
32 {
33 auto helper = [&]<size_t... IDX>(std::index_sequence<IDX...>)
34 {
35 return NEWTYPE( cast<value_t<NEWTYPE>>( val[IDX] )... );
36 };
37
38 return helper(std::make_index_sequence<std::decay_t<OLDTYPE>::Size>{});
39 }
40 else return NEWTYPE(val);
41 }
42
43 /**********************************************************************/
44 /*! \brief Extension of drjit::select to diff/non-diff masks
45 *********************************************************************/
46 template<typename MASK, typename A, typename B>
47 inline auto select(MASK&& mask, A&& a, B&& b)
48 {
49 if constexpr (is_diff_v<MASK> && !is_diff_v<A> && !is_diff_v<B>)
50 return drjit::select(cast<remove_diff_t<MASK>>(mask), std::forward<A>(a), std::forward<B>(b));
51 else if constexpr (is_LLVMArray_v<MASK> && !is_LLVMArray_v<A> && !is_LLVMArray_v<B>)
52 return drjit::select(mask[0], std::forward<A>(a), std::forward<B>(b));
53 else
54 return drjit::select(std::forward<MASK>(mask), std::forward<A>(a), std::forward<B>(b));
55 }
56
57 /**********************************************************************/
58 /*! \brief Non-packet look up
59 **********************************************************************/
60 template<typename RET, typename C, typename Index> requires (!is_packet_v<Index>) && std::ranges::range<C> && std::convertible_to<bbm::iterable_value_t<C>, RET> && is_index_v<Index>
61 inline constexpr RET lookup(C&& container, const Index& idx, const index_mask_t<Index>& mask=true)
62 {
63 // quick bailout
64 if(none(mask)) return RET();
65
66 // lookup
67 size_t i = cast<size_t>(idx);
68 if(i >= bbm::size(container)) throw bbm_out_of_range;
69 return RET( *(std::next(bbm::begin(container), i)));
70 }
71
72
73 /**********************************************************************/
74 /*! \brief Non-packet set
75 **********************************************************************/
76 template<typename VAL, typename C, typename Index> requires (!is_packet_v<Index>) && std::ranges::range<C> && is_index_v<Index> && std::convertible_to<VAL, bbm::iterable_value_t<C>>
77 inline constexpr void set(C&& container, const Index& idx, VAL&& value, const index_mask_t<Index>& mask=true)
78 {
79 // quick bailout
80 if(none(mask)) return;
81
82 // set
83 size_t i = cast<size_t>(idx);
84 if(i >= bbm::size(container)) throw bbm_out_of_range;
85 *(std::next(bbm::begin(std::forward<C>(container)), i)) = std::forward<VAL>(value);
86 }
87
88 /**********************************************************************/
89 /*! \brief binary search
90 **********************************************************************/
91 template<typename C, typename PRED> requires std::ranges::range<C> && std::is_invocable_r_v<mask_t<bbm::iterable_value_t<C>>, PRED, bbm::iterable_value_t<C>>
92 inline constexpr index_t<bbm::iterable_value_t<C>> binary_search(C&& container, PRED&& predicate, const index_mask_t<bbm::iterable_value_t<C>>& mask=true)
93 {
94 using value_type = bbm::iterable_value_t<C>;
95 using index_type = index_t<value_type>;
96
97 // quick exit
98 if(none(mask)) return index_type(bbm::size(container));
99
100 // create a wrapper for the predicate to meet drjit's expectations
101 auto pred_wrapper = [&](const index_type& index)
102 {
103 auto result = predicate( lookup<value_type>(container, index, mask) );
104 return cast<index_mask_t<index_type>>(result);
105 };
106
107 // pass control to drjit
108 index_type idx = drjit::binary_search<index_type>(0, bbm::size(container)-1, pred_wrapper);
109 return select(mask && !pred_wrapper(idx), idx, bbm::size(container));
110 }
111
112} // end backbone namespace
113
114
115#endif /* _BBM_DRJIT_CONTROL_H_ */
Predefined exceptions for common errors.
#define bbm_out_of_range
Definition: error.h:44
Extensions for STL iterators/ranges.
Random number generator; built on top of Drjit.
Definition: backbone.h:53
constexpr void set(C &&container, const Index &idx, VAL &&value, const index_mask_t< Index > &mask=true)
Non-packet set.
Definition: control.h:77
typename detail::index_impl< std::decay_t< T > >::mask index_mask_t
Return the mask of an index of a type.
Definition: type_traits.h:215
typename backbone::detail::value< std::decay_t< T > >::type value_t
Value trait.
Definition: type_traits.h:35
bool none(const T &t)
Definition: horizontal.h:56
constexpr RET lookup(C &&container, const Index &idx, const index_mask_t< Index > &mask=true)
Non-packet look up.
Definition: control.h:61
NEWTYPE cast(OLDTYPE &&val)
cast
Definition: control.h:27
typename detail::index_impl< std::decay_t< T > >::type index_t
Return the index type.
Definition: type_traits.h:203
constexpr index_t< bbm::iterable_value_t< C > > binary_search(C &&container, PRED &&predicate, const index_mask_t< bbm::iterable_value_t< C > > &mask=true)
binary search
Definition: control.h:92
auto select(MASK &&mask, A &&a, B &&b)
Extension of drjit::select to diff/non-diff masks.
Definition: control.h:47
typename detail::remove_diff< std::decay_t< T > >::type remove_diff_t
Strip autodiff from type T.
Definition: type_traits.h:95
size_t size(T &&t)
Definition: iterator_util.h:22
std::decay_t< decltype(*bbm::begin(std::declval< T >()))> iterable_value_t
Definition: iterator_util.h:61
auto begin(T &&t)
Definition: iterator_util.h:29