#!/usr/bin/env python3
# patch_napster.py
import sys

if len(sys.argv) < 4:
    print("Usage: python3 patch_napster.py <file> <old_string> <new_string>")
    sys.exit(1)

fn=sys.argv[1]
old=sys.argv[2]
new=sys.argv[3]

data=open(fn,"rb").read()

def patch_bytes(orig_bytes, new_bytes, suffix_label):
    i = data.find(orig_bytes)
    if i == -1:
        print(f"{suffix_label}: pattern not found.")
        return None
    if len(new_bytes) > len(orig_bytes):
        print(f"{suffix_label}: new string longer ({len(new_bytes)}) than old ({len(orig_bytes)}). Aborting.")
        return None
    padded = new_bytes + b'\x00'*(len(orig_bytes)-len(new_bytes))
    newdata = data[:i] + padded + data[i+len(orig_bytes):]
    out = fn + ".patched"
    open(out,"wb").write(newdata)
    print(f"{suffix_label}: patched and wrote {out} (offset 0x{i:x})")
    return out

# ASCII
old_a = old.encode('ascii',errors='ignore')
new_a = new.encode('ascii',errors='ignore')
patch_bytes(old_a,new_a,"ASCII")

# UTF-16LE
old_u = old.encode('utf-16le',errors='ignore')
new_u = new.encode('utf-16le',errors='ignore')
patch_bytes(old_u,new_u,"UTF-16LE")