PolyFEM
Loading...
Searching...
No Matches
pyramid_bases.py
Go to the documentation of this file.
1# https://raw.githubusercontent.com/sympy/sympy/master/examples/advanced/fem.py
2from sympy import *
3import os
4import numpy as np
5import argparse
6from sympy.printing import ccode
7
8import pretty_print
9
10x, y, z = symbols('x,y,z')
11
12def shuffle(a,b):
13 return [a[i] for i in b]
14
15def pyramid_space(order):
16 sum = 0
17 basis = []
18 coeff = []
19 # b1, b2, b3 = Rational(x, 1 - z), Rational(y, 1 - z), 1 - z
20 b1, b2, b3 = x / (1 - z), y / (1 - z), 1 - z
21 for k in range(order + 1):
22 for i in range(k + 1):
23 for j in range(k + 1):
24 aijk = Symbol("a_%d_%d_%d" % (i, j, k))
25 sum += aijk * b1**i * b2**j * b3**(k)
26 basis.append(b1**i * b2**j * b3**(k))
27 coeff.append(aijk)
28 return sum, coeff, basis
29
31 h = Rational(1, order)
32 set = []
33 # Base
34 for i in range(order + 1):
35 x = i * h
36 for j in range(order + 1):
37 y = j * h
38 set.append((x, y, 0))
39 # Apex
40 set.append((0, 0, 1))
41
42 # Side edges
43 for i in range(1, order):
44 z = i * h
45 for base_v in [(0, 0), (1, 0), (1, 1), (0, 1)]:
46 set.append((base_v[0] * (1 - z), base_v[1] * (1 - z), z))
47
48 # Side faces
49 for face in [[(0, 0, 1), (0, 0, 0), (1, 0, 0)], [(0, 0, 1), (1, 0, 0), (1, 1, 0)], [(0, 0, 1), (1, 1, 0), (0, 1, 0)], [(0, 0, 1), (0, 1, 0), (0, 0, 0)]]:
50 f_a, f_b, f_c = face[0], face[1], face[2]
51 for i in range(1, order):
52 alpha = i * h
53 for j in range(1, order):
54 beta = j * h
55 gamma = 1 - alpha - beta
56 if alpha > 0 and beta > 0 and gamma > 0:
57 x = alpha * f_a[0] + beta * f_b[0] + gamma * f_c[0] # barycentric interpolation of x-coordinate
58 y = alpha * f_a[1] + beta * f_b[1] + gamma * f_c[1] # barycentric interpolation of y-coordinate
59 z = alpha * f_a[2] + beta * f_b[2] + gamma * f_c[2] # barycentric interpolation of z-coordinate
60 set.append((x, y, z))
61
62 # Interior
63 h_i = Rational(1, order - 1)
64 for k in range(1, order-1):
65 z = 1 - k * h_i # 1/2 for order 3, 2/3 and 1/3 for order 4, 3/4, 1/2, 1/4 for order 5, ...
66 if k == 1:
67 set.append((0.5 * (1 - z), 0.5 * (1 - z), z))
68 else: # k > 1
69 h_k = Rational(1, k + 1)
70 for i in range(1, k + 1):
71 x = i * h_k * (1 - z)
72 for j in range(1, k + 1):
73 y = j * h_k * (1 - z)
74 set.append((x, y, z))
75
76 assert len(set) == (order + 1) * (order + 2) * (2 * order + 3) // 6, f"Expected {(order + 1) * (order + 2) * (order + 3) // 6} points, but got {len(set)}"
77
78 return set
79
80def create_matrix(equations, coeffs):
81 A = zeros(len(equations))
82 i = 0
83 j = 0
84 for j in range(0, len(coeffs)):
85 c = coeffs[j]
86 for i in range(0, len(equations)):
87 e = equations[i]
88 d, _ = reduced(e, [c])
89 A[i, j] = d[0]
90 return A
91
92class Pyramid:
93 def __init__(self, order):
94 self.order = order
95 self.points = []
96 self.compute_basis()
97
98 def nbf(self):
99 return len(self.N)
100
101 def compute_basis(self):
102 order = self.order
103 N = []
104 self.points = create_point_set(order)
105 sum, coeff, basis = pyramid_space(order)
106
107 equations = []
108 for p in self.points:
109 ex = sum.subs(x, p[0])
110 ex = ex.subs(y, p[1])
111 ex = ex.subs(z, p[2])
112 equations.append(ex)
113
114 b = eye(len(equations))
115 A = create_matrix(equations, coeff)
116 Ainv = A.inv()
117 xx = Ainv * b
118
119 for i in range(0, len(equations)):
120 Ni = sum
121 for j in range(0, len(coeff)):
122 Ni = Ni.subs(coeff[j], xx[j, i])
123 N.append(Ni)
124
125 self.N = N
126
127
129 parser = argparse.ArgumentParser(
130 description=__doc__,
131 formatter_class=argparse.RawDescriptionHelpFormatter)
132 parser.add_argument("output", type=str, help="path to the output folder")
133 return parser.parse_args()
134
135
136if __name__ == "__main__":
137 args = parse_args()
138
139 dims = [3]
140
141 orders = [0, 1, 2, 3, 4]
142
143 bletter = "pyramid"
144
145 cpp = f"#include \"auto_{bletter}_bases.hpp\""
146 cpp = cpp + "\n#include \"auto_b_bases.hpp\""
147 cpp = cpp + "\n#include \"p_n_bases.hpp\""
148 cpp = cpp + "\n\n\n" \
149 "namespace polyfem {\nnamespace autogen " + "{\nnamespace " + "{\n"
150
151 hpp = "#pragma once\n\n#include <Eigen/Dense>\n#include <cassert>\n"
152
153 hpp = hpp + "\nnamespace polyfem {\nnamespace autogen " + "{\n"
154
155 for dim in dims:
156 assert dim == 3, "Only 3D pyramid is supported"
157 print(str(dim) + "D " + bletter)
158 suffix = "3d"
159
160 unique_nodes = f"void {bletter}_nodes_{suffix}" + \
161 f"(const int {bletter}, Eigen::MatrixXd &val)"
162
163 unique_fun = f"void {bletter}_basis_value_{suffix}" + \
164 f"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
165 dunique_fun = f"void {bletter}_grad_basis_value_{suffix}" + \
166 f"(const int {bletter}, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
167
168 hpp = hpp + unique_nodes + ";\n\n"
169
170 hpp = hpp + unique_fun + ";\n\n"
171 hpp = hpp + dunique_fun + ";\n\n"
172
173 unique_nodes = unique_nodes + f"{{\nswitch({bletter})" + "{\n"
174
175 unique_fun = unique_fun + "{\n"
176 dunique_fun = dunique_fun + "{\n"
177
178 unique_fun = unique_fun + f"\nswitch({bletter})" + "{\n"
179 dunique_fun = dunique_fun + f"\nswitch({bletter})" + "{\n"
180
181 vertices = [[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1]]
182
183 for order in orders:
184 print("\t-processing " + str(order))
185
186 if order == 0:
187 def fe(): return None
188 fe.nbf = lambda: 1
189
190 fe.N = [1]
191
192 fe.points = [[2./5., 2./5., 1./5.]]
193 else:
194 fe = Pyramid(order)
195
196 current_indices = list(range(0, len(fe.points)))
197 indices = []
198
199 # vertex coordinate
200 for i in range(0, 5):
201 vv = vertices[i]
202 for ii in current_indices:
203 norm = 0
204 for dd in range(0, dim):
205 norm = norm + (vv[dd] - fe.points[ii][dd]) ** 2
206
207 if norm < 1e-10:
208 indices.append(ii)
209 current_indices.remove(ii)
210 break
211
212 # base edge 1 coordinate
213 for i in range(0, order - 1):
214 for ii in current_indices:
215 if fe.points[ii][1] != 0 or (dim == 3 and fe.points[ii][2] != 0):
216 continue
217
218 if abs(fe.points[ii][0] - (i + 1) / order) < 1e-10:
219 indices.append(ii)
220 current_indices.remove(ii)
221 break
222
223 # base edge 2 coordinate
224 for i in range(0, order - 1):
225 for ii in current_indices:
226 if fe.points[ii][0] != 1 or (dim == 3 and fe.points[ii][2] != 0):
227 continue
228
229 if abs(fe.points[ii][1] - (i + 1) / order) < 1e-10:
230 indices.append(ii)
231 current_indices.remove(ii)
232 break
233
234 # base edge 3 coordinate
235 for i in range(0, order - 1):
236 for ii in current_indices:
237 if fe.points[ii][1] != 1 or (dim == 3 and fe.points[ii][2] != 0):
238 continue
239
240 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
241 indices.append(ii)
242 current_indices.remove(ii)
243 break
244
245 # base edge 4 coordinate
246 for i in range(0, order - 1):
247 for ii in current_indices:
248 if fe.points[ii][0] != 0 or (dim == 3 and fe.points[ii][2] != 0):
249 continue
250
251 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
252 indices.append(ii)
253 current_indices.remove(ii)
254 break
255
256
257 # side edge 1 coordinate
258 for i in range(0, order - 1):
259 for ii in current_indices:
260 if fe.points[ii][0] != 0 or fe.points[ii][1] != 0:
261 continue
262
263 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
264 indices.append(ii)
265 current_indices.remove(ii)
266 break
267
268 # side edge 2 coordinate
269 for i in range(0, order - 1):
270 for ii in current_indices:
271 if fe.points[ii][0] + fe.points[ii][2] != 1 or fe.points[ii][1] != 0:
272 continue
273
274 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
275 indices.append(ii)
276 current_indices.remove(ii)
277 break
278
279 # side edge 3 coordinate
280 for i in range(0, order - 1):
281 for ii in current_indices:
282 if fe.points[ii][0] + fe.points[ii][2] != 1 or fe.points[ii][1] + fe.points[ii][2] != 1:
283 continue
284
285
286 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
287 indices.append(ii)
288 current_indices.remove(ii)
289 break
290
291 # side edge 4 coordinate
292 for i in range(0, order - 1):
293 for ii in current_indices:
294 if fe.points[ii][1] + fe.points[ii][2] != 1 or fe.points[ii][0] != 0:
295 continue
296
297 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
298 indices.append(ii)
299 current_indices.remove(ii)
300 break
301
302 nn = max(0, order - 2)
303 npts_b = (nn + 1)**2
304 npts = int(nn * (nn + 1) / 2)
305
306 # front: y = 0 (f[0]: v0,v1,v4)
307 tmp = []
308 for i in range(0, npts):
309 for ii in current_indices:
310 if abs(fe.points[ii][1]) > 1e-10:
311 continue
312 tmp.append(ii); current_indices.remove(ii); break
313 for i in range(len(tmp)):
314 indices.append(tmp[(i + 1) % len(tmp)])
315
316 # right: x + z = 1 (f[1]: v1,v2,v4)
317 tmp = []
318 for i in range(0, npts):
319 for ii in current_indices:
320 if abs(fe.points[ii][0]) < 1e-10 or abs(fe.points[ii][1]) < 1e-10 or abs(fe.points[ii][2]) < 1e-10:
321 continue
322 if abs((fe.points[ii][0] + fe.points[ii][2]) - 1) > 1e-10:
323 continue
324 tmp.append(ii); current_indices.remove(ii); break
325 for i in range(len(tmp)):
326 indices.append(tmp[(i + 1) % len(tmp)])
327
328 # back: y + z = 1 (f[2]: v2,v3,v4)
329 tmp = []
330 for i in range(0, npts):
331 for ii in current_indices:
332 if abs(fe.points[ii][0]) < 1e-10 or abs(fe.points[ii][1]) < 1e-10 or abs(fe.points[ii][2]) < 1e-10:
333 continue
334 if abs((fe.points[ii][1] + fe.points[ii][2]) - 1) > 1e-10:
335 continue
336 tmp.append(ii); current_indices.remove(ii); break
337 for i in range(len(tmp)):
338 indices.append(tmp[(i + 1) % len(tmp)])
339
340 # left: x = 0 (f[3]: v3,v0,v4)
341 tmp = []
342 for i in range(0, npts):
343 for ii in current_indices:
344 if abs(fe.points[ii][0]) > 1e-10:
345 continue
346 tmp.append(ii); current_indices.remove(ii); break
347 for i in range(len(tmp)):
348 indices.append(tmp[(i + 1) % len(tmp)])
349
350 # bottom: z = 0 (f[4]: base quad, (p-1)^2 nodes) ← moved to after tri faces
351 for i in range(0, npts_b):
352 for ii in current_indices:
353 if abs(fe.points[ii][2]) > 1e-10:
354 continue
355 indices.append(ii); current_indices.remove(ii); break
356
357 # interior unshared indices, order does not matter
358 for ii in current_indices:
359 indices.append(ii)
360
361 for i in range(0, fe.nbf()):
362 print(i, indices[i], fe.points[indices[i]])
363
364 # nodes code gen
365 nodes = f"void {bletter}_{order}_nodes_{suffix}(Eigen::MatrixXd &res) {{\n res.resize(" + str(
366 len(indices)) + ", " + str(dim) + "); res << \n"
367 unique_nodes = unique_nodes + f"\tcase {order}: " + \
368 f"{bletter}_{order}_nodes_{suffix}(val); break;\n"
369
370 for ii in indices:
371 nodes = nodes + ccode(fe.points[ii][0]) + ", " + ccode(fe.points[ii][1]) + (
372 (", " + ccode(fe.points[ii][2])) if dim == 3 else "") + ",\n"
373 nodes = nodes[:-2]
374 nodes = nodes + ";\n}"
375
376 # bases code gen
377 func = f"void {bletter}_{order}_basis_value_{suffix}" + \
378 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &result_0)"
379 dfunc = f"void {bletter}_{order}_basis_grad_value_{suffix}" + \
380 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
381
382 unique_fun = unique_fun + \
383 f"\tcase {order}: {bletter}_{order}_basis_value_{suffix}(local_index, uv, val); break;\n"
384 dunique_fun = dunique_fun + \
385 f"\tcase {order}: {bletter}_{order}_basis_grad_value_{suffix}(local_index, uv, val); break;\n"
386
387 # hpp = hpp + func + ";\n"
388 # hpp = hpp + dfunc + ";\n"
389
390 base = "auto x=uv.col(0).array();\nauto y=uv.col(1).array();"
391 base = base + "\nauto z=uv.col(2).array();"
392 base = base + "\n\n"
393 dbase = base
394
395 if order == 0:
396 base = base + "result_0.resize(x.size(),1);\n"
397
398 base = base + "switch(local_index){\n"
399 dbase = dbase + \
400 "val.resize(uv.rows(), uv.cols());\n Eigen::ArrayXd result_0(uv.rows());\n" + \
401 "switch(local_index){\n"
402
403 for i in range(0, fe.nbf()):
404 real_index = indices[i]
405 # real_index = i
406
407 base = base + "\tcase " + str(i) + ": {" + pretty_print.C99_print(
408 simplify(fe.N[real_index])).replace(" = 1;", ".setOnes();") + "} break;\n"
409 dbase = dbase + "\tcase " + str(i) + ": {" + \
410 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], x))).replace(" = 0;", ".setZero();").replace(" = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(0) = result_0; }" \
411 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], y))).replace(" = 0;", ".setZero();").replace(
412 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(1) = result_0; }"
413
414 if dim == 3:
415 dbase = dbase + "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], z))).replace(" = 0;", ".setZero();").replace(
416 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(2) = result_0; }"
417
418 dbase = dbase + "} break;\n"
419
420 base = base + "\tdefault: assert(false);\n}"
421 dbase = dbase + "\tdefault: assert(false);\n}"
422
423 cpp = cpp + func + "{\n\n"
424 cpp = cpp + base + "}\n"
425
426 cpp = cpp + dfunc + "{\n\n"
427 cpp = cpp + dbase + "}\n\n\n"
428
429 cpp = cpp + nodes + "\n\n\n"
430
431 unique_nodes = unique_nodes + "\tdefault: assert(false);\n}}"
432
433 unique_fun = unique_fun + "\tdefault: assert(false); \n}}"
434 dunique_fun = dunique_fun + "\tdefault: assert(false); \n}}"
435
436 cpp = cpp + "}\n\n" + unique_nodes + "\n" + unique_fun + \
437 "\n\n" + dunique_fun + "\n" + "\nnamespace " + "{\n"
438 hpp = hpp + "\n"
439
440 hpp = hpp + \
441 f"\nstatic const int MAX_{bletter.capitalize()}_BASES = {max(orders)};\n"
442
443 cpp = cpp + "\n}}}\n"
444 hpp = hpp + "\n}}\n"
445
446 path = os.path.abspath(args.output)
447
448 print("saving...")
449 with open(os.path.join(path, f"auto_{bletter}_bases.cpp"), "w") as file:
450 file.write(cpp)
451
452 with open(os.path.join(path, f"auto_{bletter}_bases.hpp"), "w") as file:
453 file.write(hpp)
454
455 print("done!")
456
457
458 # print("Creating point set...")
459 # point_set = create_point_set(5)
460 # # plot the point set
461 # import matplotlib.pyplot as plt
462 # from mpl_toolkits.mplot3d import Axes3D
463 # fig = plt.figure()
464 # ax = fig.add_subplot(111, projection='3d')
465 # xs = [p[0] for p in point_set]
466 # ys = [p[1] for p in point_set]
467 # zs = [p[2] for p in point_set]
468 # colors = [p[3] for p in point_set]
469 # ax.scatter(xs, ys, zs, c=colors)
470 # ax.set_xlabel('X')
471 # ax.set_ylabel('Y')
472 # ax.set_zlabel('Z')
473
474 # plt.show()
475
476 # python_pyramid_basis_0 = Pyramid(0)
477 # print(python_pyramid_basis_0.N, len(python_pyramid_basis_0.N))
478
479 # python_pyramid_basis_1 = Pyramid(1)
480 # print(python_pyramid_basis_1.N, len(python_pyramid_basis_1.N))
481
482 # python_pyramid_basis_2 = Pyramid(2)
483 # print(python_pyramid_basis_2.N, len(python_pyramid_basis_2.N))
C99_print(expr)
pyramid_space(order)
create_point_set(order)
create_matrix(equations, coeffs)