PolyFEM
Loading...
Searching...
No Matches
pretty_print.py
Go to the documentation of this file.
1from sympy import *
2from sympy.matrices import *
3from sympy.printing import ccode
4import numpy as np
5
6# We define reference space coordinate as sympy symbol x, y, z.
7SCALAR_COORDS = symbols('x,y,z')
8
9# pretty print
10def C99_print(expr):
11 CSE_results = cse(expr, numbered_symbols("helper_"), optimizations='basic')
12 lines = []
13 for helper in CSE_results[0]:
14 if isinstance(helper[1], MatrixSymbol):
15 lines.append(
16 'const auto ' + str(helper[0]) + '[' + str(helper[1].rows * helper[1].cols) + '];')
17 lines.append(ccode(helper[1], helper[0]))
18 else:
19 lines.append('const auto ' + ccode(helper[1], helper[0]))
20
21 for i, result in enumerate(CSE_results[1]):
22 lines.append(ccode(result, "result_%d" % i))
23 return '\n'.join(lines)
24
25
26def C99_print_scalar(expr, result_name="result"):
27 """Print scalar assignment: double helper_x = expr;"""
28 substitutions, results = cse(
29 expr, numbered_symbols("helper_"), optimizations='basic')
30 lines = [f"double {ccode(value, symbol)}" for symbol, value in substitutions]
31 lines.append(ccode(results[0], result_name))
32 return '\n'.join(lines)
33
34
35def scalar_args(dim):
36 """Return a scalar function argument list, e.g. 'double x, double y'."""
37 assert 1 <= dim <= len(SCALAR_COORDS)
38 return ", ".join(f"double {coord.name}" for coord in SCALAR_COORDS[:dim])
39
40
42 """
43 Return scalar function call args: uv(i, 0), uv(i, 1)
44 This is for unpacking quadrature points in Eigen matrix.
45 """
46 return ", ".join(f"uv(i, {d})" for d in range(dim))
47
48
49def C99_print_scalar_value_function(function_name, expr, dim):
50 """Print function that evaluate basis value at one quadrature point."""
51 return (
52 f"double {function_name}({scalar_args(dim)}) {{\n"
53 "double result;\n"
54 f"{C99_print_scalar(expr, 'result')}\n"
55 "return result;\n"
56 "}\n\n")
57
58
59def C99_print_scalar_gradient_function(function_name, expr, dim):
60 """Print function that evaluate basis gradient at one quadrature point."""
61 assert 1 <= dim <= len(SCALAR_COORDS)
62 lines = [f"void {function_name}({scalar_args(dim)}, double *val) {{"]
63 for d, coord in enumerate(SCALAR_COORDS[:dim]):
64 derivative = simplify(diff(expr, coord))
65 lines.append("{" + C99_print_scalar(derivative, f"val[{d}]") + "}")
66 lines.append("}\n")
67 return "\n".join(lines) + "\n"
68
69
70def C99_print_scalar_value_case(local_index, function_name, dim):
71 """Generate one local_index switch case for basis values function."""
72 return (
73 f"\tcase {local_index}:\n"
74 "\t\tfor (Eigen::Index i = 0; i < uv.rows(); ++i)\n"
75 f"\t\t\tresult_0(i, 0) = {function_name}({scalar_call_args(dim)});\n"
76 "\t\tbreak;\n")
77
78
79def C99_print_scalar_gradient_case(local_index, function_name, dim):
80 """Generate one local_index switch case for basis gradients."""
81 lines = [
82 f"\tcase {local_index}:",
83 "\t\tfor (Eigen::Index i = 0; i < uv.rows(); ++i) {",
84 f"\t\t\t{function_name}({scalar_call_args(dim)}, gradient);",
85 ]
86 lines.extend(f"\t\t\tval(i, {d}) = gradient[{d}];" for d in range(dim))
87 lines.extend(["\t\t}", "\t\tbreak;"])
88 return "\n".join(lines) + "\n"
89
90
91# Pretty print a matrix or tensor expression.
92def C99_print_tensor(expr, result_name="result"):
93 # If a tensor expression, the result is reshaped into a 2d matrix for printing.
94 lines = []
95 subs, result = cse(expr, numbered_symbols(
96 "helper_"), optimizations='basic')
97 if len(result) == 1:
98 result = result[0]
99
100 for k, v in subs:
101 lines.append(f"const double {ccode(v, k)}")
102
103 result_shape = np.array(result).shape
104 if len(result_shape) == 2:
105 for i in range(result_shape[0]):
106 for j in range(result_shape[1]):
107 s = ccode(result[i, j], f"{result_name}[{i}, {j}]")
108 lines.append(f"{s}")
109 elif len(result_shape) == 4:
110 for i in range(result_shape[0]):
111 for j in range(result_shape[1]):
112 for k in range(result_shape[2]):
113 for l in range(result_shape[3]):
114 s = ccode(result[i, j, k, l],
115 f"{result_name}[{i * result_shape[1] + j}, {k * result_shape[3] + l}]")
116 lines.append(f"{s}")
117
118 return "\n".join(lines)
scalar_call_args(dim)
C99_print_scalar_gradient_function(function_name, expr, dim)
C99_print_scalar_value_function(function_name, expr, dim)
C99_print_scalar(expr, result_name="result")
C99_print_scalar_value_case(local_index, function_name, dim)
C99_print_scalar_gradient_case(local_index, function_name, dim)