6from sympy.printing
import ccode
10x, y, z = symbols(
'x,y,z')
13 return [a[i]
for i
in b]
20 b1, b2, b3 = x / (1 - z), y / (1 - z), 1 - z
21 for k
in range(order + 1):
22 for i
in range(k + 1):
23 for j
in range(k + 1):
24 aijk = Symbol(
"a_%d_%d_%d" % (i, j, k))
25 sum += aijk * b1**i * b2**j * b3**(k)
26 basis.append(b1**i * b2**j * b3**(k))
28 return sum, coeff, basis
31 h = Rational(1, order)
34 for i
in range(order + 1):
36 for j
in range(order + 1):
43 for i
in range(1, order):
45 for base_v
in [(0, 0), (1, 0), (1, 1), (0, 1)]:
46 set.append((base_v[0] * (1 - z), base_v[1] * (1 - z), z))
49 for face
in [[(0, 0, 1), (0, 0, 0), (1, 0, 0)], [(0, 0, 1), (1, 0, 0), (1, 1, 0)], [(0, 0, 1), (1, 1, 0), (0, 1, 0)], [(0, 0, 1), (0, 1, 0), (0, 0, 0)]]:
50 f_a, f_b, f_c = face[0], face[1], face[2]
51 for i
in range(1, order):
53 for j
in range(1, order):
55 gamma = 1 - alpha - beta
56 if alpha > 0
and beta > 0
and gamma > 0:
57 x = alpha * f_a[0] + beta * f_b[0] + gamma * f_c[0]
58 y = alpha * f_a[1] + beta * f_b[1] + gamma * f_c[1]
59 z = alpha * f_a[2] + beta * f_b[2] + gamma * f_c[2]
63 h_i = Rational(1, order - 1)
64 for k
in range(1, order-1):
67 set.append((0.5 * (1 - z), 0.5 * (1 - z), z))
69 h_k = Rational(1, k + 1)
70 for i
in range(1, k + 1):
72 for j
in range(1, k + 1):
76 assert len(set) == (order + 1) * (order + 2) * (2 * order + 3) // 6, f
"Expected {(order + 1) * (order + 2) * (order + 3) // 6} points, but got {len(set)}"
81 A = zeros(len(equations))
84 for j
in range(0, len(coeffs)):
86 for i
in range(0, len(equations)):
88 d, _ = reduced(e, [c])
109 ex = sum.subs(x, p[0])
110 ex = ex.subs(y, p[1])
111 ex = ex.subs(z, p[2])
114 b = eye(len(equations))
119 for i
in range(0, len(equations)):
121 for j
in range(0, len(coeff)):
122 Ni = Ni.subs(coeff[j], xx[j, i])
129 parser = argparse.ArgumentParser(
131 formatter_class=argparse.RawDescriptionHelpFormatter)
132 parser.add_argument(
"output", type=str, help=
"path to the output folder")
133 return parser.parse_args()
136if __name__ ==
"__main__":
141 orders = [0, 1, 2, 3, 4]
145 cpp = f
"#include \"auto_{bletter}_bases.hpp\""
146 cpp = cpp +
"\n#include \"auto_b_bases.hpp\""
147 cpp = cpp +
"\n#include \"p_n_bases.hpp\""
148 cpp = cpp +
"\n\n\n" \
149 "namespace polyfem {\nnamespace autogen " +
"{\nnamespace " +
"{\n"
151 hpp =
"#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n"
153 hpp = hpp +
"\nnamespace polyfem {\nnamespace autogen " +
"{\n"
156 assert dim == 3,
"Only 3D pyramid is supported"
157 print(str(dim) +
"D " + bletter)
160 unique_nodes = f
"void {bletter}_nodes_{suffix}" + \
161 f
"(const int {bletter}, Eigen::MatrixXd &val)"
163 unique_fun = f
"void {bletter}_basis_value_{suffix}" + \
164 f
"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
165 dunique_fun = f
"void {bletter}_grad_basis_value_{suffix}" + \
166 f
"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
168 hpp = hpp + unique_nodes +
";\n\n"
170 hpp = hpp + unique_fun +
";\n\n"
171 hpp = hpp + dunique_fun +
";\n\n"
173 unique_nodes = unique_nodes + f
"{{\nswitch({bletter})" +
"{\n"
175 unique_fun = unique_fun +
"{\n"
176 dunique_fun = dunique_fun +
"{\n"
178 unique_fun = unique_fun + f
"\nswitch({bletter})" +
"{\n"
179 dunique_fun = dunique_fun + f
"\nswitch({bletter})" +
"{\n"
181 vertices = [[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1]]
184 print(
"\t-processing " + str(order))
187 def fe():
return None
192 fe.points = [[2./5., 2./5., 1./5.]]
196 current_indices = list(range(0, len(fe.points)))
200 for i
in range(0, 5):
202 for ii
in current_indices:
204 for dd
in range(0, dim):
205 norm = norm + (vv[dd] - fe.points[ii][dd]) ** 2
209 current_indices.remove(ii)
213 for i
in range(0, order - 1):
214 for ii
in current_indices:
215 if fe.points[ii][1] != 0
or (dim == 3
and fe.points[ii][2] != 0):
218 if abs(fe.points[ii][0] - (i + 1) / order) < 1e-10:
220 current_indices.remove(ii)
224 for i
in range(0, order - 1):
225 for ii
in current_indices:
226 if fe.points[ii][0] != 1
or (dim == 3
and fe.points[ii][2] != 0):
229 if abs(fe.points[ii][1] - (i + 1) / order) < 1e-10:
231 current_indices.remove(ii)
235 for i
in range(0, order - 1):
236 for ii
in current_indices:
237 if fe.points[ii][1] != 1
or (dim == 3
and fe.points[ii][2] != 0):
240 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
242 current_indices.remove(ii)
246 for i
in range(0, order - 1):
247 for ii
in current_indices:
248 if fe.points[ii][0] != 0
or (dim == 3
and fe.points[ii][2] != 0):
251 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
253 current_indices.remove(ii)
258 for i
in range(0, order - 1):
259 for ii
in current_indices:
260 if fe.points[ii][0] != 0
or fe.points[ii][1] != 0:
263 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
265 current_indices.remove(ii)
269 for i
in range(0, order - 1):
270 for ii
in current_indices:
271 if fe.points[ii][0] + fe.points[ii][2] != 1
or fe.points[ii][1] != 0:
274 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
276 current_indices.remove(ii)
280 for i
in range(0, order - 1):
281 for ii
in current_indices:
282 if fe.points[ii][0] + fe.points[ii][2] != 1
or fe.points[ii][1] + fe.points[ii][2] != 1:
286 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
288 current_indices.remove(ii)
292 for i
in range(0, order - 1):
293 for ii
in current_indices:
294 if fe.points[ii][1] + fe.points[ii][2] != 1
or fe.points[ii][0] != 0:
297 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
299 current_indices.remove(ii)
302 nn = max(0, order - 2)
304 npts = int(nn * (nn + 1) / 2)
308 for i
in range(0, npts):
309 for ii
in current_indices:
310 if abs(fe.points[ii][1]) > 1e-10:
312 tmp.append(ii); current_indices.remove(ii);
break
313 for i
in range(len(tmp)):
314 indices.append(tmp[(i + 1) % len(tmp)])
318 for i
in range(0, npts):
319 for ii
in current_indices:
320 if abs(fe.points[ii][0]) < 1e-10
or abs(fe.points[ii][1]) < 1e-10
or abs(fe.points[ii][2]) < 1e-10:
322 if abs((fe.points[ii][0] + fe.points[ii][2]) - 1) > 1e-10:
324 tmp.append(ii); current_indices.remove(ii);
break
325 for i
in range(len(tmp)):
326 indices.append(tmp[(i + 1) % len(tmp)])
330 for i
in range(0, npts):
331 for ii
in current_indices:
332 if abs(fe.points[ii][0]) < 1e-10
or abs(fe.points[ii][1]) < 1e-10
or abs(fe.points[ii][2]) < 1e-10:
334 if abs((fe.points[ii][1] + fe.points[ii][2]) - 1) > 1e-10:
336 tmp.append(ii); current_indices.remove(ii);
break
337 for i
in range(len(tmp)):
338 indices.append(tmp[(i + 1) % len(tmp)])
342 for i
in range(0, npts):
343 for ii
in current_indices:
344 if abs(fe.points[ii][0]) > 1e-10:
346 tmp.append(ii); current_indices.remove(ii);
break
347 for i
in range(len(tmp)):
348 indices.append(tmp[(i + 1) % len(tmp)])
351 for i
in range(0, npts_b):
352 for ii
in current_indices:
353 if abs(fe.points[ii][2]) > 1e-10:
355 indices.append(ii); current_indices.remove(ii);
break
358 for ii
in current_indices:
361 for i
in range(0, fe.nbf()):
362 print(i, indices[i], fe.points[indices[i]])
365 nodes = f
"void {bletter}_{order}_nodes_{suffix}(Eigen::MatrixXd &res) {{\n res.resize(" + str(
366 len(indices)) +
", " + str(dim) +
"); res << \n"
367 unique_nodes = unique_nodes + f
"\tcase {order}: " + \
368 f
"{bletter}_{order}_nodes_{suffix}(val); break;\n"
371 nodes = nodes + ccode(fe.points[ii][0]) +
", " + ccode(fe.points[ii][1]) + (
372 (
", " + ccode(fe.points[ii][2]))
if dim == 3
else "") +
",\n"
374 nodes = nodes +
";\n}"
377 func = f
"void {bletter}_{order}_basis_value_{suffix}" + \
378 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &result_0)"
379 dfunc = f
"void {bletter}_{order}_basis_grad_value_{suffix}" + \
380 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
382 unique_fun = unique_fun + \
383 f
"\tcase {order}: {bletter}_{order}_basis_value_{suffix}(local_index, uv, val); break;\n"
384 dunique_fun = dunique_fun + \
385 f
"\tcase {order}: {bletter}_{order}_basis_grad_value_{suffix}(local_index, uv, val); break;\n"
390 base =
"auto x=uv.col(0).array();\nauto y=uv.col(1).array();"
391 base = base +
"\nauto z=uv.col(2).array();"
396 base = base +
"result_0.resize(x.size(),1);\n"
398 base = base +
"switch(local_index){\n"
400 "val.resize(uv.rows(), uv.cols());\n Eigen::ArrayXd result_0(uv.rows());\n" + \
401 "switch(local_index){\n"
403 for i
in range(0, fe.nbf()):
404 real_index = indices[i]
408 simplify(fe.N[real_index])).replace(
" = 1;",
".setOnes();") +
"} break;\n"
409 dbase = dbase +
"\tcase " + str(i) +
": {" + \
410 "{" +
pretty_print.C99_print(simplify(diff(fe.N[real_index], x))).replace(
" = 0;",
".setZero();").replace(
" = 1;",
".setOnes();").replace(
" = -1;",
".setConstant(-1);") +
"val.col(0) = result_0; }" \
411 "{" +
pretty_print.C99_print(simplify(diff(fe.N[real_index], y))).replace(
" = 0;",
".setZero();").replace(
412 " = 1;",
".setOnes();").replace(
" = -1;",
".setConstant(-1);") +
"val.col(1) = result_0; }"
415 dbase = dbase +
"{" +
pretty_print.C99_print(simplify(diff(fe.N[real_index], z))).replace(
" = 0;",
".setZero();").replace(
416 " = 1;",
".setOnes();").replace(
" = -1;",
".setConstant(-1);") +
"val.col(2) = result_0; }"
418 dbase = dbase +
"} break;\n"
420 base = base +
"\tdefault: assert(false);\n}"
421 dbase = dbase +
"\tdefault: assert(false);\n}"
423 cpp = cpp + func +
"{\n\n"
424 cpp = cpp + base +
"}\n"
426 cpp = cpp + dfunc +
"{\n\n"
427 cpp = cpp + dbase +
"}\n\n\n"
429 cpp = cpp + nodes +
"\n\n\n"
431 unique_nodes = unique_nodes +
"\tdefault: assert(false);\n}}"
433 unique_fun = unique_fun +
"\tdefault: assert(false); \n}}"
434 dunique_fun = dunique_fun +
"\tdefault: assert(false); \n}}"
436 cpp = cpp +
"}\n\n" + unique_nodes +
"\n" + unique_fun + \
437 "\n\n" + dunique_fun +
"\n" +
"\nnamespace " +
"{\n"
441 f
"\nstatic const int MAX_{bletter.capitalize()}_BASES = {max(orders)};\n"
443 cpp = cpp +
"\n}}}\n"
446 path = os.path.abspath(args.output)
449 with open(os.path.join(path, f
"auto_{bletter}_bases.cpp"),
"w")
as file:
452 with open(os.path.join(path, f
"auto_{bletter}_bases.hpp"),
"w")
as file:
create_matrix(equations, coeffs)