Loading...
Searching...
No Matches
batch.h
Go to the documentation of this file.
1#ifndef _BBM_BATCH_H_
2#define _BBM_BATCH_H_
3
5#include "bbm/config.h"
6#include "util/vector_util.h"
7
8/************************************************************************/
9/*! \file batch.h
10
11 \brief Batch loss wrapper for sampled loss functions.
12
13*************************************************************************/
14
15namespace bbm {
16
17
18 /*********************************************************************/
19 /* \brief batch; a specialized sampled loss function
20
21 Computes the loss over a sampled loss for a given batch size of
22 randomly picked samples.
23
24 Satisfies: concepts::sampledlossfunction
25 **********************************************************************/
26 template<typename SAMPLEDLOSSFUNC> requires concepts::sampledlossfunction<SAMPLEDLOSSFUNC>
27 struct batch
28 {
29 BBM_IMPORT_CONFIG( SAMPLEDLOSSFUNC );
30
31 /********************************************************************/
32 /*! \brief Constructor
33
34 Wrapper around sampled loss functions that only evaluates the loss
35 for a subset of the samples. Each update re-randomizes the subset.
36
37 \param batchsize = size of the batch (same batch size for all packets)
38 \param sampedlosfunc = any sampled loss function (to be subsampled).
39 *********************************************************************/
40 inline batch(size_t batchsize, const SAMPLEDLOSSFUNC& sampledlossfunc, seed_t seed=default_seed) : _sampledlossfunc(sampledlossfunc), _index(batchsize), _rng(seed, 0, sampledlossfunc.samples())
41 {
42 update();
43 }
44
45 /*******************************************************************/
46 /*! \brief Init the list of random indices.
47 *******************************************************************/
48 inline void update(void)
49 {
50 _sampledlossfunc.update();
51
52 for(size_t i=0; i < _index.size(); ++i)
53 _index[i] = _rng();
54 }
55
56 /*******************************************************************/
57 /*! \brief Returns the number of samples
58 *******************************************************************/
59 inline Size_t samples(void) const { return _index.size(); }
60
61 /*******************************************************************/
62 /*! \brief Compute the loss over the I-th sample
63 *******************************************************************/
64 inline Value operator()(Size_t idx, Mask mask=true) const
65 {
66 mask &= (idx < samples());
67 if(bbm::none(mask)) return 0;
68
69 return _sampledlossfunc(_index[idx], mask);
70 }
71
72 /*******************************************************************/
73 /*! \brief Compute loss over all samples
74 *******************************************************************/
75 inline Value operator()(Mask mask=true) const
76 {
77 Value err(0);
78 for(size_t i; i < _index.size(); ++i)
79 err += operator()(i, mask);
80
81 return err / samples();
82 }
83
84
85 private:
86 /////////////////////
87 // Class Attributes
88 /////////////////////
91 SAMPLEDLOSSFUNC _sampledlossfunc;
92 };
93
95
96} // end bbm namespace
97
98#endif /* _BBM_BATCH_H_ */
All BBM methods are defined to operate on a variety of value types and spectrum types....
Definition: vector_util.h:27
sampled loss function concept
Definition: sampledlossfunction.h:27
sampled loss function contract
#define BBM_CHECK_CONCEPT(CONCEPTNAME, CLASSNAME,...)
Check a class for a concept with bbm::concepts::archetypes in the namespace.
Definition: macro.h:35
Definition: aggregatebsdf.h:29
Random generator wrapper around Drjit's PCG32.
Definition: random.h:34
Definition: batch.h:28
Value operator()(Size_t idx, Mask mask=true) const
Compute the loss over the I-th sample.
Definition: batch.h:64
void update(void)
Init the list of random indices.
Definition: batch.h:48
Size_t samples(void) const
Returns the number of samples.
Definition: batch.h:59
bbm::rng< Size_t > _rng
Definition: batch.h:89
SAMPLEDLOSSFUNC _sampledlossfunc
Definition: batch.h:91
batch(size_t batchsize, const SAMPLEDLOSSFUNC &sampledlossfunc, seed_t seed=default_seed)
Constructor.
Definition: batch.h:40
BBM_IMPORT_CONFIG(SAMPLEDLOSSFUNC)
Value operator()(Mask mask=true) const
Compute loss over all samples.
Definition: batch.h:75
bbm::vector< Size_t > _index
Definition: batch.h:90
Definition: sampledlossfunction.h:46
Extensions for the STL vector class.