Point Cloud Library (PCL) 1.12.1
regression_variance_stats_estimator.h
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/branch_estimator.h>
42#include <pcl/ml/stats_estimator.h>
43
44#include <istream>
45#include <ostream>
46
47namespace pcl {
48
49/** Node for a regression trees which optimizes variance. */
50template <class FeatureType, class LabelType>
52public:
53 /** Constructor. */
54 RegressionVarianceNode() : value(0), variance(0), threshold(0), sub_nodes() {}
55
56 /** Destructor. */
58
59 /** Serializes the node to the specified stream.
60 *
61 * \param[out] stream the destination for the serialization
62 */
63 inline void
64 serialize(std::ostream& stream) const
65 {
66 feature.serialize(stream);
67
68 stream.write(reinterpret_cast<const char*>(&threshold), sizeof(threshold));
69
70 stream.write(reinterpret_cast<const char*>(&value), sizeof(value));
71 stream.write(reinterpret_cast<const char*>(&variance), sizeof(variance));
72
73 const int num_of_sub_nodes = static_cast<int>(sub_nodes.size());
74 stream.write(reinterpret_cast<const char*>(&num_of_sub_nodes),
75 sizeof(num_of_sub_nodes));
76 for (int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index) {
77 sub_nodes[sub_node_index].serialize(stream);
78 }
79 }
80
81 /** Deserializes a node from the specified stream.
82 *
83 * \param[in] stream the source for the deserialization
84 */
85 inline void
86 deserialize(std::istream& stream)
87 {
88 feature.deserialize(stream);
89
90 stream.read(reinterpret_cast<char*>(&threshold), sizeof(threshold));
91
92 stream.read(reinterpret_cast<char*>(&value), sizeof(value));
93 stream.read(reinterpret_cast<char*>(&variance), sizeof(variance));
94
95 int num_of_sub_nodes;
96 stream.read(reinterpret_cast<char*>(&num_of_sub_nodes), sizeof(num_of_sub_nodes));
97 sub_nodes.resize(num_of_sub_nodes);
98
99 if (num_of_sub_nodes > 0) {
100 for (int sub_node_index = 0; sub_node_index < num_of_sub_nodes;
101 ++sub_node_index) {
102 sub_nodes[sub_node_index].deserialize(stream);
103 }
104 }
105 }
106
107public:
108 /** The feature associated with the node. */
109 FeatureType feature;
110
111 /** The threshold applied on the feature response. */
113
114 /** The label value of this node. */
115 LabelType value;
116
117 /** The variance of the labels that ended up at this node during training. */
118 LabelType variance;
119
120 /** The child nodes. */
121 std::vector<RegressionVarianceNode> sub_nodes;
122};
123
124/** Statistics estimator for regression trees which optimizes variance. */
125template <class LabelDataType, class NodeType, class DataSet, class ExampleIndex>
127: public pcl::StatsEstimator<LabelDataType, NodeType, DataSet, ExampleIndex> {
128public:
129 /** Constructor. */
131 : branch_estimator_(branch_estimator)
132 {}
133
134 /** Destructor. */
136
137 /** Returns the number of branches the corresponding tree has. */
138 inline std::size_t
140 {
141 // return 2;
142 return branch_estimator_->getNumOfBranches();
143 }
144
145 /** Returns the label of the specified node.
146 *
147 * \param[in] node the node which label is returned
148 */
149 inline LabelDataType
150 getLabelOfNode(NodeType& node) const
151 {
152 return node.value;
153 }
154
155 /** Computes the information gain obtained by the specified threshold.
156 *
157 * \param[in] data_set the data set corresponding to the supplied result data
158 * \param[in] examples the examples used for extracting the supplied result data
159 * \param[in] label_data the label data corresponding to the specified examples
160 * \param[in] results the results computed using the specified examples
161 * \param[in] flags the flags corresponding to the results
162 * \param[in] threshold the threshold for which the information gain is computed
163 */
164 float
165 computeInformationGain(DataSet& data_set,
166 std::vector<ExampleIndex>& examples,
167 std::vector<LabelDataType>& label_data,
168 std::vector<float>& results,
169 std::vector<unsigned char>& flags,
170 const float threshold) const
171 {
172 const std::size_t num_of_examples = examples.size();
173 const std::size_t num_of_branches = getNumOfBranches();
174
175 // compute variance
176 std::vector<LabelDataType> sums(num_of_branches + 1, 0);
177 std::vector<LabelDataType> sqr_sums(num_of_branches + 1, 0);
178 std::vector<std::size_t> branch_element_count(num_of_branches + 1, 0);
179
180 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
181 branch_element_count[branch_index] = 1;
182 ++branch_element_count[num_of_branches];
183 }
184
185 for (std::size_t example_index = 0; example_index < num_of_examples;
186 ++example_index) {
187 unsigned char branch_index;
188 computeBranchIndex(
189 results[example_index], flags[example_index], threshold, branch_index);
190
191 LabelDataType label = label_data[example_index];
192
193 sums[branch_index] += label;
194 sums[num_of_branches] += label;
195
196 sqr_sums[branch_index] += label * label;
197 sqr_sums[num_of_branches] += label * label;
198
199 ++branch_element_count[branch_index];
200 ++branch_element_count[num_of_branches];
201 }
202
203 std::vector<float> variances(num_of_branches + 1, 0);
204 for (std::size_t branch_index = 0; branch_index < num_of_branches + 1;
205 ++branch_index) {
206 const float mean_sum =
207 static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
208 const float mean_sqr_sum = static_cast<float>(sqr_sums[branch_index]) /
209 branch_element_count[branch_index];
210 variances[branch_index] = mean_sqr_sum - mean_sum * mean_sum;
211 }
212
213 float information_gain = variances[num_of_branches];
214 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
215 // const float weight = static_cast<float>(sums[branchIndex]) /
216 // sums[numOfBranches];
217 const float weight = static_cast<float>(branch_element_count[branch_index]) /
218 static_cast<float>(branch_element_count[num_of_branches]);
219 information_gain -= weight * variances[branch_index];
220 }
221
222 return information_gain;
223 }
224
225 /** Computes the branch indices for all supplied results.
226 *
227 * \param[in] results the results the branch indices will be computed for
228 * \param[in] flags the flags corresponding to the specified results
229 * \param[in] threshold the threshold used to compute the branch indices
230 * \param[out] branch_indices the destination for the computed branch indices
231 */
232 void
233 computeBranchIndices(std::vector<float>& results,
234 std::vector<unsigned char>& flags,
235 const float threshold,
236 std::vector<unsigned char>& branch_indices) const
237 {
238 const std::size_t num_of_results = results.size();
239 const std::size_t num_of_branches = getNumOfBranches();
240
241 branch_indices.resize(num_of_results);
242 for (std::size_t result_index = 0; result_index < num_of_results; ++result_index) {
243 unsigned char branch_index;
244 computeBranchIndex(
245 results[result_index], flags[result_index], threshold, branch_index);
246 branch_indices[result_index] = branch_index;
247 }
248 }
249
250 /** Computes the branch index for the specified result.
251 *
252 * \param[in] result the result the branch index will be computed for
253 * \param[in] flag the flag corresponding to the specified result
254 * \param[in] threshold the threshold used to compute the branch index
255 * \param[out] branch_index the destination for the computed branch index
256 */
257 inline void
258 computeBranchIndex(const float result,
259 const unsigned char flag,
260 const float threshold,
261 unsigned char& branch_index) const
262 {
263 branch_estimator_->computeBranchIndex(result, flag, threshold, branch_index);
264 // branch_index = (result > threshold) ? 1 : 0;
265 }
266
267 /** Computes and sets the statistics for a node.
268 *
269 * \param[in] data_set the data set which is evaluated
270 * \param[in] examples the examples which define which parts of the data set are use
271 * for evaluation
272 * \param[in] label_data the label_data corresponding to the examples
273 * \param[out] node the destination node for the statistics
274 */
275 void
276 computeAndSetNodeStats(DataSet& data_set,
277 std::vector<ExampleIndex>& examples,
278 std::vector<LabelDataType>& label_data,
279 NodeType& node) const
280 {
281 const std::size_t num_of_examples = examples.size();
282
283 LabelDataType sum = 0.0f;
284 LabelDataType sqr_sum = 0.0f;
285 for (std::size_t example_index = 0; example_index < num_of_examples;
286 ++example_index) {
287 const LabelDataType label = label_data[example_index];
288
289 sum += label;
290 sqr_sum += label * label;
291 }
292
293 sum /= num_of_examples;
294 sqr_sum /= num_of_examples;
295
296 const float variance = sqr_sum - sum * sum;
297
298 node.value = sum;
299 node.variance = variance;
300 }
301
302 /** Generates code for branch index computation.
303 *
304 * \param[in] node the node for which code is generated
305 * \param[out] stream the destination for the generated code
306 */
307 void
308 generateCodeForBranchIndexComputation(NodeType& node, std::ostream& stream) const
309 {
310 stream << "ERROR: RegressionVarianceStatsEstimator does not implement "
311 "generateCodeForBranchIndex(...)";
312 }
313
314 /** Generates code for label output.
315 *
316 * \param[in] node the node for which code is generated
317 * \param[out] stream the destination for the generated code
318 */
319 void
320 generateCodeForOutput(NodeType& node, std::ostream& stream) const
321 {
322 stream << "ERROR: RegressionVarianceStatsEstimator does not implement "
323 "generateCodeForBranchIndex(...)";
324 }
325
326private:
327 /// The branch estimator
328 pcl::BranchEstimator* branch_estimator_;
329};
330
331} // namespace pcl
Interface for branch estimators.
Node for a regression trees which optimizes variance.
void serialize(std::ostream &stream) const
Serializes the node to the specified stream.
LabelType variance
The variance of the labels that ended up at this node during training.
void deserialize(std::istream &stream)
Deserializes a node from the specified stream.
float threshold
The threshold applied on the feature response.
FeatureType feature
The feature associated with the node.
LabelType value
The label value of this node.
std::vector< RegressionVarianceNode > sub_nodes
The child nodes.
Statistics estimator for regression trees which optimizes variance.
void generateCodeForOutput(NodeType &node, std::ostream &stream) const
Generates code for label output.
void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const
Computes and sets the statistics for a node.
void computeBranchIndex(const float result, const unsigned char flag, const float threshold, unsigned char &branch_index) const
Computes the branch index for the specified result.
LabelDataType getLabelOfNode(NodeType &node) const
Returns the label of the specified node.
void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const
Computes the branch indices for all supplied results.
RegressionVarianceStatsEstimator(BranchEstimator *branch_estimator)
Constructor.
float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const
Computes the information gain obtained by the specified threshold.
std::size_t getNumOfBranches() const
Returns the number of branches the corresponding tree has.
void generateCodeForBranchIndexComputation(NodeType &node, std::ostream &stream) const
Generates code for branch index computation.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.
#define PCL_EXPORTS
Definition: pcl_macros.h:323