Add type hints to projecteuler.py

This commit is contained in:
daniele 2024-09-29 14:32:18 +02:00
parent b6b10cdd12
commit b3643b2f28
Signed by: fuxino
GPG Key ID: 981A2B2A3BBF5514

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Callable, List, ParamSpec, Tuple, TypeVar
from functools import wraps from functools import wraps
from math import sqrt, floor, ceil, gcd from math import sqrt, floor, ceil, gcd
from timeit import default_timer from timeit import default_timer
@ -7,7 +8,25 @@ from timeit import default_timer
from numpy import zeros from numpy import zeros
def is_prime(num): P = ParamSpec('P')
R = TypeVar('R')
def timing(f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start = default_timer()
result = f(*args, **kwargs)
end = default_timer()
print(f'{f.__name__!r} took {end - start:.9f} seconds')
return result
return wrapper
def is_prime(num: int) -> bool:
if num < 4: if num < 4:
# If num is 2 or 3 then it's prime. # If num is 2 or 3 then it's prime.
return num in (2, 3) return num in (2, 3)
@ -31,7 +50,7 @@ def is_prime(num):
return True return True
def is_palindrome(num, base=10): def is_palindrome(num: int, base: int = 10) -> bool:
reverse = 0 reverse = 0
tmp = num tmp = num
@ -54,12 +73,12 @@ def is_palindrome(num, base=10):
# Least common multiple algorithm using the greatest common divisor. # Least common multiple algorithm using the greatest common divisor.
def lcm(a, b): def lcm(a: int , b: int) -> int:
return a * b // gcd(a, b) return a * b // gcd(a, b)
# Recursive function to calculate the least common multiple of more than 2 numbers. # Recursive function to calculate the least common multiple of more than 2 numbers.
def lcmm(values, n): def lcmm(values: List[int], n: int) -> int:
# If there are only two numbers, use the lcm function to calculate the lcm. # If there are only two numbers, use the lcm function to calculate the lcm.
if n == 2: if n == 2:
return lcm(values[0], values[1]) return lcm(values[0], values[1])
@ -71,7 +90,7 @@ def lcmm(values, n):
# Function implementing the Sieve or Eratosthenes to generate primes up to a certain number. # Function implementing the Sieve or Eratosthenes to generate primes up to a certain number.
def sieve(n): def sieve(n: int) -> List[int]:
primes = [1] * (n + 1) primes = [1] * (n + 1)
# 0 and 1 are not prime, 2 and 3 are prime. # 0 and 1 are not prime, 2 and 3 are prime.
@ -100,7 +119,7 @@ def sieve(n):
return primes return primes
def count_divisors(n): def count_divisors(n: int) -> int:
count = 0 count = 0
# For every divisor below the square root of n, there is a corresponding one # For every divisor below the square root of n, there is a corresponding one
# above the square root, so it's sufficient to check up to the square root of n # above the square root, so it's sufficient to check up to the square root of n
@ -118,7 +137,7 @@ def count_divisors(n):
return count return count
def find_max_path(triang, n): def find_max_path(triang: List[List[int]], n: int) -> int:
# Start from the second to last row and go up. # Start from the second to last row and go up.
for i in range(n-2, -1, -1): for i in range(n-2, -1, -1):
# For each element in the row, check the two adjacent elements # For each element in the row, check the two adjacent elements
@ -133,7 +152,7 @@ def find_max_path(triang, n):
return triang[0][0] return triang[0][0]
def sum_of_divisors(n): def sum_of_divisors(n: int) -> int:
# For each divisor of n smaller than the square root of n, # For each divisor of n smaller than the square root of n,
# there is another one larger than the square root. If i is # there is another one larger than the square root. If i is
# a divisor of n, so is n/i. Checking divisors i up to square # a divisor of n, so is n/i. Checking divisors i up to square
@ -141,20 +160,20 @@ def sum_of_divisors(n):
# all divisors. # all divisors.
limit = floor(sqrt(n)) + 1 limit = floor(sqrt(n)) + 1
sum_ = 1 _sum = 1
for i in range(2, limit): for i in range(2, limit):
if n % i == 0: if n % i == 0:
sum_ += i _sum += i
# If n is a perfect square, i=limit is a divisor and has to be counted only once. # If n is a perfect square, i=limit is a divisor and has to be counted only once.
if n != i * i: if n != i * i:
sum_ = sum_ + n // i _sum = _sum + n // i
return sum_ return _sum
def is_pandigital(value, n): def is_pandigital(value: int, n: int) -> bool:
i = 0 i = 0
digits = [0] * (n + 1) digits = [0] * (n + 1)
@ -180,7 +199,7 @@ def is_pandigital(value, n):
return True return True
def is_pentagonal(n): def is_pentagonal(n: int) -> bool:
# A number n is pentagonal if p=(sqrt(24n+1)+1)/6 is an integer. # A number n is pentagonal if p=(sqrt(24n+1)+1)/6 is an integer.
# In this case, n is the pth pentagonal number. # In this case, n is the pth pentagonal number.
i = (sqrt(24*n+1) + 1) / 6 i = (sqrt(24*n+1) + 1) / 6
@ -199,9 +218,9 @@ def is_pentagonal(n):
# d_(n+1)=(S-m_(n+1)^2)/d_n # d_(n+1)=(S-m_(n+1)^2)/d_n
# a_(n+1)=floor((sqrt(S)+m_(n+1))/d_(n+1))=floor((a_0+m_(n+1))/d_(n+1)) # a_(n+1)=floor((sqrt(S)+m_(n+1))/d_(n+1))=floor((a_0+m_(n+1))/d_(n+1))
# if a_i=2*a_0, the algorithm ends. # if a_i=2*a_0, the algorithm ends.
def build_sqrt_cont_fraction(i, l): def build_sqrt_cont_fraction(i: int, l: int) -> Tuple[List[int], int]:
mn = 0 mn = 0.0
dn = 1 dn = 1.0
count = 0 count = 0
fraction = [0] * l fraction = [0] * l
@ -233,7 +252,7 @@ def build_sqrt_cont_fraction(i, l):
# Function to solve the Diophantine equation in the form x^2-Dy^2=1 # Function to solve the Diophantine equation in the form x^2-Dy^2=1
# (Pell equation) using continued fractions. # (Pell equation) using continued fractions.
def pell_eq(d): def pell_eq(d: int) -> int:
# Find the continued fraction for sqrt(d). # Find the continued fraction for sqrt(d).
fraction, _ = build_sqrt_cont_fraction(d, 100) fraction, _ = build_sqrt_cont_fraction(d, 100)
@ -278,7 +297,7 @@ def pell_eq(d):
# Function to check if a number is semiprime. Parameters include # Function to check if a number is semiprime. Parameters include
# pointers to p and q to return the factors values and a list of # pointers to p and q to return the factors values and a list of
# primes. # primes.
def is_semiprime(n, primes): def is_semiprime(n: int, primes: List[int]) -> Tuple[bool, int, int]:
# If n is prime, it's not semiprime. # If n is prime, it's not semiprime.
if primes[n] == 1: if primes[n] == 1:
return False, -1, -1 return False, -1, -1
@ -336,14 +355,14 @@ def is_semiprime(n, primes):
# If n=pq is semiprime, phi(n)=(p-1)(q-1)=pq-p-q+1=n-(p+4)+1 # If n=pq is semiprime, phi(n)=(p-1)(q-1)=pq-p-q+1=n-(p+4)+1
# if p!=q. If p=q (n is a square), phi(n)=n-p. # if p!=q. If p=q (n is a square), phi(n)=n-p.
def phi_semiprime(n, p, q): def phi_semiprime(n: int, p: int, q: int) -> int:
if p == q: if p == q:
return n - p return n - p
return n - (p + q) + 1 return n - (p + q) + 1
def phi(n, primes): def phi(n: int, primes: List[int]) -> float:
# If n is primes, phi(n)=n-1. # If n is primes, phi(n)=n-1.
if primes[n] == 1: if primes[n] == 1:
return n - 1 return n - 1
@ -354,7 +373,7 @@ def phi(n, primes):
if semi_p: if semi_p:
return phi_semiprime(n, p, q) return phi_semiprime(n, p, q)
ph = n ph = float(n)
# If 2 is a factor of n, multiply the current ph (which now is n) # If 2 is a factor of n, multiply the current ph (which now is n)
# by 1-1/2, then divide all factors 2. # by 1-1/2, then divide all factors 2.
@ -417,7 +436,7 @@ def phi(n, primes):
# Function implementing the partition function. # Function implementing the partition function.
def partition_fn(n, partitions, mod=-1): def partition_fn(n: int, partitions: List[int], mod: int = -1) -> int:
# The partition function for negative numbers is 0 by definition. # The partition function for negative numbers is 0 by definition.
if n < 0: if n < 0:
return 0 return 0
@ -452,7 +471,7 @@ def partition_fn(n, partitions, mod=-1):
return int(res) return int(res)
def dijkstra(matrix, distances, m, n, up=False, back=False, start=0): def dijkstra(matrix: List[List[int]], distances: List[List[int]], m: int, n: int, up: bool = False, back: bool = False, start: int = 0) -> None:
visited = zeros((m, n), int) visited = zeros((m, n), int)
for i in range(m): for i in range(m):
@ -495,17 +514,3 @@ def dijkstra(matrix, distances, m, n, up=False, back=False, start=0):
if i == m - 1 and j == n - 1: if i == m - 1 and j == n - 1:
break break
def timing(f):
@wraps(f)
def wrapper(*args, **kwargs):
start = default_timer()
result = f(*args, **kwargs)
end = default_timer()
print(f'{f.__name__!r} took {end - start:.9f} seconds')
return result
return wrapper