PolyFEM
Loading...
Searching...
No Matches
ALSolver.cpp
Go to the documentation of this file.
1#include "ALSolver.hpp"
2
4
5namespace polyfem::solver
6{
8 const std::vector<std::shared_ptr<AugmentedLagrangianForm>> &alagr_form,
9 const double initial_al_weight,
10 const double scaling,
11 const double max_al_weight,
12 const double eta_tol,
13 const std::function<void(const Eigen::VectorXd &)> &update_barrier_stiffness)
14 : alagr_forms{alagr_form},
15 initial_al_weight(initial_al_weight),
16 scaling(scaling),
17 max_al_weight(max_al_weight),
18 eta_tol(eta_tol),
19 update_barrier_stiffness(update_barrier_stiffness)
20 {
21 }
22
23 void ALSolver::solve_al(std::shared_ptr<NLSolver> nl_solver, NLProblem &nl_problem, Eigen::MatrixXd &sol)
24 {
25 assert(sol.size() == nl_problem.full_size());
26
27 const Eigen::VectorXd initial_sol = sol;
28 Eigen::VectorXd tmp_sol = nl_problem.full_to_reduced(sol);
29 assert(tmp_sol.size() == nl_problem.reduced_size());
30
31 // --------------------------------------------------------------------
32
33 double al_weight = initial_al_weight;
34 int al_steps = 0;
35 const int iters = nl_solver->stop_criteria().iterations;
36
37 double initial_error = 0;
38 for (const auto &f : alagr_forms)
39 initial_error += f->compute_error(sol);
40
41 nl_problem.line_search_begin(sol, tmp_sol);
42
43 for (auto &f : alagr_forms)
44 f->set_initial_weight(al_weight);
45
46 while (!std::isfinite(nl_problem.value(tmp_sol))
47 || !nl_problem.is_step_valid(sol, tmp_sol)
48 || !nl_problem.is_step_collision_free(sol, tmp_sol))
49 {
50 nl_problem.line_search_end();
51
52 set_al_weight(nl_problem, sol, al_weight);
53 logger().debug("Solving AL Problem with weight {}", al_weight);
54
55 nl_problem.init(sol);
57 tmp_sol = sol;
58
59 try
60 {
61 nl_solver->minimize(nl_problem, tmp_sol);
62 nl_problem.finish();
63 }
64 catch (const std::runtime_error &e)
65 {
66 std::string err_msg = e.what();
67 // if the nonlinear solve fails due to invalid energy at the current solution, changing the weights would not help
68 if (err_msg.find("f(x) is nan or inf; stopping") != std::string::npos)
69 log_and_throw_error("Failed to solve with AL; f(x) is nan or inf");
70 }
71
72 sol = tmp_sol;
73 set_al_weight(nl_problem, sol, -1);
74
75 double current_error = 0;
76 for (const auto &f : alagr_forms)
77 f->compute_error(sol);
78 const double eta = 1 - sqrt(current_error / initial_error);
79
80 logger().debug("Current eta = {}", eta);
81
82 if (eta < 0)
83 {
84 logger().debug("Higher error than initial, increase weight and revert to previous solution");
85 sol = initial_sol;
86 }
87
88 tmp_sol = nl_problem.full_to_reduced(sol);
89 nl_problem.line_search_begin(sol, tmp_sol);
90
91 if (eta < eta_tol && al_weight < max_al_weight)
92 al_weight *= scaling;
93 else
94 {
95 for (auto &f : alagr_forms)
96 f->update_lagrangian(sol, al_weight);
97 }
98
99 post_subsolve(al_weight);
100 ++al_steps;
101 }
102 nl_problem.line_search_end();
103 nl_solver->stop_criteria().iterations = iters;
104 }
105
106 void ALSolver::solve_reduced(std::shared_ptr<NLSolver> nl_solver, NLProblem &nl_problem, Eigen::MatrixXd &sol)
107 {
108 assert(sol.size() == nl_problem.full_size());
109
110 Eigen::VectorXd tmp_sol = nl_problem.full_to_reduced(sol);
111 nl_problem.line_search_begin(sol, tmp_sol);
112
113 if (!std::isfinite(nl_problem.value(tmp_sol))
114 || !nl_problem.is_step_valid(sol, tmp_sol)
115 || !nl_problem.is_step_collision_free(sol, tmp_sol))
116 log_and_throw_error("Failed to apply constraints conditions; solve with augmented lagrangian first!");
117
118 // --------------------------------------------------------------------
119 // Perform one final solve with the DBC projected out
120
121 logger().debug("Successfully applied constraints conditions; solving in reduced space");
122
123 nl_problem.init(sol);
125 try
126 {
127 nl_solver->minimize(nl_problem, tmp_sol);
128 nl_problem.finish();
129 }
130 catch (const std::runtime_error &e)
131 {
132 sol = nl_problem.reduced_to_full(tmp_sol);
133 throw e;
134 }
135 sol = nl_problem.reduced_to_full(tmp_sol);
136
137 post_subsolve(0);
138 }
139
140 void ALSolver::set_al_weight(NLProblem &nl_problem, const Eigen::VectorXd &x, const double weight)
141 {
142 if (alagr_forms.empty())
143 return;
144 if (weight > 0)
145 {
146 for (auto &f : alagr_forms)
147 f->enable();
148
149 nl_problem.use_full_size();
150 }
151 else
152 {
153 for (auto &f : alagr_forms)
154 f->disable();
155
156 nl_problem.use_reduced_size();
157 }
158 }
159
160} // namespace polyfem::solver
int x
const double max_al_weight
Definition ALSolver.hpp:40
std::vector< std::shared_ptr< AugmentedLagrangianForm > > alagr_forms
Definition ALSolver.hpp:37
void set_al_weight(NLProblem &nl_problem, const Eigen::VectorXd &x, const double weight)
Definition ALSolver.cpp:140
std::function< void(const Eigen::VectorXd &)> update_barrier_stiffness
Definition ALSolver.hpp:44
ALSolver(const std::vector< std::shared_ptr< AugmentedLagrangianForm > > &alagr_form, const double initial_al_weight, const double scaling, const double max_al_weight, const double eta_tol, const std::function< void(const Eigen::VectorXd &)> &update_barrier_stiffness)
Definition ALSolver.cpp:7
std::function< void(const double)> post_subsolve
Definition ALSolver.hpp:32
const double initial_al_weight
Definition ALSolver.hpp:38
void solve_reduced(std::shared_ptr< NLSolver > nl_solver, NLProblem &nl_problem, Eigen::MatrixXd &sol)
Definition ALSolver.cpp:106
void solve_al(std::shared_ptr< NLSolver > nl_solver, NLProblem &nl_problem, Eigen::MatrixXd &sol)
Definition ALSolver.cpp:23
virtual void line_search_end() override
virtual void init(const TVector &x0) override
void line_search_begin(const TVector &x0, const TVector &x1) override
Definition NLProblem.cpp:88
virtual bool is_step_valid(const TVector &x0, const TVector &x1) override
Definition NLProblem.cpp:98
virtual TVector full_to_reduced(const TVector &full) const
virtual bool is_step_collision_free(const TVector &x0, const TVector &x1) override
virtual TVector reduced_to_full(const TVector &reduced) const
virtual double value(const TVector &x) override
spdlog::logger & logger()
Retrieves the current logger.
Definition Logger.cpp:42
void log_and_throw_error(const std::string &msg)
Definition Logger.cpp:71