PolyFEM
Loading...
Searching...
No Matches
q_bases.py
Go to the documentation of this file.
1# https://raw.githubusercontent.com/sympy/sympy/master/examples/advanced/fem.py
2from sympy import *
3import os
4import argparse
5from sympy.printing import ccode
6
7
8import pretty_print
9
10x, y, z = symbols('x,y,z')
11
12
14 def __init__(self, nsd):
15 self.nsd = nsd
16 if nsd <= 3:
17 coords = symbols('x,y,z')[:nsd]
18 else:
19 coords = [Symbol("x_%d" % d) for d in range(nsd)]
20 self.coords = coords
21
22 def integrate(self, f):
23 coords = self.coords
24 nsd = self.nsd
25
26 limit = 1
27 for p in coords:
28 limit -= p
29
30 intf = f
31 for d in range(0, nsd):
32 p = coords[d]
33 limit += p
34 intf = integrate(intf, (p, 0, limit))
35 return intf
36
37
38def create_point_set(order, nsd):
39 h = Rational(1, order)
40 set = []
41
42 if nsd == 2:
43 for i in range(0, order + 1):
44 x = i * h
45 for j in range(0, order + 1):
46 y = j * h
47 set.append((x, y))
48
49 if nsd == 3:
50 for i in range(0, order + 1):
51 x = i * h
52 for j in range(0, order + 1):
53 y = j * h
54 for k in range(0, order + 1):
55 z = k * h
56 set.append((x, y, z))
57
58 return set
59
60
62 def __init__(self, nsd, order):
63 self.nsd = nsd
64 self.order = order
65 self.points = []
66 self.compute_basis()
67
68 def nbf(self):
69 return len(self.N)
70
71 def compute_basis(self):
72 order = self.order
73 nsd = self.nsd
74 N = []
75 self.points = create_point_set(order, nsd)
76
77 if nsd == 2:
78 Ntmpx = []
79 Ntmpy = []
80
81 for j in range(order + 1):
82 vx = 1
83 vy = 1
84 xj = 1./(order)*j
85 for m in range(order+1):
86 if m == j:
87 continue
88 xm = 1./(order)*m
89 vx *= (x - xm)/(xj - xm)
90 vy *= (y - xm)/(xj - xm)
91
92 Ntmpx.append(vx)
93 Ntmpy.append(vy)
94
95 for i in range(order + 1):
96 for j in range(order + 1):
97 N.append(Ntmpx[i]*Ntmpy[j])
98 elif nsd == 3:
99 Ntmpx = []
100 Ntmpy = []
101 Ntmpz = []
102
103 for j in range(order + 1):
104 vx = 1
105 vy = 1
106 vz = 1
107 xj = 1./(order)*j
108 for m in range(order+1):
109 if m == j:
110 continue
111 xm = 1./(order)*m
112 vx *= (x - xm)/(xj - xm)
113 vy *= (y - xm)/(xj - xm)
114 vz *= (z - xm)/(xj - xm)
115
116 Ntmpx.append(vx)
117 Ntmpy.append(vy)
118 Ntmpz.append(vz)
119
120 for i in range(order + 1):
121 for j in range(order + 1):
122 for l in range(order + 1):
123 N.append(Ntmpx[i]*Ntmpy[j]*Ntmpz[l])
124
125 self.N = N
126
127
129 parser = argparse.ArgumentParser(
130 description=__doc__,
131 formatter_class=argparse.RawDescriptionHelpFormatter)
132 parser.add_argument("output", type=str, help="path to the output folder")
133 return parser.parse_args()
134
135
136if __name__ == "__main__":
137 args = parse_args()
138 path = os.path.abspath(args.output)
139
140 dims = [2, 3]
141 orders = [0, 1, 2, 3, -2]
142
143 for dim in dims:
144 namev = f"auto_q_bases_{dim}d_val"
145 namen = f"auto_q_bases_{dim}d_nodes"
146 nameg = f"auto_q_bases_{dim}d_grad"
147
148 cppv = f"#include \"{namev}.hpp\"\n\n\n"
149 cppv = cppv + \
150 "namespace polyfem {\nnamespace autogen " + "{\nnamespace " + "{\n"
151
152 cppn = f"#include \"{namen}.hpp\"\n\n\n"
153 cppn = cppn + \
154 "namespace polyfem {\nnamespace autogen " + "{\nnamespace " + "{\n"
155
156 cppg = f"#include \"{nameg}.hpp\"\n\n\n"
157 cppg = cppg + \
158 "namespace polyfem {\nnamespace autogen " + "{\nnamespace " + "{\n"
159 if dim == 3:
160 cppg = "#include <Eigen/Dense>\n#include <cassert>\n namespace polyfem {\nnamespace autogen {"
161
162 eextern = ""
163
164 hppv = "#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n\n"
165 hppv = hppv + "namespace polyfem {\nnamespace autogen " + "{\n"
166
167 hppn = "#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n\n"
168 hppn = hppn + "namespace polyfem {\nnamespace autogen " + "{\n"
169
170 hppg = "#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n\n"
171 hppg = hppg + "namespace polyfem {\nnamespace autogen " + "{\n"
172
173 print(str(dim) + "D")
174 suffix = "_2d" if dim == 2 else "_3d"
175
176 unique_nodes = "void q_nodes" + suffix + \
177 "(const int q, Eigen::MatrixXd &val)"
178
179 unique_fun = "void q_basis_value" + suffix + \
180 "(const int q, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
181 dunique_fun = "void q_grad_basis_value" + suffix + \
182 "(const int q, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
183
184 hppn = hppn + unique_nodes + ";\n\n"
185
186 hppv = hppv + unique_fun + ";\n\n"
187 hppg = hppg + dunique_fun + ";\n\n"
188
189 unique_nodes = unique_nodes + "{\nswitch(q)" + "{\n"
190
191 unique_fun = unique_fun + "{\nswitch(q)" + "{\n"
192 dunique_fun = dunique_fun + "{\nswitch(q)" + "{\n"
193
194 if dim == 2:
195 vertices = [[0, 0], [1, 0], [1, 1], [0, 1]]
196 elif dim == 3:
197 vertices = [[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],
198 [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]]
199
200 for order in orders:
201 print("\t-processing " + str(order))
202
203 if order == 0:
204 def fe(): return None
205 fe.nbf = lambda: 1
206
207 fe.N = [1]
208
209 if dim == 2:
210 fe.points = [[1./2., 1./2.]]
211 else:
212 fe.points = [[1./2., 1./2., 1./2.]]
213 elif order == -2:
214 def fe(): return None
215 if dim == 2:
216 fe.points = []
217 fe.N = []
218 fe.nbf = lambda: 8
219
220 for xi_a in [-1, 1]:
221 for eta_a in [-1, 1]:
222 tmp = (1/4*(2*x*xi_a-xi_a+1))*(2*eta_a*y -
223 eta_a+1)*(2*eta_a*y+2*x*xi_a-eta_a-xi_a-1)
224 fe.N.append(tmp)
225 fe.points.append([(xi_a+1)/2, (eta_a+1)/2])
226
227 for eta_a in [-1, 1]:
228 tmp = -2*x*(x-1)*(2*eta_a*y-eta_a+1)
229 fe.N.append(tmp)
230 fe.points.append([1/2, (eta_a+1)/2])
231
232 for xi_a in [-1, 1]:
233 tmp = -2*y*(y-1)*(2*x*xi_a-xi_a+1)
234 fe.N.append(tmp)
235 fe.points.append([(xi_a+1)/2, 1/2])
236
237 assert (len(fe.points) == 8)
238 assert (len(fe.N) == 8)
239
240 elif dim == 3:
241 fe.points = []
242 fe.N = []
243 fe.nbf = lambda: 20
244
245 for xi_a in [-1, 1]:
246 for eta_a in [-1, 1]:
247 for zeta_a in [-1, 1]:
248 tmp = (1/8*(2*x*xi_a-xi_a+1))*(2*eta_a*y-eta_a+1)*(2*z*zeta_a-zeta_a+1)*(
249 2*eta_a*y+2*x*xi_a+2*z*zeta_a-eta_a-xi_a-zeta_a-2)
250 fe.N.append(tmp)
251 fe.points.append(
252 [(xi_a+1)/2, (eta_a+1)/2, (zeta_a+1)/2])
253
254 for eta_a in [-1, 1]:
255 for zeta_a in [-1, 1]:
256 tmp = -x*(x-1)*(2*eta_a*y-eta_a+1) * \
257 (2*z*zeta_a-zeta_a+1)
258 fe.N.append(tmp)
259 fe.points.append([1/2, (eta_a+1)/2, (zeta_a+1)/2])
260
261 for xi_a in [-1, 1]:
262 for zeta_a in [-1, 1]:
263 tmp = -y*(y-1)*(2*x*xi_a-xi_a+1) * \
264 (2*z*zeta_a-zeta_a+1)
265 fe.N.append(tmp)
266 fe.points.append([(xi_a+1)/2, 1/2, (zeta_a+1)/2])
267
268 for xi_a in [-1, 1]:
269 for eta_a in [-1, 1]:
270 tmp = -z*(z-1)*(2*x*xi_a-xi_a+1) * \
271 (2*eta_a*y-eta_a+1)
272 fe.N.append(tmp)
273 fe.points.append([(xi_a+1)/2, (eta_a+1)/2, 1/2])
274
275 assert (len(fe.points) == 20)
276 assert (len(fe.N) == 20)
277 else:
278 assert (False)
279
280 else:
281 fe = Lagrange(dim, order)
282
283 current_indices = list(range(0, len(fe.points)))
284 indices = []
285
286 # vertex coordinate
287 for i in range(0, 4*(dim-1)):
288 vv = vertices[i]
289 for ii in current_indices:
290 norm = 0
291 for dd in range(0, dim):
292 norm = norm + (vv[dd] - fe.points[ii][dd]) ** 2
293
294 if norm < 1e-10:
295 indices.append(ii)
296 current_indices.remove(ii)
297 break
298
299 # edge 0 coordinate
300 for i in range(0, abs(order) - 1):
301 for ii in current_indices:
302 if fe.points[ii][1] != 0 or (dim == 3 and fe.points[ii][2] != 0):
303 continue
304
305 if abs(fe.points[ii][0] - (i + 1) / abs(order)) < 1e-10:
306 indices.append(ii)
307 current_indices.remove(ii)
308 break
309
310 # edge 1 coordinate
311 for i in range(0, abs(order) - 1):
312 for ii in current_indices:
313 if fe.points[ii][0] != 1 or (dim == 3 and fe.points[ii][2] != 0):
314 continue
315
316 if abs(fe.points[ii][1] - (i + 1) / abs(order)) < 1e-10:
317 indices.append(ii)
318 current_indices.remove(ii)
319 break
320
321 # edge 2 coordinate
322 for i in range(0, abs(order) - 1):
323 for ii in current_indices:
324 if fe.points[ii][1] != 1 or (dim == 3 and fe.points[ii][2] != 0):
325 continue
326
327 if abs(fe.points[ii][0] - (1 - (i + 1) / abs(order))) < 1e-10:
328 indices.append(ii)
329 current_indices.remove(ii)
330 break
331
332 # edge 3 coordinate
333 for i in range(0, abs(order) - 1):
334 for ii in current_indices:
335 if fe.points[ii][0] != 0 or (dim == 3 and fe.points[ii][2] != 0):
336 continue
337
338 if abs(fe.points[ii][1] - (1 - (i + 1) / abs(order))) < 1e-10:
339 indices.append(ii)
340 current_indices.remove(ii)
341 break
342
343 if dim == 3:
344 # edge 4 coordinate
345 for i in range(0, abs(order) - 1):
346 for ii in current_indices:
347 if fe.points[ii][0] != 0 or fe.points[ii][1] != 0:
348 continue
349
350 if abs(fe.points[ii][2] - (i + 1) / abs(order)) < 1e-10:
351 indices.append(ii)
352 current_indices.remove(ii)
353 break
354
355 # edge 5 coordinate
356 for i in range(0, abs(order) - 1):
357 for ii in current_indices:
358 if fe.points[ii][0] != 1 or fe.points[ii][1] != 0:
359 continue
360
361 if abs(fe.points[ii][2] - (1 - (i + 1) / abs(order))) < 1e-10:
362 indices.append(ii)
363 current_indices.remove(ii)
364 break
365
366 # edge 6 coordinate
367 for i in range(0, abs(order) - 1):
368 for ii in current_indices:
369 if fe.points[ii][0] != 1 or fe.points[ii][1] != 1:
370 continue
371
372 if abs(fe.points[ii][2] - (1 - (i + 1) / abs(order))) < 1e-10:
373 indices.append(ii)
374 current_indices.remove(ii)
375 break
376
377 # edge 7 coordinate
378 for i in range(0, abs(order) - 1):
379 for ii in current_indices:
380 if fe.points[ii][0] != 0 or fe.points[ii][1] != 1:
381 continue
382
383 if abs(fe.points[ii][2] - (1 - (i + 1) / abs(order))) < 1e-10:
384 indices.append(ii)
385 current_indices.remove(ii)
386 break
387
388 # edge 8 coordinate
389 for i in range(0, abs(order) - 1):
390 for ii in current_indices:
391 if fe.points[ii][1] != 0 or fe.points[ii][2] != 1:
392 continue
393
394 if abs(fe.points[ii][0] - (i + 1) / abs(order)) < 1e-10:
395 indices.append(ii)
396 current_indices.remove(ii)
397 break
398
399 # edge 9 coordinate
400 for i in range(0, abs(order) - 1):
401 for ii in current_indices:
402 if fe.points[ii][0] != 1 or fe.points[ii][2] != 1:
403 continue
404
405 if abs(fe.points[ii][1] - (i + 1) / abs(order)) < 1e-10:
406 indices.append(ii)
407 current_indices.remove(ii)
408 break
409
410 # edge 10 coordinate
411 for i in range(0, abs(order) - 1):
412 for ii in current_indices:
413 if fe.points[ii][1] != 1 or fe.points[ii][2] != 1:
414 continue
415
416 if abs(fe.points[ii][0] - (1 - (i + 1) / abs(order))) < 1e-10:
417 indices.append(ii)
418 current_indices.remove(ii)
419 break
420
421 # edge 11 coordinate
422 for i in range(0, abs(order) - 1):
423 for ii in current_indices:
424 if fe.points[ii][0] != 0 or fe.points[ii][2] != 1:
425 continue
426
427 if abs(fe.points[ii][1] - (1 - (i + 1) / abs(order))) < 1e-10:
428 indices.append(ii)
429 current_indices.remove(ii)
430 break
431
432 if dim == 3:
433 nn = max(0, abs(order) - 1)
434 npts = int(nn * nn)
435
436 # side: x = 0
437 tmp = []
438 for i in range(0, npts):
439 for ii in current_indices:
440 if abs(fe.points[ii][0]) > 1e-10:
441 continue
442
443 tmp.append(ii)
444 current_indices.remove(ii)
445 break
446 tmp.sort(reverse=True)
447 indices.extend(tmp)
448
449 # side: x = 1
450 tmp = []
451 for i in range(0, npts):
452 for ii in current_indices:
453 if abs(fe.points[ii][0] - 1) > 1e-10:
454 continue
455
456 tmp.append(ii)
457 current_indices.remove(ii)
458 break
459 indices.extend(tmp)
460
461 # front: y = 0
462 for i in range(0, npts):
463 for ii in current_indices:
464 if abs(fe.points[ii][1]) > 1e-10:
465 continue
466
467 indices.append(ii)
468 current_indices.remove(ii)
469 break
470
471 # back: y = 1
472 for i in range(0, npts):
473 for ii in current_indices:
474 if abs(fe.points[ii][1]-1) > 1e-10:
475 continue
476
477 indices.append(ii)
478 current_indices.remove(ii)
479 break
480
481 # bottom: z = 0
482 for i in range(0, npts):
483 for ii in current_indices:
484 if abs(fe.points[ii][2]) > 1e-10:
485 continue
486
487 indices.append(ii)
488 current_indices.remove(ii)
489 break
490
491 # top: z = 1
492 for i in range(0, npts):
493 for ii in current_indices:
494 if abs(fe.points[ii][2]-1) > 1e-10:
495 continue
496
497 indices.append(ii)
498 current_indices.remove(ii)
499 break
500
501 # either face or volume indices, order do not matter
502 for ii in current_indices:
503 indices.append(ii)
504
505 orderN = str(order) if order >= 0 else "m"+str(-order)
506 # nodes code gen
507 nodes = "void q_" + orderN + "_nodes" + suffix + \
508 "(Eigen::MatrixXd &res) {\n res.resize(" + \
509 str(len(indices)) + ", " + str(dim) + "); res << \n"
510 unique_nodes = unique_nodes + "\tcase " + \
511 str(order) + ": " + "q_" + orderN + \
512 "_nodes" + suffix + "(val); break;\n"
513
514 eextern = eextern + \
515 f"extern \"C++\" void q_{orderN}_basis_grad_value_3d(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val);\n"
516
517 for ii in indices:
518 nodes = nodes + ccode(fe.points[ii][0]) + ", " + ccode(fe.points[ii][1]) + (
519 (", " + ccode(fe.points[ii][2])) if dim == 3 else "") + ",\n"
520 nodes = nodes[:-2]
521 nodes = nodes + ";\n}"
522
523 # bases code gen
524 func = "void q_" + orderN + "_basis_value" + suffix + \
525 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &result_0)"
526 dfunc = "void q_" + orderN + "_basis_grad_value" + suffix + \
527 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
528
529 unique_fun = unique_fun + "\tcase " + \
530 str(order) + ": " + "q_" + orderN + "_basis_value" + \
531 suffix + "(local_index, uv, val); break;\n"
532 dunique_fun = dunique_fun + "\tcase " + \
533 str(order) + ": " + "q_" + orderN + "_basis_grad_value" + \
534 suffix + "(local_index, uv, val); break;\n"
535
536 # hpp = hpp + func + ";\n"
537 # hpp = hpp + dfunc + ";\n"
538
539 base = "auto x=uv.col(0).array();\nauto y=uv.col(1).array();"
540 if dim == 3:
541 base = base + "\nauto z=uv.col(2).array();"
542 base = base + "\n\n"
543 dbase = base
544
545 if order == 0:
546 base = base + "result_0.resize(x.size(),1);\n"
547
548 base = base + "switch(local_index){\n"
549 dbase = dbase + \
550 "val.resize(uv.rows(), uv.cols());\n Eigen::ArrayXd result_0(uv.rows());\n" + \
551 "switch(local_index){\n"
552
553 for i in range(0, fe.nbf()):
554 real_index = indices[i]
555 # real_index = i
556
557 if dim == 3:
558 base = base + "\tcase " + str(i) + ": {" + pretty_print.C99_print(
559 simplify(fe.N[real_index])).replace(" = 1;", ".setOnes();") + "} break;\n"
560 dbase = dbase + "\tcase " + str(i) + ": {" + \
561 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], x))).replace(" = 0;", ".setZero();").replace(" = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(0) = result_0; }" \
562 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], y))).replace(" = 0;", ".setZero();").replace(
563 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(1) = result_0; }"
564 dbase = dbase + "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], z))).replace(" = 0;", ".setZero();").replace(
565 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(2) = result_0; }"
566 else:
567 base = base + "\tcase " + str(i) + ": {" + pretty_print.C99_print(
568 simplify(fe.N[real_index])).replace(" = 1;", ".setOnes();") + "} break;\n"
569 dbase = dbase + "\tcase " + str(i) + ": {" + \
570 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], x))).replace(" = 0;", ".setZero();").replace(" = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(0) = result_0; }" \
571 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], y))).replace(" = 0;", ".setZero();").replace(
572 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(1) = result_0; }"
573
574 dbase = dbase + "} break;\n"
575
576 base = base + "\tdefault: assert(false);\n}"
577 dbase = dbase + "\tdefault: assert(false);\n}"
578
579 cppv = cppv + func + "{\n\n"
580 cppv = cppv + base + "}\n"
581
582 cppg = cppg + dfunc + "{\n\n"
583 cppg = cppg + dbase + "}\n\n"
584 cppn = cppn + nodes + "\n\n"
585
586 if dim == 3:
587 with open(os.path.join(path, f"{nameg}_{order}.cpp"), "w") as file:
588 file.write(cppg+"}}")
589 cppg = "#include <Eigen/Dense>\n#include <cassert>\n namespace polyfem {\nnamespace autogen {"
590
591 if dim == 3:
592 cppg = ""
593 unique_nodes = unique_nodes + "\tdefault: assert(false);\n}}"
594
595 unique_fun = unique_fun + "\tdefault: assert(false);\n}}"
596 dunique_fun = dunique_fun + "\tdefault: assert(false);\n}}"
597
598 cppv = cppv + "}\n\n"
599 cppn = cppn + "}\n\n"
600 if dim != 3:
601 cppg = cppg + "}\n\n"
602
603 cppn = cppn + unique_nodes + "\n}}\n"
604 cppv = cppv + unique_fun + "\n}}\n"
605 cppg = cppg + dunique_fun + "\n}}\n"
606 hppv = hppv + "\n}}\n"
607 hppn = hppn + "\n}}\n"
608 hppg = hppg + "\n}}\n"
609
610 if dim == 3:
611 tcppg = f"#include \"{nameg}.hpp\"\n\n\n"
612 tcppg = tcppg + "namespace polyfem {\nnamespace autogen {\n"
613 tcppg = tcppg + eextern + "\n"
614 cppg = tcppg+cppg
615
616 print("saving...")
617 with open(os.path.join(path, f"{namev}.cpp"), "w") as file:
618 file.write(cppv)
619 with open(os.path.join(path, f"{namen}.cpp"), "w") as file:
620 file.write(cppn)
621 with open(os.path.join(path, f"{nameg}.cpp"), "w") as file:
622 file.write(cppg)
623
624 with open(os.path.join(path, f"{namev}.hpp"), "w") as file:
625 file.write(hppv)
626 with open(os.path.join(path, f"{namen}.hpp"), "w") as file:
627 file.write(hppn)
628 with open(os.path.join(path, f"{nameg}.hpp"), "w") as file:
629 file.write(hppg)
630
631 hpp = "#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n\n"
632 for dim in dims:
633 hpp = hpp + f"#include \"auto_q_bases_{dim}d_val.hpp\"\n"
634 hpp = hpp + f"#include \"auto_q_bases_{dim}d_nodes.hpp\"\n"
635 hpp = hpp + f"#include \"auto_q_bases_{dim}d_grad.hpp\"\n"
636 hpp = hpp + "\n\nnamespace polyfem {\nnamespace autogen " + "{\n"
637 hpp = hpp + "\nstatic const int MAX_Q_BASES = " + str(max(orders)) + ";\n"
638 hpp = hpp + "\n}}\n"
639
640 print("saving...")
641 with open(os.path.join(path, "auto_q_bases.hpp"), "w") as file:
642 file.write(hpp)
643
644 print("done!")
__init__(self, nsd, order)
Definition q_bases.py:62
compute_basis(self)
Definition q_bases.py:71
C99_print(expr)
create_point_set(order, nsd)
Definition q_bases.py:38
parse_args()
Definition q_bases.py:128