6from sympy.printing
import ccode
11x, y, z = symbols(
'x,y,z')
18 coords = symbols(
'x,y,z')[:nsd]
20 coords = [Symbol(
"x_%d" % d)
for d
in range(nsd)]
32 for d
in range(0, nsd):
41 raise RuntimeError(
"Bernstein only implemented in 1D, 2D, and 3D")
47 b1, b2, b3 = x, y, 1 - x - y
48 for o1
in range(0, order + 1):
49 for o2
in range(0, order + 1):
50 for o3
in range(0, order + 1):
51 if o1 + o2 + o3 == order:
52 aij = Symbol(
"a_%d_%d_%d" % (o1, o2, o3))
53 fac = factorial(order) / (factorial(o1) *
54 factorial(o2) * factorial(o3))
55 sum += aij * fac * pow(b1, o1) * \
56 pow(b2, o2) * pow(b3, o3)
57 basis.append(fac * pow(b1, o1) *
58 pow(b2, o2) * pow(b3, o3))
62 b1, b2, b3, b4 = x, y, z, 1 - x - y - z
63 for o1
in range(0, order + 1):
64 for o2
in range(0, order + 1):
65 for o3
in range(0, order + 1):
66 for o4
in range(0, order + 1):
67 if o1 + o2 + o3 + o4 == order:
68 aij = Symbol(
"a_%d_%d_%d_%d" % (o1, o2, o3, o4))
70 order) / (factorial(o1) * factorial(o2) * factorial(o3) * factorial(o4))
72 pow(b1, o1) * pow(b2, o2) * \
73 pow(b3, o3) * pow(b4, o4)
74 basis.append(fac * pow(b1, o1) * pow(b2, o2) *
75 pow(b3, o3) * pow(b4, o4))
78 return sum, coeff, basis
82 h = Rational(1, order)
86 for i
in range(0, order + 1):
88 for j
in range(0, order + 1):
94 for i
in range(0, order + 1):
96 for j
in range(0, order + 1):
98 for k
in range(0, order + 1):
101 set.append((x, y, z))
107 A = zeros(len(equations))
110 for j
in range(0, len(coeffs)):
112 for i
in range(0, len(equations)):
114 d, _ = reduced(e, [c])
139 ex = pol.subs(x, p[0])
141 ex = ex.subs(y, p[1])
143 ex = ex.subs(z, p[2])
146 b = eye(len(equations))
158 for i
in range(0, len(equations)):
160 for j
in range(0, len(coeffs)):
161 Ni = Ni.subs(coeffs[j], xx[j, i])
168 parser = argparse.ArgumentParser(
170 formatter_class=argparse.RawDescriptionHelpFormatter)
171 parser.add_argument(
"output", type=str, help=
"path to the output folder")
172 parser.add_argument(
"--bernstein", default=
False, action=
'store_true',
173 help=
"use Bernstein basis instead of Lagrange basis")
174 return parser.parse_args()
177if __name__ ==
"__main__":
182 orders = [0, 1, 2, 3, 4]
185 bletter =
"b" if args.bernstein
else "p"
187 cpp = f
"#include \"auto_{bletter}_bases.hpp\""
188 if not args.bernstein:
189 cpp = cpp +
"\n#include \"auto_b_bases.hpp\""
190 cpp = cpp +
"\n#include \"p_n_bases.hpp\""
191 cpp = cpp +
"\n\n\n" \
192 "namespace polyfem {\nnamespace autogen " +
"{\nnamespace " +
"{\n"
194 hpp =
"#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n"
196 hpp = hpp +
"\nnamespace polyfem {\nnamespace autogen " +
"{\n"
199 print(str(dim) +
"D " + bletter)
200 suffix =
"2d" if dim == 2
else "3d"
202 unique_nodes = f
"void {bletter}_nodes_{suffix}" + \
203 f
"(const int {bletter}, Eigen::MatrixXd &val)"
206 unique_fun = f
"void {bletter}_basis_value_{suffix}" + \
207 f
"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
208 dunique_fun = f
"void {bletter}_grad_basis_value_{suffix}" + \
209 f
"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
211 unique_fun = f
"void {bletter}_basis_value_{suffix}" + \
212 f
"(const bool bernstein, const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
213 dunique_fun = f
"void {bletter}_grad_basis_value_{suffix}" + \
214 f
"(const bool bernstein, const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
216 if not args.bernstein:
217 hpp = hpp + unique_nodes +
";\n\n"
219 hpp = hpp + unique_fun +
";\n\n"
220 hpp = hpp + dunique_fun +
";\n\n"
222 unique_nodes = unique_nodes + f
"{{\nswitch({bletter})" +
"{\n"
224 unique_fun = unique_fun +
"{\n"
225 dunique_fun = dunique_fun +
"{\n"
227 if not args.bernstein:
228 unique_fun = unique_fun + \
229 f
"if(bernstein) {{ b_basis_value_{suffix}(p, local_index, uv, val); return; }}\n\n"
230 dunique_fun = dunique_fun + \
231 f
"if(bernstein) {{ b_grad_basis_value_{suffix}(p, local_index, uv, val); return; }}\n\n"
233 unique_fun = unique_fun + f
"\nswitch({bletter})" +
"{\n"
234 dunique_fun = dunique_fun + f
"\nswitch({bletter})" +
"{\n"
237 vertices = [[0, 0], [1, 0], [0, 1]]
239 vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
242 print(
"\t-processing " + str(order))
245 def fe():
return None
251 fe.points = [[1./3., 1./3.]]
253 fe.points = [[1./3., 1./3., 1./3.]]
257 current_indices = list(range(0, len(fe.points)))
261 for i
in range(0, dim + 1):
263 for ii
in current_indices:
265 for dd
in range(0, dim):
266 norm = norm + (vv[dd] - fe.points[ii][dd]) ** 2
270 current_indices.remove(ii)
274 for i
in range(0, order - 1):
275 for ii
in current_indices:
276 if fe.points[ii][1] != 0
or (dim == 3
and fe.points[ii][2] != 0):
279 if abs(fe.points[ii][0] - (i + 1) / order) < 1e-10:
281 current_indices.remove(ii)
285 for i
in range(0, order - 1):
286 for ii
in current_indices:
287 if fe.points[ii][0] + fe.points[ii][1] != 1
or (dim == 3
and fe.points[ii][2] != 0):
290 if abs(fe.points[ii][1] - (i + 1) / order) < 1e-10:
292 current_indices.remove(ii)
296 for i
in range(0, order - 1):
297 for ii
in current_indices:
298 if fe.points[ii][0] != 0
or (dim == 3
and fe.points[ii][2] != 0):
301 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
303 current_indices.remove(ii)
308 for i
in range(0, order - 1):
309 for ii
in current_indices:
310 if fe.points[ii][0] != 0
or fe.points[ii][1] != 0:
313 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
315 current_indices.remove(ii)
319 for i
in range(0, order - 1):
320 for ii
in current_indices:
321 if fe.points[ii][0] + fe.points[ii][2] != 1
or fe.points[ii][1] != 0:
324 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
326 current_indices.remove(ii)
330 for i
in range(0, order - 1):
331 for ii
in current_indices:
332 if fe.points[ii][1] + fe.points[ii][2] != 1
or fe.points[ii][0] != 0:
335 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
337 current_indices.remove(ii)
341 nn = max(0, order - 2)
342 npts = int(nn * (nn + 1) / 2)
345 for i
in range(0, npts):
346 for ii
in current_indices:
347 if abs(fe.points[ii][2]) > 1e-10:
351 current_indices.remove(ii)
355 for i
in range(0, npts):
356 for ii
in current_indices:
357 if abs(fe.points[ii][1]) > 1e-10:
361 current_indices.remove(ii)
366 for i
in range(0, npts):
367 for ii
in current_indices:
368 if (abs(fe.points[ii][0]) < 1e-10) | (abs(fe.points[ii][1]) < 1e-10) | (abs(fe.points[ii][2]) < 1e-10):
371 if abs((fe.points[ii][0] + fe.points[ii][1] + fe.points[ii][2]) - 1) > 1e-10:
375 current_indices.remove(ii)
377 for i
in range(0, len(tmp)):
378 indices.append(tmp[(i + 2) % len(tmp)])
382 for i
in range(0, npts):
383 for ii
in current_indices:
384 if abs(fe.points[ii][0]) > 1e-10:
388 current_indices.remove(ii)
390 tmp.sort(reverse=
True)
394 for ii
in current_indices:
398 nodes = f
"void {bletter}_{order}_nodes_{suffix}(Eigen::MatrixXd &res) {{\n res.resize(" + str(
399 len(indices)) +
", " + str(dim) +
"); res << \n"
400 unique_nodes = unique_nodes + f
"\tcase {order}: " + \
401 f
"{bletter}_{order}_nodes_{suffix}(val); break;\n"
404 nodes = nodes + ccode(fe.points[ii][0]) +
", " + ccode(fe.points[ii][1]) + (
405 (
", " + ccode(fe.points[ii][2]))
if dim == 3
else "") +
",\n"
407 nodes = nodes +
";\n}"
410 func = f
"void {bletter}_{order}_basis_value_{suffix}" + \
411 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &result_0)"
412 dfunc = f
"void {bletter}_{order}_basis_grad_value_{suffix}" + \
413 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
415 unique_fun = unique_fun + \
416 f
"\tcase {order}: {bletter}_{order}_basis_value_{suffix}(local_index, uv, val); break;\n"
417 dunique_fun = dunique_fun + \
418 f
"\tcase {order}: {bletter}_{order}_basis_grad_value_{suffix}(local_index, uv, val); break;\n"
423 if not args.bernstein:
424 default_base =
"p_n_basis_value_3d(p, local_index, uv, val);" if dim == 3
else "p_n_basis_value_2d(p, local_index, uv, val);"
425 default_dbase =
"p_n_basis_grad_value_3d(p, local_index, uv, val);" if dim == 3
else "p_n_basis_grad_value_2d(p, local_index, uv, val);"
426 default_nodes =
"p_n_nodes_3d(p, val);" if dim == 3
else "p_n_nodes_2d(p, val);"
428 base =
"auto x=uv.col(0).array();\nauto y=uv.col(1).array();"
430 base = base +
"\nauto z=uv.col(2).array();"
435 base = base +
"result_0.resize(x.size(),1);\n"
437 base = base +
"switch(local_index){\n"
439 "val.resize(uv.rows(), uv.cols());\n Eigen::ArrayXd result_0(uv.rows());\n" + \
440 "switch(local_index){\n"
442 for i
in range(0, fe.nbf()):
443 real_index = indices[i]
447 simplify(fe.N[real_index])).replace(
" = 1;",
".setOnes();") +
"} break;\n"
448 dbase = dbase +
"\tcase " + str(i) +
": {" + \
449 "{" +
pretty_print.C99_print(simplify(diff(fe.N[real_index], x))).replace(
" = 0;",
".setZero();").replace(
" = 1;",
".setOnes();").replace(
" = -1;",
".setConstant(-1);") +
"val.col(0) = result_0; }" \
450 "{" +
pretty_print.C99_print(simplify(diff(fe.N[real_index], y))).replace(
" = 0;",
".setZero();").replace(
451 " = 1;",
".setOnes();").replace(
" = -1;",
".setConstant(-1);") +
"val.col(1) = result_0; }"
454 dbase = dbase +
"{" +
pretty_print.C99_print(simplify(diff(fe.N[real_index], z))).replace(
" = 0;",
".setZero();").replace(
455 " = 1;",
".setOnes();").replace(
" = -1;",
".setConstant(-1);") +
"val.col(2) = result_0; }"
457 dbase = dbase +
"} break;\n"
459 base = base +
"\tdefault: assert(false);\n}"
460 dbase = dbase +
"\tdefault: assert(false);\n}"
462 cpp = cpp + func +
"{\n\n"
463 cpp = cpp + base +
"}\n"
465 cpp = cpp + dfunc +
"{\n\n"
466 cpp = cpp + dbase +
"}\n\n\n"
468 if not args.bernstein:
469 cpp = cpp + nodes +
"\n\n\n"
474 unique_nodes = unique_nodes +
"\tdefault: "+default_nodes+
"\n}}"
477 unique_fun = unique_fun +
"\tdefault: assert(false); \n}}"
478 dunique_fun = dunique_fun +
"\tdefault: assert(false); \n}}"
480 unique_fun = unique_fun +
"\tdefault: "+default_base+
"\n}}"
481 dunique_fun = dunique_fun +
"\tdefault: "+default_dbase+
"\n}}"
483 cpp = cpp +
"}\n\n" + unique_nodes +
"\n" + unique_fun + \
484 "\n\n" + dunique_fun +
"\n" +
"\nnamespace " +
"{\n"
488 f
"\nstatic const int MAX_{bletter.capitalize()}_BASES = {max(orders)};\n"
490 cpp = cpp +
"\n}}}\n"
493 path = os.path.abspath(args.output)
496 with open(os.path.join(path, f
"auto_{bletter}_bases.cpp"),
"w")
as file:
499 with open(os.path.join(path, f
"auto_{bletter}_bases.hpp"),
"w")
as file:
__init__(self, nsd, order, bernstein)
create_matrix(equations, coeffs)
create_point_set(order, nsd)
bernstein_space(order, nsd)