import os
import struct
from array import array

def ReadWord(f):
    return struct.unpack(">H", f.read(2))[0]

def ReadLong(f):
    return struct.unpack(">I", f.read(4))[0]

ixCCExtend   = 0
ixCCNegative = 1
ixCCZero     = 2
ixCCOverflow = 3
ixCCCarry    = 4

FLAGS = None
DR = None

def SetZNFlags(val):
    global FLAGS
    FLAGS[ixCCNegative] = (val < 0)
    FLAGS[ixCCZero] = (val == 0)

def hex2(x):
    return hex(x & 0xFFFFFFFFL)

class DaxIndexEntry:
    def __init__(self, fileName, headerSize):
        self.fileName = fileName
        self.headerSize = headerSize

    def PopulateFromFileIndex(self, f):
        # This sums up to 10 bytes in size.
        self.indexID = ReadWord(f)
        self.dataOffset = ReadLong(f)
        self.compressedLength = ReadWord(f)
        self.decompressedLength = ReadWord(f)

    def DecompressData(self):
        global FLAGS, DR

        cData = array('L')
        dData = array('B')

        f = open(self.fileName, "rb")
        try:
            f.seek(self.headerSize + self.dataOffset)
            cData.fromfile(f, self.compressedLength/cData.itemsize)
        finally:
            f.close()

        # This is being run on a PC, and the data is from an Amiga.. so..
        cData.byteswap()
        # First long is the decompressed length.  Confirm it.
        decompressedLength = cData.pop()
        if decompressedLength != self.decompressedLength:
            raise StandardError(self.fileName +": mismatched decompressed length (%d != %d)" % (decompressedLength, self.decompressedLength))

        FLAGS = [ 0 ] * 5
        DR = [ 0 ] * 8

        DR[5] = long(cData.pop())
        DR[0] = long(cData.pop())
        DR[5] ^= DR[0]

        while len(dData) != self.decompressedLength:
            self.decompresschunk(cData, dData)

        if DR[5]:
            raise StandardError("Bad decompression checksum", hex2(D[5]))

        dData.reverse()
        return dData

    def decompresschunk(self, cData, dData):
        # lbC041210
        DR[0] = self.lsr(DR[0], 1)
        if DR[0] == 0:
            self.getnextlong(cData)

        # lbC041218
        if FLAGS[ixCCCarry]:
            # lbC041254
            DR[1] = 2
            self.copynbitnumber(cData)

            D2b = DR[2] & 0xFF
            if D2b >= 2:
                if D2b == 3:
                    # lbC04124E
                    DR[1] = 8
                    DR[4] = 8
                    self.unpackbytes(cData, dData)
                    return
                DR[1] = 8
                self.copynbitnumber(cData)

                DR[3] = DR[2]
                DR[1] = 12
            else:
                # lbC041274
                DR[1] = DR[2] + 9
                DR[2] += 2
                DR[3] = DR[2]
            self.copybytes(cData, dData)
            return

        DR[1] = 8
        DR[3] = 1
        DR[0] = self.lsr(DR[0], 1)
        if DR[0] == 0:
            self.getnextlong(cData)

        # lbC041226
        if FLAGS[ixCCCarry]:
            self.copybytes(cData, dData)
        else:
            DR[1] = 3
            DR[4] = 0
            self.unpackbytes(cData, dData)

    def unpackbytes(self, cData, dData):
        # lbC04122C
        self.copynbitnumber(cData)

        DR[3] = DR[2] + DR[4]
        # lbC041234
        while DR[3] >= 0:
            DR[1] = 7
        # lbC041236
            while DR[1] >= 0:
                DR[0] = self.lsr(DR[0], 1)
                if DR[0] == 0:
                    self.getnextlong(cData)

                DR[2] = self.roxl(DR[2], 1)
                DR[1] -= 1
            dData.append(DR[2] & 0xFF)
            DR[3] -= 1

    def copybytes(self, cData, dData):
        # lbC04127C
        self.copynbitnumber(cData)
        while DR[3] >= 0:
            dData.append(dData[-DR[2]])
            DR[3] -= 1

    # ===============================================

    def getnextlong(self, cData):
        DR[0] = long(cData.pop())
        DR[5] ^= DR[0]
        DR[0] = self.roxr(DR[0], 1)

    def copynbitnumber(self, cData):
        DR[1] -= 1
        DR[2] = 0

        while DR[1] >= 0:
            DR[0] = self.lsr(DR[0], 1)
            if DR[0] == 0:
                self.getnextlong(cData)
            DR[2] = self.roxl(DR[2], 1)
            DR[1] -= 1

    # ===============================================

    def lsr(self, val, cnt):
        # Shift all but the last bit, which we want to carry.
        val >>= cnt-1
        FLAGS[ixCCExtend] = FLAGS[ixCCCarry] = val & 1
        # Now we are ready to shift the last bit.
        val >>= 1
        SetZNFlags(val)
        return val

    # First bit rotated in comes from extend.
    def roxl(self, val, cnt):
        cnt -= 1
        loval = val >> (31 - cnt)
        FLAGS[ixCCCarry] = loval & 1
        val = (((val << 1) | FLAGS[ixCCExtend]) << cnt) | (loval >> 1)
        FLAGS[ixCCExtend] = FLAGS[ixCCCarry]
        SetZNFlags(val)
        return val & 0xFFFFFFFF

    # First bit rotated in comes from extend.
    def roxr(self, val, cnt):
        cnt -= 1
        hival = (val << 1) | FLAGS[ixCCExtend]
        hival <<= (31 - cnt)
        val >>= cnt
        FLAGS[ixCCExtend] = FLAGS[ixCCCarry] = val & 1
        val >>= 1
        val |= hival
        SetZNFlags(val)
        return val & 0xFFFFFFFF

    def __len__(self):
        return 10

    def __repr__(self):
        return "<DaxIndexEntry id=%s offset=%s cLength=%s dLength=%s>" % (self.indexID, self.dataOffset, self.compressedLength, self.decompressedLength)

def ProcessDaxFile(filePath):
    indexEntryList = []
    f = open(filePath, "rb")
    try:
        headerSize = ReadWord(f)
        readSize = 0
        while readSize < headerSize:
            # Give them a header size that takes into account the size word itself.
            entry = DaxIndexEntry(filePath, headerSize+2)
            entry.PopulateFromFileIndex(f)
            indexEntryList.append(entry)
            readSize += len(entry)
#        for entry in indexEntryList:
#            entry.PopulateCompressedData()
    finally:
        f.close()
    return indexEntryList

if __name__ == "__main__":
    inPath = "pooldata"
    outPath = "pooldata-unpacked"
    entryListByFileName = {}
    for fileName in os.listdir(inPath):
        if not fileName.endswith(".dax"):
            continue
        fileReadPath = os.path.join(inPath, fileName)
        idx = 0
        for die in ProcessDaxFile(fileReadPath):
            fileWritePath = os.path.join(outPath, fileName)
            fileWritePath += ".%d" % idx
            dData = die.DecompressData()
            f = open(fileWritePath, "wb")
            dData.tofile(f)
            f.close()
            print fileName, die.decompressedLength
            idx += 1

#    dData = die.DecompressData()
