https://projecteuler.net/problem=14
Longest Collatz sequence¶
Problem 14¶
The following iterative sequence is defined for the set of positive integers:
n → n/2 (n is even)
n → 3n + 1 (n is odd)
Using the rule above and starting with 13, we generate the following sequence:
13 → 40 → 20 → 10 → 5 → 16 → 8 → 4 → 2 → 1
It can be seen that this sequence (starting at 13 and finishing at 1) contains 10 terms. Although it has not been proved yet (Collatz Problem), it is thought that all starting numbers finish at 1.
Which starting number, under one million, produces the longest chain?
NOTE: Once the chain starts the terms are allowed to go above one million.
First version is pretty slow
%%time
def chain_length(x):
if x == 1:
return 1
if x % 2 == 0:
y = x // 2
else:
y = x * 3 + 1
return chain_length(y) + 1
maxchain = 1
beststart = 1
for i in range(2,1000000):
c = chain_length(i)
if c>maxchain:
beststart = i
maxchain=c
print("Max Chain:",maxchain)
print("From Start Number",beststart)
Second version implements memoization and is much faster
%%time
_MEMO_=dict()
def chain_length(x):
if x in _MEMO_:
return _MEMO_[x]
else:
if x == 1:
return 1
if x % 2 == 0:
y = x // 2
else:
y = x * 3 + 1
result = chain_length(y)+1
_MEMO_[x]=result
return result
maxchain = 1
beststart = 1
for i in range(2,1000000):
c = chain_length(i)
if c>maxchain:
beststart = i
maxchain=c
print("Max Chain:",maxchain)
print("From Start Number",beststart)
At this point I tried replacing the integer division by 2 (x//2) with a bit shift (x>>1).
I also replaced mod 2 test (x%2) with a bitwise AND (x&1).
Neither of these cause any speedup. Python 3 is already efficient in these operations, so I replaced the original code.
In the event that x is odd, then the next step 3*x+1 will always be even, and will always be followed by a division by 2.
So I combine these operations in the following code and save a recursion call. This version runs a tiny bit faster.
%%time
_MEMO_=dict()
def chain_length(x):
if x in _MEMO_:
return _MEMO_[x]
else:
if x == 1:
return 1
if x % 2 == 0:
y = x//2
result = chain_length(y)+1
_MEMO_[x]=result
return result
else:
y = (x * 3 + 1)//2
result = chain_length(y)+2
_MEMO_[x]=result
return result
maxchain = 1
beststart = 1
for i in range(2,1000000):
c = chain_length(i)
if c>maxchain:
beststart = i
maxchain=c
print("Max Chain:",maxchain)
print("From Start Number",beststart)
At this point, I have run out of ideas for simple optimizations. The next step would be to try writing the code without recursion.