In [1]:
get_ipython().ast_node_interactivity = 'all'
In [2]:
import z3
In [7]:
def solve_with_ops(sh, N):
    solver = z3.Solver()
    
    shifts = []
    for i in range(N):
        n = z3.BitVec(f"shift_{i}", 64)
        solver.add(n > 0)
        solver.add(n < 64)
        shifts.append(n)
    
    def forward(x):
        return x ^ z3.LShR(x, sh)

    def inverse(x):
        inv = x
        for i in range(N):
            inv = inv ^ z3.LShR(inv, shifts[i])
        return inv

    for _ in range(2 ** 10):
        x = z3.BitVec(f"x", 64)
        if str(solver.check(x != inverse(forward(x)))) == "sat":
            x = solver.model()[x]
            solver.add(x == inverse(forward(x)))
        else:
            break
    print(solver.check())
    print(solver.model())

for i in range(1, 16):
    print(i)
    try:
        solve_with_ops(7, i)
        break
    except:
        pass
    print()
Out:
1
unsat

2
unsat

3
unsat

4
sat
[shift_1 = 14, shift_2 = 28, shift_3 = 56, shift_0 = 7]
In [6]:
def solve_with_ops(sh, N):
    solver = z3.Solver()
    
    shifts = []
    for i in range(N):
        n = z3.BitVec(f"shift_{i}", 64)
        solver.add(n > 0)
        solver.add(n < 64)
        shifts.append(n)
    
    def forward(x):
        return x + (x << sh)

    def inverse(x):
        inv = x
        for i in range(N):
            if i == 0:
                inv = inv - (inv << shifts[i])
            else:
                inv = inv + (inv << shifts[i])
        return inv

    for _ in range(2 ** 10):
        x = z3.BitVec(f"x", 64)
        if str(solver.check(z3.simplify(x != inverse(forward(x))))) == "sat":
            x = solver.model()[x]
            solver.add(z3.simplify(x == inverse(forward(x))))
        else:
            break
    print(solver.check())
    print(solver.model())

for i in range(1, 8):
    print(i)
    try:
        solve_with_ops(13, i)
        break
    except:
        pass
    print()
Out:
1
unsat

2
unsat

3
sat
[shift_1 = 26, shift_2 = 52, shift_0 = 13]