PolyFEM
Loading...
Searching...
No Matches
p_bases.py
Go to the documentation of this file.
1# https://raw.githubusercontent.com/sympy/sympy/master/examples/advanced/fem.py
2from sympy import *
3import os
4import numpy as np
5import argparse
6from sympy.printing import ccode
7
8
9import pretty_print
10
11x, y, z = symbols('x,y,z')
12
13
15 def __init__(self, nsd):
16 self.nsd = nsd
17 if nsd <= 3:
18 coords = symbols('x,y,z')[:nsd]
19 else:
20 coords = [Symbol("x_%d" % d) for d in range(nsd)]
21 self.coords = coords
22
23 def integrate(self, f):
24 coords = self.coords
25 nsd = self.nsd
26
27 limit = 1
28 for p in coords:
29 limit -= p
30
31 intf = f
32 for d in range(0, nsd):
33 p = coords[d]
34 limit += p
35 intf = integrate(intf, (p, 0, limit))
36 return intf
37
38
39def bernstein_space(order, nsd):
40 if nsd > 3:
41 raise RuntimeError("Bernstein only implemented in 1D, 2D, and 3D")
42 sum = 0
43 basis = []
44 coeff = []
45
46 if nsd == 2:
47 b1, b2, b3 = x, y, 1 - x - y
48 for o1 in range(0, order + 1):
49 for o2 in range(0, order + 1):
50 for o3 in range(0, order + 1):
51 if o1 + o2 + o3 == order:
52 aij = Symbol("a_%d_%d_%d" % (o1, o2, o3))
53 fac = factorial(order) / (factorial(o1) *
54 factorial(o2) * factorial(o3))
55 sum += aij * fac * pow(b1, o1) * \
56 pow(b2, o2) * pow(b3, o3)
57 basis.append(fac * pow(b1, o1) *
58 pow(b2, o2) * pow(b3, o3))
59 coeff.append(aij)
60
61 if nsd == 3:
62 b1, b2, b3, b4 = x, y, z, 1 - x - y - z
63 for o1 in range(0, order + 1):
64 for o2 in range(0, order + 1):
65 for o3 in range(0, order + 1):
66 for o4 in range(0, order + 1):
67 if o1 + o2 + o3 + o4 == order:
68 aij = Symbol("a_%d_%d_%d_%d" % (o1, o2, o3, o4))
69 fac = factorial(
70 order) / (factorial(o1) * factorial(o2) * factorial(o3) * factorial(o4))
71 sum += aij * fac * \
72 pow(b1, o1) * pow(b2, o2) * \
73 pow(b3, o3) * pow(b4, o4)
74 basis.append(fac * pow(b1, o1) * pow(b2, o2) *
75 pow(b3, o3) * pow(b4, o4))
76 coeff.append(aij)
77
78 return sum, coeff, basis
79
80
81def create_point_set(order, nsd):
82 h = Rational(1, order)
83 set = []
84
85 if nsd == 2:
86 for i in range(0, order + 1):
87 x = i * h
88 for j in range(0, order + 1):
89 y = j * h
90 if x + y <= 1:
91 set.append((x, y))
92
93 if nsd == 3:
94 for i in range(0, order + 1):
95 x = i * h
96 for j in range(0, order + 1):
97 y = j * h
98 for k in range(0, order + 1):
99 z = k * h
100 if x + y + z <= 1:
101 set.append((x, y, z))
102
103 return set
104
105
106def create_matrix(equations, coeffs):
107 A = zeros(len(equations))
108 i = 0
109 j = 0
110 for j in range(0, len(coeffs)):
111 c = coeffs[j]
112 for i in range(0, len(equations)):
113 e = equations[i]
114 d, _ = reduced(e, [c])
115 A[i, j] = d[0]
116 return A
117
118
120 def __init__(self, nsd, order):
121 self.nsd = nsd
122 self.order = order
123 self.points = []
124 self.compute_basis()
125
126 def nbf(self):
127 return len(self.N)
128
129 def compute_basis(self):
130 order = self.order
131 nsd = self.nsd
132 N = []
133 pol, coeffs, basis = bernstein_space(order, nsd)
134 self.points = create_point_set(order, nsd)
135
136 equations = []
137 for p in self.points:
138 ex = pol.subs(x, p[0])
139 if nsd > 1:
140 ex = ex.subs(y, p[1])
141 if nsd > 2:
142 ex = ex.subs(z, p[2])
143 equations.append(ex)
144
145 A = create_matrix(equations, coeffs)
146
147 # if A.shape[0] > 25:
148 # A = A.evalf()
149
150 Ainv = A.inv()
151
152 b = eye(len(equations))
153
154 xx = Ainv * b
155
156 for i in range(0, len(equations)):
157 Ni = pol
158 for j in range(0, len(coeffs)):
159 Ni = Ni.subs(coeffs[j], xx[j, i])
160 N.append(Ni)
161
162 self.N = N
163
164
166 parser = argparse.ArgumentParser(
167 description=__doc__,
168 formatter_class=argparse.RawDescriptionHelpFormatter)
169 parser.add_argument("output", type=str, help="path to the output folder")
170 return parser.parse_args()
171
172
173if __name__ == "__main__":
174 args = parse_args()
175
176 dims = [2, 3]
177
178 orders = [0, 1, 2, 3, 4]
179 # orders = [4]
180
181 cpp = "#include \"auto_p_bases.hpp\"\n\n\n"
182 cpp = cpp + \
183 "namespace polyfem {\nnamespace autogen " + "{\nnamespace " + "{\n"
184
185 hpp = "#pragma once\n\n#include <Eigen/Dense>\n#include \"p_n_bases.hpp\"\n#include <cassert>\n\n"
186 hpp = hpp + "namespace polyfem {\nnamespace autogen " + "{\n"
187
188 for dim in dims:
189 print(str(dim) + "D")
190 suffix = "_2d" if dim == 2 else "_3d"
191
192 unique_nodes = "void p_nodes" + suffix + \
193 "(const int p, Eigen::MatrixXd &val)"
194
195 unique_fun = "void p_basis_value" + suffix + \
196 "(const int p, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
197 dunique_fun = "void p_grad_basis_value" + suffix + \
198 "(const int p, const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
199
200 hpp = hpp + unique_nodes + ";\n\n"
201
202 hpp = hpp + unique_fun + ";\n\n"
203 hpp = hpp + dunique_fun + ";\n\n"
204
205 unique_nodes = unique_nodes + "{\nswitch(p)" + "{\n"
206
207 unique_fun = unique_fun + "{\nswitch(p)" + "{\n"
208 dunique_fun = dunique_fun + "{\nswitch(p)" + "{\n"
209
210 if dim == 2:
211 vertices = [[0, 0], [1, 0], [0, 1]]
212 elif dim == 3:
213 vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
214
215 for order in orders:
216 print("\t-processing " + str(order))
217
218 if order == 0:
219 def fe(): return None
220 fe.nbf = lambda: 1
221
222 fe.N = [1]
223
224 if dim == 2:
225 fe.points = [[1./3., 1./3.]]
226 else:
227 fe.points = [[1./3., 1./3., 1./3.]]
228 else:
229 fe = Lagrange(dim, order)
230
231 current_indices = list(range(0, len(fe.points)))
232 indices = []
233
234 # vertex coordinate
235 for i in range(0, dim + 1):
236 vv = vertices[i]
237 for ii in current_indices:
238 norm = 0
239 for dd in range(0, dim):
240 norm = norm + (vv[dd] - fe.points[ii][dd]) ** 2
241
242 if norm < 1e-10:
243 indices.append(ii)
244 current_indices.remove(ii)
245 break
246
247 # edge 1 coordinate
248 for i in range(0, order - 1):
249 for ii in current_indices:
250 if fe.points[ii][1] != 0 or (dim == 3 and fe.points[ii][2] != 0):
251 continue
252
253 if abs(fe.points[ii][0] - (i + 1) / order) < 1e-10:
254 indices.append(ii)
255 current_indices.remove(ii)
256 break
257
258 # edge 2 coordinate
259 for i in range(0, order - 1):
260 for ii in current_indices:
261 if fe.points[ii][0] + fe.points[ii][1] != 1 or (dim == 3 and fe.points[ii][2] != 0):
262 continue
263
264 if abs(fe.points[ii][1] - (i + 1) / order) < 1e-10:
265 indices.append(ii)
266 current_indices.remove(ii)
267 break
268
269 # edge 3 coordinate
270 for i in range(0, order - 1):
271 for ii in current_indices:
272 if fe.points[ii][0] != 0 or (dim == 3 and fe.points[ii][2] != 0):
273 continue
274
275 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
276 indices.append(ii)
277 current_indices.remove(ii)
278 break
279
280 if dim == 3:
281 # edge 4 coordinate
282 for i in range(0, order - 1):
283 for ii in current_indices:
284 if fe.points[ii][0] != 0 or fe.points[ii][1] != 0:
285 continue
286
287 if abs(fe.points[ii][2] - (i + 1) / order) < 1e-10:
288 indices.append(ii)
289 current_indices.remove(ii)
290 break
291
292 # edge 5 coordinate
293 for i in range(0, order - 1):
294 for ii in current_indices:
295 if fe.points[ii][0] + fe.points[ii][2] != 1 or fe.points[ii][1] != 0:
296 continue
297
298 if abs(fe.points[ii][0] - (1 - (i + 1) / order)) < 1e-10:
299 indices.append(ii)
300 current_indices.remove(ii)
301 break
302
303 # edge 6 coordinate
304 for i in range(0, order - 1):
305 for ii in current_indices:
306 if fe.points[ii][1] + fe.points[ii][2] != 1 or fe.points[ii][0] != 0:
307 continue
308
309 if abs(fe.points[ii][1] - (1 - (i + 1) / order)) < 1e-10:
310 indices.append(ii)
311 current_indices.remove(ii)
312 break
313
314 if dim == 3:
315 nn = max(0, order - 2)
316 npts = int(nn * (nn + 1) / 2)
317
318 # bottom: z = 0
319 for i in range(0, npts):
320 for ii in current_indices:
321 if abs(fe.points[ii][2]) > 1e-10:
322 continue
323
324 indices.append(ii)
325 current_indices.remove(ii)
326 break
327
328 # front: y = 0
329 for i in range(0, npts):
330 for ii in current_indices:
331 if abs(fe.points[ii][1]) > 1e-10:
332 continue
333
334 indices.append(ii)
335 current_indices.remove(ii)
336 break
337
338 # diagonal: none equal to zero and sum 1
339 tmp = []
340 for i in range(0, npts):
341 for ii in current_indices:
342 if (abs(fe.points[ii][0]) < 1e-10) | (abs(fe.points[ii][1]) < 1e-10) | (abs(fe.points[ii][2]) < 1e-10):
343 continue
344
345 if abs((fe.points[ii][0] + fe.points[ii][1] + fe.points[ii][2]) - 1) > 1e-10:
346 continue
347
348 tmp.append(ii)
349 current_indices.remove(ii)
350 break
351 for i in range(0, len(tmp)):
352 indices.append(tmp[(i + 2) % len(tmp)])
353
354 # side: x = 0
355 tmp = []
356 for i in range(0, npts):
357 for ii in current_indices:
358 if abs(fe.points[ii][0]) > 1e-10:
359 continue
360
361 tmp.append(ii)
362 current_indices.remove(ii)
363 break
364 tmp.sort(reverse=True)
365 indices.extend(tmp)
366
367 # either face or volume indices, order do not matter
368 for ii in current_indices:
369 indices.append(ii)
370
371 # nodes code gen
372 nodes = "void p_" + str(order) + "_nodes" + suffix + "(Eigen::MatrixXd &res) {\n res.resize(" + str(
373 len(indices)) + ", " + str(dim) + "); res << \n"
374 unique_nodes = unique_nodes + "\tcase " + \
375 str(order) + ": " + "p_" + str(order) + \
376 "_nodes" + suffix + "(val); break;\n"
377
378 for ii in indices:
379 nodes = nodes + ccode(fe.points[ii][0]) + ", " + ccode(fe.points[ii][1]) + (
380 (", " + ccode(fe.points[ii][2])) if dim == 3 else "") + ",\n"
381 nodes = nodes[:-2]
382 nodes = nodes + ";\n}"
383
384 # bases code gen
385 func = "void p_" + str(order) + "_basis_value" + suffix + \
386 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &result_0)"
387 dfunc = "void p_" + str(order) + "_basis_grad_value" + suffix + \
388 "(const int local_index, const Eigen::MatrixXd &uv, Eigen::MatrixXd &val)"
389
390 unique_fun = unique_fun + "\tcase " + str(order) + ": " + "p_" + str(
391 order) + "_basis_value" + suffix + "(local_index, uv, val); break;\n"
392 dunique_fun = dunique_fun + "\tcase " + str(order) + ": " + "p_" + str(
393 order) + "_basis_grad_value" + suffix + "(local_index, uv, val); break;\n"
394
395 # hpp = hpp + func + ";\n"
396 # hpp = hpp + dfunc + ";\n"
397
398 default_base = "p_n_basis_value_3d(p, local_index, uv, val);" if dim == 3 else "p_n_basis_value_2d(p, local_index, uv, val);"
399 default_dbase = "p_n_basis_grad_value_3d(p, local_index, uv, val);" if dim == 3 else "p_n_basis_grad_value_2d(p, local_index, uv, val);"
400 default_nodes = "p_n_nodes_3d(p, val);" if dim == 3 else "p_n_nodes_2d(p, val);"
401
402 base = "auto x=uv.col(0).array();\nauto y=uv.col(1).array();"
403 if dim == 3:
404 base = base + "\nauto z=uv.col(2).array();"
405 base = base + "\n\n"
406 dbase = base
407
408 if order == 0:
409 base = base + "result_0.resize(x.size(),1);\n"
410
411 base = base + "switch(local_index){\n"
412 dbase = dbase + \
413 "val.resize(uv.rows(), uv.cols());\n Eigen::ArrayXd result_0(uv.rows());\n" + \
414 "switch(local_index){\n"
415
416 for i in range(0, fe.nbf()):
417 real_index = indices[i]
418 # real_index = i
419
420 base = base + "\tcase " + str(i) + ": {" + pretty_print.C99_print(
421 simplify(fe.N[real_index])).replace(" = 1;", ".setOnes();") + "} break;\n"
422 dbase = dbase + "\tcase " + str(i) + ": {" + \
423 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], x))).replace(" = 0;", ".setZero();").replace(" = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(0) = result_0; }" \
424 "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], y))).replace(" = 0;", ".setZero();").replace(
425 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(1) = result_0; }"
426
427 if dim == 3:
428 dbase = dbase + "{" + pretty_print.C99_print(simplify(diff(fe.N[real_index], z))).replace(" = 0;", ".setZero();").replace(
429 " = 1;", ".setOnes();").replace(" = -1;", ".setConstant(-1);") + "val.col(2) = result_0; }"
430
431 dbase = dbase + "} break;\n"
432
433 base = base + "\tdefault: assert(false);\n}"
434 dbase = dbase + "\tdefault: assert(false);\n}"
435
436 cpp = cpp + func + "{\n\n"
437 cpp = cpp + base + "}\n"
438
439 cpp = cpp + dfunc + "{\n\n"
440 cpp = cpp + dbase + "}\n\n\n" + nodes + "\n\n\n"
441
442 unique_nodes = unique_nodes + "\tdefault: "+default_nodes+"\n}}"
443
444 unique_fun = unique_fun + "\tdefault: "+default_base+"\n}}"
445 dunique_fun = dunique_fun + "\tdefault: "+default_dbase+"\n}}"
446
447 cpp = cpp + "}\n\n" + unique_nodes + "\n" + unique_fun + \
448 "\n\n" + dunique_fun + "\n" + "\nnamespace " + "{\n"
449 hpp = hpp + "\n"
450
451 hpp = hpp + "\nstatic const int MAX_P_BASES = " + str(max(orders)) + ";\n"
452
453 cpp = cpp + "\n}}}\n"
454 hpp = hpp + "\n}}\n"
455
456 path = os.path.abspath(args.output)
457
458 print("saving...")
459 with open(os.path.join(path, "auto_p_bases.cpp"), "w") as file:
460 file.write(cpp)
461
462 with open(os.path.join(path, "auto_p_bases.hpp"), "w") as file:
463 file.write(hpp)
464
465 print("done!")
__init__(self, nsd, order)
Definition p_bases.py:120
compute_basis(self)
Definition p_bases.py:129
create_matrix(equations, coeffs)
Definition p_bases.py:106
create_point_set(order, nsd)
Definition p_bases.py:81
bernstein_space(order, nsd)
Definition p_bases.py:39
parse_args()
Definition p_bases.py:165
C99_print(expr)