From b3643b2f28e8689e2290cb715d3d1190b2f8804e Mon Sep 17 00:00:00 2001 From: Daniele Fucini Date: Sun, 29 Sep 2024 14:32:18 +0200 Subject: [PATCH] Add type hints to projecteuler.py --- Python/projecteuler.py | 81 ++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/Python/projecteuler.py b/Python/projecteuler.py index a86785c..746f3d9 100644 --- a/Python/projecteuler.py +++ b/Python/projecteuler.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +from typing import Callable, List, ParamSpec, Tuple, TypeVar from functools import wraps from math import sqrt, floor, ceil, gcd from timeit import default_timer @@ -7,7 +8,25 @@ from timeit import default_timer from numpy import zeros -def is_prime(num): +P = ParamSpec('P') +R = TypeVar('R') + + +def timing(f: Callable[P, R]) -> Callable[P, R]: + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + start = default_timer() + result = f(*args, **kwargs) + end = default_timer() + + print(f'{f.__name__!r} took {end - start:.9f} seconds') + + return result + + return wrapper + + +def is_prime(num: int) -> bool: if num < 4: # If num is 2 or 3 then it's prime. return num in (2, 3) @@ -31,7 +50,7 @@ def is_prime(num): return True -def is_palindrome(num, base=10): +def is_palindrome(num: int, base: int = 10) -> bool: reverse = 0 tmp = num @@ -54,12 +73,12 @@ def is_palindrome(num, base=10): # Least common multiple algorithm using the greatest common divisor. -def lcm(a, b): +def lcm(a: int , b: int) -> int: return a * b // gcd(a, b) # Recursive function to calculate the least common multiple of more than 2 numbers. -def lcmm(values, n): +def lcmm(values: List[int], n: int) -> int: # If there are only two numbers, use the lcm function to calculate the lcm. if n == 2: return lcm(values[0], values[1]) @@ -71,7 +90,7 @@ def lcmm(values, n): # Function implementing the Sieve or Eratosthenes to generate primes up to a certain number. -def sieve(n): +def sieve(n: int) -> List[int]: primes = [1] * (n + 1) # 0 and 1 are not prime, 2 and 3 are prime. @@ -100,7 +119,7 @@ def sieve(n): return primes -def count_divisors(n): +def count_divisors(n: int) -> int: count = 0 # For every divisor below the square root of n, there is a corresponding one # above the square root, so it's sufficient to check up to the square root of n @@ -118,7 +137,7 @@ def count_divisors(n): return count -def find_max_path(triang, n): +def find_max_path(triang: List[List[int]], n: int) -> int: # Start from the second to last row and go up. for i in range(n-2, -1, -1): # For each element in the row, check the two adjacent elements @@ -133,7 +152,7 @@ def find_max_path(triang, n): return triang[0][0] -def sum_of_divisors(n): +def sum_of_divisors(n: int) -> int: # For each divisor of n smaller than the square root of n, # there is another one larger than the square root. If i is # a divisor of n, so is n/i. Checking divisors i up to square @@ -141,20 +160,20 @@ def sum_of_divisors(n): # all divisors. limit = floor(sqrt(n)) + 1 - sum_ = 1 + _sum = 1 for i in range(2, limit): if n % i == 0: - sum_ += i + _sum += i # If n is a perfect square, i=limit is a divisor and has to be counted only once. if n != i * i: - sum_ = sum_ + n // i + _sum = _sum + n // i - return sum_ + return _sum -def is_pandigital(value, n): +def is_pandigital(value: int, n: int) -> bool: i = 0 digits = [0] * (n + 1) @@ -180,7 +199,7 @@ def is_pandigital(value, n): return True -def is_pentagonal(n): +def is_pentagonal(n: int) -> bool: # A number n is pentagonal if p=(sqrt(24n+1)+1)/6 is an integer. # In this case, n is the pth pentagonal number. i = (sqrt(24*n+1) + 1) / 6 @@ -199,9 +218,9 @@ def is_pentagonal(n): # d_(n+1)=(S-m_(n+1)^2)/d_n # a_(n+1)=floor((sqrt(S)+m_(n+1))/d_(n+1))=floor((a_0+m_(n+1))/d_(n+1)) # if a_i=2*a_0, the algorithm ends. -def build_sqrt_cont_fraction(i, l): - mn = 0 - dn = 1 +def build_sqrt_cont_fraction(i: int, l: int) -> Tuple[List[int], int]: + mn = 0.0 + dn = 1.0 count = 0 fraction = [0] * l @@ -233,7 +252,7 @@ def build_sqrt_cont_fraction(i, l): # Function to solve the Diophantine equation in the form x^2-Dy^2=1 # (Pell equation) using continued fractions. -def pell_eq(d): +def pell_eq(d: int) -> int: # Find the continued fraction for sqrt(d). fraction, _ = build_sqrt_cont_fraction(d, 100) @@ -278,7 +297,7 @@ def pell_eq(d): # Function to check if a number is semiprime. Parameters include # pointers to p and q to return the factors values and a list of # primes. -def is_semiprime(n, primes): +def is_semiprime(n: int, primes: List[int]) -> Tuple[bool, int, int]: # If n is prime, it's not semiprime. if primes[n] == 1: return False, -1, -1 @@ -336,14 +355,14 @@ def is_semiprime(n, primes): # If n=pq is semiprime, phi(n)=(p-1)(q-1)=pq-p-q+1=n-(p+4)+1 # if p!=q. If p=q (n is a square), phi(n)=n-p. -def phi_semiprime(n, p, q): +def phi_semiprime(n: int, p: int, q: int) -> int: if p == q: return n - p return n - (p + q) + 1 -def phi(n, primes): +def phi(n: int, primes: List[int]) -> float: # If n is primes, phi(n)=n-1. if primes[n] == 1: return n - 1 @@ -354,7 +373,7 @@ def phi(n, primes): if semi_p: return phi_semiprime(n, p, q) - ph = n + ph = float(n) # If 2 is a factor of n, multiply the current ph (which now is n) # by 1-1/2, then divide all factors 2. @@ -417,7 +436,7 @@ def phi(n, primes): # Function implementing the partition function. -def partition_fn(n, partitions, mod=-1): +def partition_fn(n: int, partitions: List[int], mod: int = -1) -> int: # The partition function for negative numbers is 0 by definition. if n < 0: return 0 @@ -452,7 +471,7 @@ def partition_fn(n, partitions, mod=-1): return int(res) -def dijkstra(matrix, distances, m, n, up=False, back=False, start=0): +def dijkstra(matrix: List[List[int]], distances: List[List[int]], m: int, n: int, up: bool = False, back: bool = False, start: int = 0) -> None: visited = zeros((m, n), int) for i in range(m): @@ -495,17 +514,3 @@ def dijkstra(matrix, distances, m, n, up=False, back=False, start=0): if i == m - 1 and j == n - 1: break - - -def timing(f): - @wraps(f) - def wrapper(*args, **kwargs): - start = default_timer() - result = f(*args, **kwargs) - end = default_timer() - - print(f'{f.__name__!r} took {end - start:.9f} seconds') - - return result - - return wrapper