PolyFEM
Loading...
Searching...
No Matches
p_bases.py
Go to the documentation of this file.
1# https://raw.githubusercontent.com/sympy/sympy/master/examples/advanced/fem.py
2from sympy import *
3import os
4import numpy as np
5import argparse
6from sympy.printing import ccode
7
8
9import pretty_print
10
11x, y, z = symbols('x,y,z')
12
13
15 def __init__(self, nsd):
16 self.nsd = nsd
17 if nsd <= 3:
18 coords = symbols('x,y,z')[:nsd]
19 else:
20 coords = [Symbol("x_%d" % d) for d in range(nsd)]
21 self.coords = coords
22
23 def integrate(self, f):
24 coords = self.coords
25 nsd = self.nsd
26
27 limit = 1
28 for p in coords:
29 limit -= p
30
31 intf = f
32 for d in range(0, nsd):
33 p = coords[d]
34 limit += p
35 intf = integrate(intf, (p, 0, limit))
36 return intf
37
38
39def bernstein_space(order, nsd):
40 if nsd > 3:
41 raise RuntimeError("Bernstein only implemented in 1D, 2D, and 3D")
42 sum = 0
43 basis = []
44 coeff = []
45
46 if nsd == 2:
47 b1, b2, b3 = x, y, 1 - x - y
48 for o1 in range(0, order + 1):
49 for o2 in range(0, order + 1):
50 for o3 in range(0, order + 1):
51 if o1 + o2 + o3 == order:
52 aij = Symbol("a_%d_%d_%d" % (o1, o2, o3))
53 fac = factorial(order) / (factorial(o1) *
54 factorial(o2) * factorial(o3))
55 sum += aij * fac * pow(b1, o1) * \
56 pow(b2, o2) * pow(b3, o3)
57 basis.append(fac * pow(b1, o1) *
58 pow(b2, o2) * pow(b3, o3))
59 coeff.append(aij)
60
61 if nsd == 3:
62 b1, b2, b3, b4 = x, y, z, 1 - x - y - z
63 for o1 in range(0, order + 1):
64 for o2 in range(0, order + 1):
65 for o3 in range(0, order + 1):
66 for o4 in range(0, order + 1):
67 if o1 + o2 + o3 + o4 == order:
68 aij = Symbol("a_%d_%d_%d_%d" % (o1, o2, o3, o4))
69 fac = factorial(
70 order) / (factorial(o1) * factorial(o2) * factorial(o3) * factorial(o4))
71 sum += aij * fac * \
72 pow(b1, o1) * pow(b2, o2) * \
73 pow(b3, o3) * pow(b4, o4)
74 basis.append(fac * pow(b1, o1) * pow(b2, o2) *
75 pow(b3, o3) * pow(b4, o4))
76 coeff.append(aij)
77
78 return sum, coeff, basis
79
80
81def create_point_set(order, nsd):
82 h = Rational(1, order)
83 set = []
84
85 if nsd == 2:
86 for i in range(0, order + 1):
87 x = i * h
88 for j in range(0, order + 1):
89 y = j * h
90 if x + y <= 1:
91 set.append((x, y))
92
93 if nsd == 3:
94 for i in range(0, order + 1):
95 x = i * h
96 for j in range(0, order + 1):
97 y = j * h
98 for k in range(0, order + 1):
99 z = k * h
100 if x + y + z <= 1:
101 set.append((x, y, z))
102
103 return set
104
105
106def create_matrix(equations, coeffs):
107 A = zeros(len(equations))
108 i = 0
109 j = 0
110 for j in range(0, len(coeffs)):
111 c = coeffs[j]
112 for i in range(0, len(equations)):
113 e = equations[i]
114 d, _ = reduced(e, [c])
115 A[i, j] = d[0]
116 return A
117
118
120 def __init__(self, nsd, order, bernstein):
121 self.nsd = nsd
122 self.bernstein = bernstein
123 self.order = order
124 self.points = []
125 self.compute_basis()
126
127 def nbf(self):
128 return len(self.N)
129
130 def compute_basis(self):
131 order = self.order
132 nsd = self.nsd
133 N = []
134 pol, coeffs, basis = bernstein_space(order, nsd)
135 self.points = create_point_set(order, nsd)
136
137 equations = []
138 for p in self.points:
139 ex = pol.subs(x, p[0])
140 if nsd > 1:
141 ex = ex.subs(y, p[1])
142 if nsd > 2:
143 ex = ex.subs(z, p[2])
144 equations.append(ex)
145
146 b = eye(len(equations))
147 if self.bernstein:
148 xx = b
149 else:
150 A = create_matrix(equations, coeffs)
151
152 # if A.shape[0] > 25:
153 # A = A.evalf()
154
155 Ainv = A.inv()
156 xx = Ainv * b
157
158 for i in range(0, len(equations)):
159 Ni = pol
160 for j in range(0, len(coeffs)):
161 Ni = Ni.subs(coeffs[j], xx[j, i])
162 N.append(Ni)
163
164 self.N = N
165
166
168 parser = argparse.ArgumentParser(
169 description=__doc__,
170 formatter_class=argparse.RawDescriptionHelpFormatter)
171 parser.add_argument("output", type=str, help="path to the output folder")
172 parser.add_argument("--bernstein", default=False, action='store_true',
173 help="use Bernstein basis instead of Lagrange basis")
174 return parser.parse_args()
175
176
177if __name__ == "__main__":
178 args = parse_args()
179
180 dims = [2, 3]
181
182 orders = [0, 1, 2, 3, 4]
183 # orders = [4]
184
185 bletter = "b" if args.bernstein else "p"
186
187 cpp = f"#include \"auto_{bletter}_bases.hpp\""
188 if not args.bernstein:
189 cpp = cpp + "\n#include \"auto_b_bases.hpp\""
190 cpp = cpp + "\n#include \"p_n_bases.hpp\""
191 cpp = cpp + "\n\n\n" \
192 "namespace polyfem {\nnamespace autogen " + "{\nnamespace " + "{\n"
193
194 hpp = "#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n"
195
196 hpp = hpp + "\nnamespace polyfem {\nnamespace autogen " + "{\n"
197
198 for dim in dims:
199 assert dim in (2, 3), "P simplex autogen supports only triangles and tetrahedra"
200 print(str(dim) + "D " + bletter)
201 suffix = "2d" if dim == 2 else "3d"
202
203 unique_nodes = f"void {bletter}_nodes_{suffix}" + \
204 f"(const int {bletter}, Eigen::MatrixXd &val)"
205
206 if args.bernstein:
207 unique_fun = f"void {bletter}_basis_value_{suffix}" + \
208 f"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
209 dunique_fun = f"void {bletter}_grad_basis_value_{suffix}" + \
210 f"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
211 else:
212 unique_fun = f"void {bletter}_basis_value_{suffix}" + \
213 f"(const bool bernstein, const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
214 dunique_fun = f"void {bletter}_grad_basis_value_{suffix}" + \
215 f"(const bool bernstein, const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
216
217 if not args.bernstein:
218 hpp = hpp + unique_nodes + ";\n\n"
219
220 hpp = hpp + unique_fun + ";\n\n"
221 hpp = hpp + dunique_fun + ";\n\n"
222
223 unique_nodes = unique_nodes + f"{{\nswitch({bletter})" + "{\n"
224
225 unique_fun = unique_fun + "{\n"
226 dunique_fun = dunique_fun + "{\n"
227
228 if not args.bernstein:
229 unique_fun = unique_fun + \
230 f"if(bernstein) {{ b_basis_value_{suffix}(p, local_index, uv, val); return; }}\n\n"
231 dunique_fun = dunique_fun + \
232 f"if(bernstein) {{ b_grad_basis_value_{suffix}(p, local_index, uv, val); return; }}\n\n"
233
234 unique_fun = unique_fun + f"\nswitch({bletter})" + "{\n"
235 dunique_fun = dunique_fun + f"\nswitch({bletter})" + "{\n"
236
237 if dim == 2:
238 vertices = [[0, 0], [1, 0], [0, 1]]
239 elif dim == 3:
240 vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
241
242 for order in orders:
243 print("\t-processing " + str(order))
244
245 if order == 0:
246 def fe(): return None
247 fe.nbf = lambda: 1
248
249 fe.N = [1]
250
251 if dim == 2:
252 fe.points = [[1./3., 1./3.]]
253 else:
254 fe.points = [[1./3., 1./3., 1./3.]]
255 else:
256 fe = Lagrange(dim, order, args.bernstein)
257
258 current_indices = list(range(0, len(fe.points)))
259 indices = []
260
261 # vertex coordinate
262 for i in range(0, dim + 1):
263 vv = vertices[i]
264 for ii in current_indices:
265 norm = 0
266 for dd in range(0, dim):
267 norm = norm + (vv[dd] - fe.points[ii][dd]) ** 2
268
269 if norm < 1e-10:
270 indices.append(ii)
271 current_indices.remove(ii)
272 break
273
274 # edge 1 coordinate
275 for i in range(0, order - 1):
276 for ii in current_indices:
277 if fe.points[ii][1] != 0 or (dim == 3 and fe.points[ii][2] != 0):
278 continue
279
280 if abs(fe.points[ii][0] - (i + 1) / order) < 1e-10:
281 indices.append(ii)
282 current_indices.remove(ii)
283 break
284
285 # edge 2 coordinate
286 for i in range(0, order - 1):
287 for ii in current_indices:
288 if fe.points[ii][0] + fe.points[ii][1] != 1 or (dim == 3 and fe.points[ii][2] != 0):
289 continue
290
291 if abs(fe.points[ii][1] - (i + 1) / order) < 1e-10:
292 indices.append(ii)
293 current_indices.remove(ii)
294 break
295
296 # edge 3 coordinate
297 for i in range(0, order - 1):
298 for ii in current_indices:
299 if fe.points[ii][0] != 0 or (dim == 3 and fe.points[ii][2] != 0):
300 continue
301
302 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
303 indices.append(ii)
304 current_indices.remove(ii)
305 break
306
307 if dim == 3:
308 # edge 4 coordinate
309 for i in range(0, order - 1):
310 for ii in current_indices:
311 if fe.points[ii][0] != 0 or fe.points[ii][1] != 0:
312 continue
313
314 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
315 indices.append(ii)
316 current_indices.remove(ii)
317 break
318
319 # edge 5 coordinate
320 for i in range(0, order - 1):
321 for ii in current_indices:
322 if fe.points[ii][0] + fe.points[ii][2] != 1 or fe.points[ii][1] != 0:
323 continue
324
325 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
326 indices.append(ii)
327 current_indices.remove(ii)
328 break
329
330 # edge 6 coordinate
331 for i in range(0, order - 1):
332 for ii in current_indices:
333 if fe.points[ii][1] + fe.points[ii][2] != 1 or fe.points[ii][0] != 0:
334 continue
335
336 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
337 indices.append(ii)
338 current_indices.remove(ii)
339 break
340
341 if dim == 3:
342 nn = max(0, order - 2)
343 npts = int(nn * (nn + 1) / 2)
344
345 # bottom: z = 0
346 for i in range(0, npts):
347 for ii in current_indices:
348 if abs(fe.points[ii][2]) > 1e-10:
349 continue
350
351 indices.append(ii)
352 current_indices.remove(ii)
353 break
354
355 # front: y = 0
356 for i in range(0, npts):
357 for ii in current_indices:
358 if abs(fe.points[ii][1]) > 1e-10:
359 continue
360
361 indices.append(ii)
362 current_indices.remove(ii)
363 break
364
365 # diagonal: none equal to zero and sum 1
366 tmp = []
367 for i in range(0, npts):
368 for ii in current_indices:
369 if (abs(fe.points[ii][0]) < 1e-10) | (abs(fe.points[ii][1]) < 1e-10) | (abs(fe.points[ii][2]) < 1e-10):
370 continue
371
372 if abs((fe.points[ii][0] + fe.points[ii][1] + fe.points[ii][2]) - 1) > 1e-10:
373 continue
374
375 tmp.append(ii)
376 current_indices.remove(ii)
377 break
378 for i in range(0, len(tmp)):
379 indices.append(tmp[(i + 2) % len(tmp)])
380
381 # side: x = 0
382 tmp = []
383 for i in range(0, npts):
384 for ii in current_indices:
385 if abs(fe.points[ii][0]) > 1e-10:
386 continue
387
388 tmp.append(ii)
389 current_indices.remove(ii)
390 break
391 tmp.sort(reverse=True)
392 indices.extend(tmp)
393
394 # either face or volume indices, order do not matter
395 for ii in current_indices:
396 indices.append(ii)
397
398 # nodes code gen
399 nodes = f"void {bletter}_{order}_nodes_{suffix}(Eigen::MatrixXd &res) {{\n res.resize(" + str(
400 len(indices)) + ", " + str(dim) + "); res << \n"
401 unique_nodes = unique_nodes + f"\tcase {order}: " + \
402 f"{bletter}_{order}_nodes_{suffix}(val); break;\n"
403
404 for ii in indices:
405 nodes = nodes + ccode(fe.points[ii][0]) + ", " + ccode(fe.points[ii][1]) + (
406 (", " + ccode(fe.points[ii][2])) if dim == 3 else "") + ",\n"
407 nodes = nodes[:-2]
408 nodes = nodes + ";\n}"
409
410 # bases code gen
411 # Generate two functions:
412 # - "func" to eval basis value
413 # - "dfunc" to eval basis gradient.
414 # Both function evaluates quadrature points in batch by dispatching the "xxx_single" basis
415 # kernel inside a for loop. In this script the single kernel related codegen variable
416 # is denoted with scalar_ prefix.
417 func = f"void {bletter}_{order}_basis_value_{suffix}" + \
418 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &result_0)"
419 dfunc = f"void {bletter}_{order}_basis_grad_value_{suffix}" + \
420 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
421 scalar_func_name = f"{bletter}_{order}_basis_value_{suffix}_single"
422 scalar_dfunc_name = f"{bletter}_{order}_basis_grad_value_{suffix}_single"
423
424 unique_fun = unique_fun + \
425 f"\tcase {order}: {bletter}_{order}_basis_value_{suffix}(local_index, uv, val); break;\n"
426 dunique_fun = dunique_fun + \
427 f"\tcase {order}: {bletter}_{order}_basis_grad_value_{suffix}(local_index, uv, val); break;\n"
428
429 # hpp = hpp + func + ";\n"
430 # hpp = hpp + dfunc + ";\n"
431
432 if not args.bernstein:
433 default_base = "p_n_basis_value_3d(p, local_index, uv, val);" if dim == 3 else "p_n_basis_value_2d(p, local_index, uv, val);"
434 default_dbase = "p_n_basis_grad_value_3d(p, local_index, uv, val);" if dim == 3 else "p_n_basis_grad_value_2d(p, local_index, uv, val);"
435 default_nodes = "p_n_nodes_3d(p, val);" if dim == 3 else "p_n_nodes_2d(p, val);"
436
437 # Single basis kernel.
438 base = ""
439 dbase = ""
440 # Batch basis kernel. Basically a switch dispatch + for loop.
441 base_cases = "switch(local_index){\n"
442 dbase_cases = "switch(local_index){\n"
443
444 for i in range(0, fe.nbf()):
445 real_index = indices[i]
446 value_name = f"{scalar_func_name}_{i}"
447 grad_name = f"{scalar_dfunc_name}_{i}"
448 basis = simplify(fe.N[real_index])
449
450 base = base + pretty_print.C99_print_scalar_value_function(value_name, basis, dim)
451 dbase = dbase + pretty_print.C99_print_scalar_gradient_function(grad_name, basis, dim)
452 base_cases = base_cases + pretty_print.C99_print_scalar_value_case(i, value_name, dim)
453 dbase_cases = dbase_cases + pretty_print.C99_print_scalar_gradient_case(i, grad_name, dim)
454
455 base_cases = base_cases + "\tdefault: assert(false);\n}"
456 dbase_cases = dbase_cases + "\tdefault: assert(false);\n}"
457
458 cpp = cpp + base + "\n\n"
459 cpp = cpp + func + "{\n"
460 cpp = cpp + "result_0.resize(uv.rows(), 1);\n"
461 cpp = cpp + base_cases + "\n}\n"
462
463 cpp = cpp + dbase + "\n\n"
464 cpp = cpp + dfunc + "{\n"
465 cpp = cpp + f"val.resize(uv.rows(), {dim});\n"
466 cpp = cpp + f"double gradient[{dim}];\n"
467 cpp = cpp + dbase_cases + "\n}\n\n\n"
468
469 if not args.bernstein:
470 cpp = cpp + nodes + "\n\n\n"
471
472 if args.bernstein:
473 unique_nodes = ""
474 else:
475 unique_nodes = unique_nodes + "\tdefault: "+default_nodes+"\n}}"
476
477 if args.bernstein:
478 unique_fun = unique_fun + "\tdefault: assert(false); \n}}"
479 dunique_fun = dunique_fun + "\tdefault: assert(false); \n}}"
480 else:
481 unique_fun = unique_fun + "\tdefault: "+default_base+"\n}}"
482 dunique_fun = dunique_fun + "\tdefault: "+default_dbase+"\n}}"
483
484 cpp = cpp + "}\n\n" + unique_nodes + "\n" + unique_fun + \
485 "\n\n" + dunique_fun + "\n" + "\nnamespace " + "{\n"
486 hpp = hpp + "\n"
487
488 hpp = hpp + \
489 f"\nstatic const int MAX_{bletter.capitalize()}_BASES = {max(orders)};\n"
490
491 cpp = cpp + "\n}}}\n"
492 hpp = hpp + "\n}}\n"
493
494 path = os.path.abspath(args.output)
495
496 print("saving...")
497 with open(os.path.join(path, f"auto_{bletter}_bases.cpp"), "w") as file:
498 file.write(cpp)
499
500 with open(os.path.join(path, f"auto_{bletter}_bases.hpp"), "w") as file:
501 file.write(hpp)
502
503 print("done!")
__init__(self, nsd, order, bernstein)
Definition p_bases.py:120
compute_basis(self)
Definition p_bases.py:130
create_matrix(equations, coeffs)
Definition p_bases.py:106
create_point_set(order, nsd)
Definition p_bases.py:81
bernstein_space(order, nsd)
Definition p_bases.py:39
parse_args()
Definition p_bases.py:167
C99_print_scalar_gradient_function(function_name, expr, dim)
C99_print_scalar_value_function(function_name, expr, dim)
C99_print_scalar_value_case(local_index, function_name, dim)
C99_print_scalar_gradient_case(local_index, function_name, dim)