Loading...
Searching...
No Matches
gradient.h
Go to the documentation of this file.
1#ifndef _BBM_GRADIENT_H_
2#define _BBM_GRADIENT_H_
3
5#include "util/vector_util.h"
6
7/***********************************************************************/
8/*! \file gradient.h
9 \brief Gradient related operations
10
11 Enable gradients,detach, forward, backward, and get gradients types that
12 fullfill concepts::diff_parameter
13************************************************************************/
14
15namespace bbm {
16
17 /*********************************************************************/
18 /*! \brief Enable gradients for a parameter set
19 *********************************************************************/
20 template<typename PARAM> requires concepts::diff_parameter<PARAM>
21 inline void track_gradients(PARAM&& param, bool toggle=true)
22 {
23 for(auto& itr : param)
24 track_gradient(itr, toggle);
25 }
26
27
28 /*********************************************************************/
29 /*! @{ \brief Check if gradients are tracked
30 *********************************************************************/
31 template<typename PARAM> requires concepts::diff_parameter<PARAM>
32 inline bool all_gradients_tracked(PARAM&& param)
33 {
34 for(auto& itr : param)
35 if(!is_gradient_tracked(itr)) return false;
36 return true;
37 }
38
39 template<typename PARAM> requires concepts::diff_parameter<PARAM>
40 inline bool any_gradients_tracked(PARAM&& param)
41 {
42 for(auto& itr : param)
43 if(is_gradient_tracked(itr)) return true;
44 return false;
45 }
46 //! @}
47
48
49 /*********************************************************************/
50 /*! \brief Get the gradient from a parameter set
51 *********************************************************************/
52 template<typename PARAM> requires concepts::diff_parameter<PARAM>
53 inline auto get_gradients(PARAM&& param)
54 {
55 using Value = decltype( gradient(*std::begin(param)) );
56
58 for(auto& itr : param)
59 result.push_back( gradient(itr) );
60
61 return result;
62 }
63
64 /*********************************************************************/
65 /*! \brief Get the detached values from a parameter set
66 *********************************************************************/
67 template<typename PARAM> requires concepts::diff_parameter<PARAM>
68 inline auto detach_gradients(PARAM&& param)
69 {
70 using Value = decltype( detach_gradient(*std::begin(param)) );
71
73 for(auto& itr : param)
74 result.push_back( detach_gradient(itr) );
75
76 return result;
77 }
78
79
80 /*********************************************************************/
81 /*! \brief Forward computation of gradients on a parameter set
82 *********************************************************************/
83 template<typename PARAM> requires concepts::diff_parameter<PARAM>
84 inline void forward_gradients(PARAM&& param)
85 {
86 for(auto& itr : param)
87 forward(itr);
88 }
89
90 /**********************************************************************/
91 /*! \brief backward computations => passthrough to backbone
92 **********************************************************************/
93 template<typename T>
94 inline void backward_gradients(T&& t)
95 {
96 backward(std::forward<T>(t));
97 }
98
99} // end bbm namespace
100
101#endif /* _BBM_GRADIENT_H_ */
Definition: vector_util.h:27
Definition: aggregatebsdf.h:29
void track_gradient(T &t, bool toggle=true)
Enable/disable tracking of gradients for a variable.
Definition: backbone.h:451
auto detach_gradients(PARAM &&param)
Get the detached values from a parameter set.
Definition: gradient.h:68
bool any_gradients_tracked(PARAM &&param)
Definition: gradient.h:40
void track_gradients(PARAM &&param, bool toggle=true)
Enable gradients for a parameter set.
Definition: gradient.h:21
bool is_gradient_tracked(const T &t)
Checks if gradients are enabled for a variable.
Definition: backbone.h:436
void forward_gradients(PARAM &&param)
Forward computation of gradients on a parameter set.
Definition: gradient.h:84
auto gradient(T &t)
Return the gradient.
Definition: backbone.h:423
auto get_gradients(PARAM &&param)
Get the gradient from a parameter set.
Definition: gradient.h:53
bool all_gradients_tracked(PARAM &&param)
Check if gradients are tracked.
Definition: gradient.h:32
void backward_gradients(T &&t)
backward computations => passthrough to backbone
Definition: gradient.h:94
auto detach_gradient(T &t)
Detach the value from the gradient computations.
Definition: backbone.h:408
Concepts related to BSDF model parameters.
Extensions for the STL vector class.