adc986a0033cef38d6de3a64022853b9397ab52b
[blender.git] / intern / libmv / libmv / autotrack / predict_tracks.cc
1 // Copyright (c) 2014 libmv authors.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to
5 // deal in the Software without restriction, including without limitation the
6 // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
7 // sell copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
19 // IN THE SOFTWARE.
20 //
21 // Author: mierle@gmail.com (Keir Mierle)
22
23 #include "libmv/autotrack/marker.h"
24 #include "libmv/autotrack/predict_tracks.h"
25 #include "libmv/autotrack/tracks.h"
26 #include "libmv/base/vector.h"
27 #include "libmv/logging/logging.h"
28 #include "libmv/tracking/kalman_filter.h"
29
30 namespace mv {
31
32 namespace {
33
34 using libmv::vector;
35 using libmv::Vec2;
36
37 // Implied time delta between steps. Set empirically by tweaking and seeing
38 // what numbers did best at prediction.
39 const double dt = 3.8;
40
41 // State transition matrix.
42
43 // The states for predicting a track are as follows:
44 //
45 //   0 - X position
46 //   1 - X velocity
47 //   2 - X acceleration
48 //   3 - Y position
49 //   4 - Y velocity
50 //   5 - Y acceleration
51 //
52 // Note that in the velocity-only state transition matrix, the acceleration
53 // component is ignored; so technically the system could be modelled with only
54 // 4 states instead of 6. For ease of implementation, this keeps order 6.
55
56 // Choose one or the other model from below (velocity or acceleration).
57
58 // For a typical system having constant velocity. This gives smooth-appearing
59 // predictions, but they are not always as accurate.
60 const double velocity_state_transition_data[] = {
61   1, dt,       0,  0,  0,        0,
62   0,  1,       0,  0,  0,        0,
63   0,  0,       1,  0,  0,        0,
64   0,  0,       0,  1, dt,        0,
65   0,  0,       0,  0,  1,        0,
66   0,  0,       0,  0,  0,        1
67 };
68
69 // This 3rd-order system also models acceleration. This makes for "jerky"
70 // predictions, but that tend to be more accurate.
71 const double acceleration_state_transition_data[] = {
72   1, dt, dt*dt/2,  0,  0,        0,
73   0,  1,      dt,  0,  0,        0,
74   0,  0,       1,  0,  0,        0,
75   0,  0,       0,  1, dt,  dt*dt/2,
76   0,  0,       0,  0,  1,       dt,
77   0,  0,       0,  0,  0,        1
78 };
79
80 // This system (attempts) to add an angular velocity component. However, it's
81 // total junk.
82 const double angular_state_transition_data[] = {
83   1, dt,     -dt,  0,  0,        0,   // Position x
84   0,  1,       0,  0,  0,        0,   // Velocity x
85   0,  0,       1,  0,  0,        0,   // Angular momentum
86   0,  0,      dt,  1, dt,        0,   // Position y
87   0,  0,       0,  0,  1,        0,   // Velocity y
88   0,  0,       0,  0,  0,        1    // Ignored
89 };
90
91 const double* state_transition_data = velocity_state_transition_data;
92
93 // Observation matrix.
94 const double observation_data[] = {
95   1., 0., 0., 0., 0., 0.,
96   0., 0., 0., 1., 0., 0.
97 };
98
99 // Process covariance.
100 const double process_covariance_data[] = {
101  35,  0,  0,  0,  0,  0,
102   0,  5,  0,  0,  0,  0,
103   0,  0,  5,  0,  0,  0,
104   0,  0,  0, 35,  0,  0,
105   0,  0,  0,  0,  5,  0,
106   0,  0,  0,  0,  0,  5
107 };
108
109 // Process covariance.
110 const double measurement_covariance_data[] = {
111   0.01,  0.00,
112   0.00,  0.01,
113 };
114
115 // Initial covariance.
116 const double initial_covariance_data[] = {
117  10,  0,  0,  0,  0,  0,
118   0,  1,  0,  0,  0,  0,
119   0,  0,  1,  0,  0,  0,
120   0,  0,  0, 10,  0,  0,
121   0,  0,  0,  0,  1,  0,
122   0,  0,  0,  0,  0,  1
123 };
124
125 typedef mv::KalmanFilter<double, 6, 2> TrackerKalman;
126
127 TrackerKalman filter(state_transition_data,
128                      observation_data,
129                      process_covariance_data,
130                      measurement_covariance_data);
131
132 bool OrderByFrameLessThan(const Marker* a, const Marker* b) {
133   if (a->frame == b->frame) {
134     if (a->clip == b->clip) {
135       return a->track < b->track;
136     }
137     return a->clip < b->clip;
138   }
139   return a->frame < b-> frame;
140 }
141
142 // Predicted must be after the previous markers (in the frame numbering sense).
143 void RunPrediction(const vector<Marker*> previous_markers,
144                    Marker* predicted_marker) {
145   TrackerKalman::State state;
146   state.mean << previous_markers[0]->center.x(), 0, 0,
147                 previous_markers[0]->center.y(), 0, 0;
148   state.covariance = Eigen::Matrix<double, 6, 6, Eigen::RowMajor>(
149       initial_covariance_data);
150
151   int current_frame = previous_markers[0]->frame;
152   int target_frame = predicted_marker->frame;
153
154   bool predict_forward = current_frame < target_frame;
155   int frame_delta = predict_forward ? 1 : -1;
156
157   for (int i = 1; i < previous_markers.size(); ++i) {
158     // Step forward predicting the state until it is on the current marker.
159     int predictions = 0;
160     for (;
161          current_frame != previous_markers[i]->frame;
162          current_frame += frame_delta) {
163       filter.Step(&state);
164       predictions++;
165       LG << "Predicted point (frame " << current_frame << "): "
166          << state.mean(0) << ", " << state.mean(3);
167     }
168     // Log the error -- not actually used, but interesting.
169     Vec2 error = previous_markers[i]->center.cast<double>() -
170                  Vec2(state.mean(0), state.mean(3));
171     LG << "Prediction error for " << predictions << " steps: ("
172        << error.x() << ", " << error.y() << "); norm: " << error.norm();
173     // Now that the state is predicted in the current frame, update the state
174     // based on the measurement from the current frame.
175     filter.Update(previous_markers[i]->center.cast<double>(),
176                   Eigen::Matrix<double, 2, 2, Eigen::RowMajor>(
177                       measurement_covariance_data),
178                   &state);
179     LG << "Updated point: " << state.mean(0) << ", " << state.mean(3);
180   }
181   // At this point as all the prediction that's possible is done. Finally
182   // predict until the target frame.
183   for (; current_frame != target_frame; current_frame += frame_delta) {
184     filter.Step(&state);
185     LG << "Final predicted point (frame " << current_frame << "): "
186        << state.mean(0) << ", " << state.mean(3);
187   }
188
189   // The x and y positions are at 0 and 3; ignore acceleration and velocity.
190   predicted_marker->center.x() = state.mean(0);
191   predicted_marker->center.y() = state.mean(3);
192
193   // Take the patch from the last marker then shift it to match the prediction.
194   const Marker& last_marker = *previous_markers[previous_markers.size() - 1];
195   predicted_marker->patch = last_marker.patch;
196   Vec2f delta = predicted_marker->center - last_marker.center;
197   for (int i = 0; i < 4; ++i) {
198     predicted_marker->patch.coordinates.row(i) += delta;
199   }
200
201   // Alter the search area as well so it always corresponds to the center.
202   predicted_marker->search_region = last_marker.search_region;
203   predicted_marker->search_region.Offset(delta);
204 }
205
206 }  // namespace
207
208 bool PredictMarkerPosition(const Tracks& tracks, Marker* marker) {
209   // Get all markers for this clip and track.
210   vector<Marker> markers;
211   tracks.GetMarkersForTrackInClip(marker->clip, marker->track, &markers);
212
213   if (markers.empty()) {
214     LG << "No markers to predict from for " << *marker;
215     return false;
216   }
217
218   // Order the markers by frame within the clip.
219   vector<Marker*> boxed_markers(markers.size());
220   for (int i = 0; i < markers.size(); ++i) {
221     boxed_markers[i] = &markers[i];
222   }
223   std::sort(boxed_markers.begin(), boxed_markers.end(), OrderByFrameLessThan);
224
225   // Find the insertion point for this marker among the returned ones.
226   int insert_at = -1;      // If we find the exact frame
227   int insert_before = -1;  // Otherwise...
228   for (int i = 0; i < boxed_markers.size(); ++i) {
229     if (boxed_markers[i]->frame == marker->frame) {
230       insert_at = i;
231       break;
232     }
233     if (boxed_markers[i]->frame > marker->frame) {
234       insert_before = i;
235       break;
236     }
237   }
238
239   // Forward starts at the marker or insertion point, and goes forward.
240   int forward_scan_begin, forward_scan_end;
241
242   // Backward scan starts at the marker or insertion point, and goes backward.
243   int backward_scan_begin, backward_scan_end;
244
245   // Determine the scanning ranges.
246   if (insert_at == -1 && insert_before == -1) {
247     // Didn't find an insertion point except the end.
248     forward_scan_begin = forward_scan_end = 0;
249     backward_scan_begin = markers.size() - 1;
250     backward_scan_end = 0;
251   } else if (insert_at != -1) {
252     // Found existing marker; scan before and after it.
253     forward_scan_begin = insert_at + 1;
254     forward_scan_end = markers.size() - 1;;
255     backward_scan_begin = insert_at - 1;
256     backward_scan_end = 0;
257   } else {
258     // Didn't find existing marker but found an insertion point.
259     forward_scan_begin = insert_before;
260     forward_scan_end = markers.size() - 1;;
261     backward_scan_begin = insert_before - 1;
262     backward_scan_end = 0;
263   }
264
265   const int num_consecutive_needed = 2;
266
267   if (forward_scan_begin <= forward_scan_end &&
268       forward_scan_end - forward_scan_begin > num_consecutive_needed) {
269     // TODO(keir): Finish this.
270   }
271
272   bool predict_forward = false;
273   if (backward_scan_end <= backward_scan_begin) {
274     // TODO(keir): Add smarter handling and detecting of consecutive frames!
275     predict_forward = true;
276   }
277
278   const int max_frames_to_predict_from = 20;
279   if (predict_forward) {
280     if (backward_scan_begin - backward_scan_end < num_consecutive_needed) {
281       // Not enough information to do a prediction.
282       LG << "Predicting forward impossible, not enough information";
283       return false;
284     }
285     LG << "Predicting forward";
286     int predict_begin =
287         std::max(backward_scan_begin - max_frames_to_predict_from, 0);
288     int predict_end = backward_scan_begin;
289     vector<Marker*> previous_markers;
290     for (int i = predict_begin; i <= predict_end; ++i) {
291       previous_markers.push_back(boxed_markers[i]);
292     }
293     RunPrediction(previous_markers, marker);
294     return true;
295   } else {
296     if (forward_scan_end - forward_scan_begin < num_consecutive_needed) {
297       // Not enough information to do a prediction.
298       LG << "Predicting backward impossible, not enough information";
299       return false;
300     }
301     LG << "Predicting backward";
302     int predict_begin =
303         std::min(forward_scan_begin + max_frames_to_predict_from,
304                  forward_scan_end);
305     int predict_end = forward_scan_begin;
306     vector<Marker*> previous_markers;
307     for (int i = predict_begin; i >= predict_end; --i) {
308       previous_markers.push_back(boxed_markers[i]);
309     }
310     RunPrediction(previous_markers, marker);
311     return false;
312   }
313
314 }
315
316 }  // namespace mv