124 lines
3.5 KiB
Python
Executable file
124 lines
3.5 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
import ast
|
|
import operator
|
|
import math
|
|
import sys
|
|
|
|
_OPERATORS = {
|
|
ast.Add: operator.add,
|
|
ast.Sub: operator.sub,
|
|
ast.Mult: operator.mul,
|
|
ast.Div: operator.truediv,
|
|
ast.Mod: operator.mod,
|
|
ast.Pow: operator.pow,
|
|
ast.USub: operator.neg,
|
|
ast.UAdd: operator.pos,
|
|
ast.FloorDiv: operator.floordiv,
|
|
}
|
|
|
|
_MATH_FUNCS = {name: getattr(math, name) for name in dir(math) if not name.startswith("_") and callable(getattr(math, name))}
|
|
_MATH_FUNCS.update({
|
|
"abs": abs,
|
|
"round": round,
|
|
"factorial": math.factorial,
|
|
"degrees": math.degrees,
|
|
"radians": math.radians,
|
|
"isclose": math.isclose,
|
|
"isfinite": math.isfinite,
|
|
"isinf": math.isinf,
|
|
"isnan": math.isnan,
|
|
"log2": math.log2,
|
|
"log10": math.log10,
|
|
"log": math.log,
|
|
"exp": math.exp,
|
|
"expm1": math.expm1,
|
|
"fsum": math.fsum,
|
|
"gamma": math.gamma,
|
|
"lgamma": math.lgamma,
|
|
"trunc": math.trunc,
|
|
"max": max,
|
|
"min": min,
|
|
"sum": sum,
|
|
})
|
|
|
|
_CONSTANTS = {
|
|
"pi": math.pi,
|
|
"e": math.e,
|
|
"tau": math.tau,
|
|
"inf": math.inf,
|
|
"nan": math.nan,
|
|
}
|
|
|
|
class matheval(ast.NodeVisitor):
|
|
def visit(self, node):
|
|
if isinstance(node, ast.Expression):
|
|
return self.visit(node.body)
|
|
return super().visit(node)
|
|
|
|
def visit_BinOp(self, node):
|
|
left = self.visit(node.left)
|
|
right = self.visit(node.right)
|
|
op_type = type(node.op)
|
|
if op_type in _OPERATORS:
|
|
return _OPERATORS[op_type](left, right)
|
|
raise ValueError(f"Unsupported binary operator {op_type.__name__}")
|
|
|
|
def visit_UnaryOp(self, node):
|
|
operand = self.visit(node.operand)
|
|
op_type = type(node.op)
|
|
if op_type in _OPERATORS:
|
|
return _OPERATORS[op_type](operand)
|
|
raise ValueError(f"Unsupported unary operator {op_type.__name__}")
|
|
|
|
def visit_Call(self, node):
|
|
if not isinstance(node.func, ast.Name):
|
|
raise ValueError("Only simple function calls allowed")
|
|
func_name = node.func.id
|
|
if func_name not in _MATH_FUNCS:
|
|
raise ValueError(f"Unknown function '{func_name}'")
|
|
args = [self.visit(arg) for arg in node.args]
|
|
return _MATH_FUNCS[func_name](*args)
|
|
|
|
def visit_Name(self, node):
|
|
if node.id in _CONSTANTS:
|
|
return _CONSTANTS[node.id]
|
|
raise ValueError(f"Unknown identifier '{node.id}'")
|
|
|
|
def visit_Constant(self, node):
|
|
if isinstance(node.value, (int, float)):
|
|
return node.value
|
|
raise ValueError(f"Unsupported constant type: {type(node.value).__name__}")
|
|
|
|
def visit_List(self, node):
|
|
return [self.visit(el) for el in node.elts]
|
|
|
|
def visit_Tuple(self, node):
|
|
return tuple(self.visit(el) for el in node.elts)
|
|
|
|
def generic_visit(self, node):
|
|
raise ValueError(f"Unsupported expression: {type(node).__name__}")
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
print("Usage: math '<expression>'")
|
|
print("Example: math 'sin(pi/2) + 2^3'")
|
|
sys.exit(1)
|
|
|
|
expr = " ".join(sys.argv[1:])
|
|
expr = expr.replace("^", "**")
|
|
expr = expr.replace("\n", "")
|
|
|
|
try:
|
|
tree = ast.parse(expr, mode="eval")
|
|
evaluator = matheval()
|
|
result = evaluator.visit(tree)
|
|
if isinstance(result, float) and result.is_integer():
|
|
print(int(result))
|
|
else:
|
|
print(result)
|
|
except Exception as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|