diff --git a/local/bin/math b/local/bin/math new file mode 100755 index 0000000..b61e07b --- /dev/null +++ b/local/bin/math @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +import ast +import operator +import math +import sys + +_OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + ast.FloorDiv: operator.floordiv, +} + +_MATH_FUNCS = {name: getattr(math, name) for name in dir(math) if not name.startswith("_") and callable(getattr(math, name))} +_MATH_FUNCS.update({ + "abs": abs, + "round": round, + "factorial": math.factorial, + "degrees": math.degrees, + "radians": math.radians, + "isclose": math.isclose, + "isfinite": math.isfinite, + "isinf": math.isinf, + "isnan": math.isnan, + "log2": math.log2, + "log10": math.log10, + "log": math.log, + "exp": math.exp, + "expm1": math.expm1, + "fsum": math.fsum, + "gamma": math.gamma, + "lgamma": math.lgamma, + "trunc": math.trunc, + "max": max, + "min": min, + "sum": sum, +}) + +_CONSTANTS = { + "pi": math.pi, + "e": math.e, + "tau": math.tau, + "inf": math.inf, + "nan": math.nan, +} + +class MathEvaluator(ast.NodeVisitor): + def visit(self, node): + if isinstance(node, ast.Expression): + return self.visit(node.body) + return super().visit(node) + + def visit_BinOp(self, node): + left = self.visit(node.left) + right = self.visit(node.right) + op_type = type(node.op) + if op_type in _OPERATORS: + return _OPERATORS[op_type](left, right) + raise ValueError(f"Unsupported binary operator {op_type.__name__}") + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + op_type = type(node.op) + if op_type in _OPERATORS: + return _OPERATORS[op_type](operand) + raise ValueError(f"Unsupported unary operator {op_type.__name__}") + + def visit_Call(self, node): + if not isinstance(node.func, ast.Name): + raise ValueError("Only simple function calls allowed") + func_name = node.func.id + if func_name not in _MATH_FUNCS: + raise ValueError(f"Unknown function '{func_name}'") + args = [self.visit(arg) for arg in node.args] + return _MATH_FUNCS[func_name](*args) + + def visit_Name(self, node): + if node.id in _CONSTANTS: + return _CONSTANTS[node.id] + raise ValueError(f"Unknown identifier '{node.id}'") + + def visit_Constant(self, node): + if isinstance(node.value, (int, float)): + return node.value + raise ValueError(f"Unsupported constant type: {type(node.value).__name__}") + + def visit_List(self, node): + return [self.visit(el) for el in node.elts] + + def visit_Tuple(self, node): + return tuple(self.visit(el) for el in node.elts) + + def generic_visit(self, node): + raise ValueError(f"Unsupported expression: {type(node).__name__}") + +def main(): + if len(sys.argv) < 2: + print("Usage: math ''") + print("Example: math 'sin(pi/2) + 2^3'") + sys.exit(1) + + expr = " ".join(sys.argv[1:]) + expr = expr.replace("^", "**") + expr = expr.replace("\n", "") + + try: + tree = ast.parse(expr, mode="eval") + evaluator = MathEvaluator() + result = evaluator.visit(tree) + if isinstance(result, float) and result.is_integer(): + print(int(result)) + else: + print(result) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + main()