#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from glob import glob
import json
from ffxiv_aku import *
from ffxiv_aku import print_color_red, readJsonFile
from natsort import natsorted
from collections import defaultdict
import zipfile
import mysql.connector
import re  # needed in split_mysql_script

_user_and_db = "d03e908b"
_pass = "tEBQZecjSCe7ejD87k6P"
dbtype = "mysql"

LANGS = ["de", "en", "fr", "ja"]
BASE_FIELDS = [
    "Name", "Text", "Description", "AchievementKind", "AcceptClassJobCategory", "AdditionalData", "Alias",
    "AlliedNames", "Category", "Class_Plural", "Command", "CompleteText", "CurrencyItem", "Destination", "Effect",
    "Masculine", "Feminine", "JournalCategory", "JournalSection", "MainCommandCategory", "MainOption", "Map",
    "Name_Action", "Name_Short", "NotCompleteText", "PlaceName", "Singular", "Plural", "ShopName", "SongName",
    "Spot_0", "Spot_1", "Spot_2", "Tag", "Target", "TextCommand", "Title", "Type"
]


def safe_str(val):
    """Convert any value to a safe SQL string (no Icon/Image special-casing here)."""
    if val is None:
        return ""
    if isinstance(val, (dict, list)):
        return json.dumps(val, ensure_ascii=False)
    return str(val)


def icon_value(icon_obj):
    """Always pick Icon.path and map .tex -> .png. Return '' if not available."""
    if isinstance(icon_obj, dict):
        path = icon_obj.get("path", "")
        return path.replace(".tex", ".png") if path else ""
    if isinstance(icon_obj, str):
        return icon_obj.replace(".tex", ".png")
    return ""


def image_value(img_obj):
    """Always pick Image.path and map .tex -> .png. Return '' if not available."""
    if isinstance(img_obj, dict):
        path = img_obj.get("path", "")
        return path.replace(".tex", ".png") if path else ""
    if isinstance(img_obj, str):
        return img_obj.replace(".tex", ".png")
    return ""


def detect_fields(data):
    """
    Scan all rows recursively and detect:
      – Language fields (existing behavior)
      – Nested icon/image objects such as Name.Icon, ScreenImage.Image
      – Top-level Icon/Image columns
    """
    fields_map = defaultdict(set)
    has_icon = False
    has_image = False
    nested_icon_fields = set()

    ICON_KEYS = {"Icon", "Image"}

    def is_icon_dict(obj):
        return isinstance(obj, dict) and ("path" in obj or "id" in obj)

    def visit(obj, parent_path=""):
        nonlocal has_icon, has_image

        if not isinstance(obj, dict):
            return

        for key, val in obj.items():
            # Detect top-level Icon / Image
            if parent_path == "" and key == "Icon":
                has_icon = True
            if parent_path == "" and key == "Image":
                has_image = True

            # Detect nested icon/image fields like Name.Icon, ScreenImage.Image
            if key in ICON_KEYS and is_icon_dict(val):
                full = parent_path + "." + key if parent_path else key
                nested_icon_fields.add(full)

            # Language dictionary detection
            if key in BASE_FIELDS and isinstance(val, dict):
                for lang in LANGS:
                    if lang in val:
                        fields_map[(parent_path, key)].add(lang)

            # base_lang keys like Name_de, Title_en etc.
            if "_" in key:
                base_guess, lang_guess = key.rsplit("_", 1)
                if base_guess in BASE_FIELDS and lang_guess in LANGS:
                    fields_map[(parent_path, base_guess)].add(lang_guess)

            # Recurse into children
            next_parent = key if parent_path == "" else f"{parent_path}.{key}"
            visit(val, next_parent)

    # visit all entries
    for entry in data.values():
        visit(entry)

    # Build headers
    fields = []
    seen = set()

    def add(col):
        if col not in seen:
            fields.append(col)
            seen.add(col)

    add("_id")
    add("0xID")

    # Keep existing top-level Icon and Image
    if has_icon:
        add("Icon")
    if has_image:
        add("Image")

    # Add nested icon/image fields
    for f in sorted(nested_icon_fields):
        add(f)

    # Order parent_path groups
    parents = {parent for (parent, _) in fields_map.keys()}
    ordered_parents = [""] + sorted(p for p in parents if p != "")

    # Add multilingual groups
    for parent in ordered_parents:
        for base in BASE_FIELDS:
            langs = [l for l in LANGS if l in fields_map.get((parent, base), set())]
            for lang in langs:
                if parent == "":
                    add(f"{base}_{lang}")
                else:
                    add(f"{parent}_{base}_{lang}")

    return fields


def get_nested_value(entry, field):
    """
    Resolve a field name that may be:
      - nested like 'Quest.Name_de'
      - a nested icon/image field like 'Name.Icon' or 'ScreenImage.Image'
      - a root language field like 'Name_de'
    Supports both styles:
      - parent: { base: { de: "..." } }
      - parent: { base_de: "..." }
      - nested icons/images: e.g. Name.Icon, ScreenImage.Image
    """
    # Top-level Icon
    if field == "Icon":
        return icon_value(entry.get("Icon", ""))

    # Top-level Image
    if field == "Image":
        return image_value(entry.get("Image", ""))

    # Handle dotted nested paths first (e.g. Name.Icon, ScreenImage.Image, Foo.Bar.Baz)
    if "." in field:
        path_parts = field.split(".")
        leaf = path_parts[-1]

        # Special case for nested icon/image-like objects
        if leaf in ("Icon", "Image"):
            target = entry
            for p in path_parts[:-1]:
                if isinstance(target, dict):
                    target = target.get(p, {})
                else:
                    target = {}
                    break
            if isinstance(target, dict):
                val = target.get(leaf, "")
                return icon_value(val) if leaf == "Icon" else image_value(val)
            return ""

        # Generic nested value (non-icon/image)
        target = entry
        for p in path_parts:
            if isinstance(target, dict):
                target = target.get(p, "")
            else:
                return ""
        return target

    # From here on, we handle base_lang style fields with underscores
    if "_" not in field:
        # Simple non-language field at root
        return entry.get(field, "")

    parts = field.split("_")
    if len(parts) >= 3:
        # last part is lang, second last is base, rest is parent path
        lang = parts[-1]
        base = parts[-2]
        parent_path = parts[:-2]  # could be multiple levels
        target = entry
        for p in parent_path:
            if isinstance(target, dict):
                target = target.get(p, {})
            else:
                target = {}
                break
        if not isinstance(target, dict):
            return ""
        if base in target and isinstance(target[base], dict):
            return target[base].get(lang, "")
        return target.get(f"{base}_{lang}", "")
    else:
        # no parent path, classic base_lang at root
        base, lang = parts
        if base in entry and isinstance(entry[base], dict):
            return entry[base].get(lang, "")
        return entry.get(field, "")


def insert_into(filename, fields, data):
    local_result = ""
    max_len = 0
    runonce = True
    query = ""

    for _id in natsorted(data):
        entry = data[_id]

        # Build _id and 0xID (support subrow_id)
        int_id = int(_id.split(".")[0])
        sub_id = entry.get("subrow_id")
        if sub_id is not None:
            full_id = f"{int_id}.{sub_id}"
            hex_id = f"{str(hex(int_id))[2:].upper()}.{sub_id}"
        else:
            full_id = str(int_id)
            hex_id = str(hex(int_id))[2:].upper()

        row = [full_id, hex_id]

        # Fill remaining columns
        for field in fields[2:]:  # skip _id, 0xID
            val = get_nested_value(entry, field)
            if field not in ("Icon", "Image"):
                value = safe_str(val)
            else:
                # Icon/Image already normalized to a simple string
                value = val
            if "'" in value and dbtype == "mysql":
                value = value.replace("'", "''")
            if len(value) > max_len:
                max_len = len(value)
            row.append(value)

        header = '`, `'.join(fields[1:])  # skip _id
        values_str = "', '".join(row[1:])  # skip _id for values

        if runonce:
            query = f"INSERT INTO {filename} (`_id`, `{header}`) VALUES \n"
            runonce = False

        # Quote _id if it contains a dot (TEXT PK)
        if "." in row[0]:
            query += f"\t('{row[0]}', '{values_str}'),\n"
        else:
            query += f"\t({row[0]}, '{values_str}'),\n"

    if query:
        local_result += query[:-2] + ";"
    return local_result, max_len


def createTable(filename, fields, max_len):
    result = f"DROP TABLE IF EXISTS {filename};\n"
    col_defs = []
    for f in fields:
        if f == "_id":
            col_defs.append("`_id` TEXT")
        elif f == "0xID":
            col_defs.append("`0xID` TEXT(100)")
        else:
            col_defs.append(f"`{f}` TEXT({max_len})")
    # TEXT PK with length prefix for MySQL
    result += f"CREATE TABLE `{filename}` ({', '.join(col_defs)}, PRIMARY KEY (`_id`(100)));\n"
    return result


def work_on_category(_file):
    result = ""
    try:
        data = readJsonFile(_file)
        filename = _file.split("/")[-1].replace(".json", "").replace(".", "_")
        fields = detect_fields(data)
        iresult, max_len = insert_into(filename, fields, data)
        cresult = createTable(filename, fields, max_len)
        result += cresult
        result += iresult
        result += "\n\n\n"
    except KeyError as e:
        print_color_red(e)
        print_color_red(_file)
    return result


def run():
    print("[CJTSAF] Create Json SQL!")
    #files = glob(r"P:\extras\json\xivapi_data\*.json")
    #files = glob(r"P:\extras\json\xivapi_data\*.json")
    files = glob("../xivapi_data/*.json")
    result = ""
    for file in files:
        result += work_on_category(file)
    with open("SQL_Translate.sql", "w", encoding="utf8") as f:
        f.write(result.replace("'''", "'"))
    # zip the SQL file
    zip_path = "SQL_Translate.sql" + ".zip"
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write("SQL_Translate.sql")
    print("[CJTSAF] DONE Create Json SQL!")


def split_mysql_script(sql_text: str):
    """
    Yield (statement, start_line, end_line). Handles:
    - Single/double-quoted strings with backslash or '' escapes
    - Backtick-quoted identifiers
    - --, #, and /* */ comments
    - DELIMITER <token> directives (case-insensitive)
    """
    i, n = 0, len(sql_text)
    line_no = 1
    start_line = 1
    term = ";"

    def at(idx, token):
        return sql_text.startswith(token, idx)

    def bump(c):
        nonlocal i, line_no
        i += 1
        if c == "\n":
            line_no += 1

    while i < n:
        # Skip leading whitespace/comments between statements
        while i < n:
            if sql_text[i].isspace():
                bump(sql_text[i])
                continue
            # -- comment (with or without trailing space)
            if at(i, "--"):
                # consume to end of line
                while i < n and sql_text[i] != "\n":
                    i += 1
                continue
            # # comment
            if sql_text[i] == "#":
                while i < n and sql_text[i] != "\n":
                    i += 1
                continue
            # /* ... */ comment
            if at(i, "/*"):
                i += 2
                while i < n and not at(i, "*/"):
                    if sql_text[i] == "\n":
                        line_no += 1
                    i += 1
                i = min(n, i + 2)
                continue
            # DELIMITER directive
            if re.match(r"(?is)\s*DELIMITER\b", sql_text[i:i+16]):
                # read the whole line
                line_start = i
                while i < n and sql_text[i] != "\n":
                    i += 1
                line = sql_text[line_start:i]
                m = re.search(r"(?is)DELIMITER\s+(\S+)", line)
                if m:
                    term = m.group(1)
                continue
            break

        if i >= n:
            break

        # Begin a statement
        start_line = line_no
        buf = []
        in_s = False  # single quote
        in_d = False  # double quote
        in_bt = False  # backtick identifier
        esc = False

        while i < n:
            c = sql_text[i]
            nx = sql_text[i+1] if i+1 < n else ""

            # Handle strings / identifiers
            if in_s:
                buf.append(c)
                if esc:
                    esc = False
                elif c == "\\":
                    esc = True
                elif c == "'":
                    # also handle '' escape (two single quotes)
                    if nx == "'":
                        buf.append(nx)
                        i += 1
                    else:
                        in_s = False
                if c == "\n":
                    line_no += 1
                i += 1
                continue

            if in_d:
                buf.append(c)
                if esc:
                    esc = False
                elif c == "\\":
                    esc = True
                elif c == '"':
                    in_d = False
                if c == "\n":
                    line_no += 1
                i += 1
                continue

            if in_bt:
                buf.append(c)
                if c == "`":
                    in_bt = False
                if c == "\n":
                    line_no += 1
                i += 1
                continue

            # Not inside quotes: check comments quickly
            if at(i, "/*"):
                buf.append("/*")
                i += 2
                while i < n and not at(i, "*/"):
                    if sql_text[i] == "\n":
                        line_no += 1
                    buf.append(sql_text[i])
                    i += 1
                if at(i, "*/"):
                    buf.append("*/")
                    i += 2
                continue

            if at(i, "--"):
                # consume to EOL, but keep it inside the statement buffer
                while i < n and sql_text[i] != "\n":
                    buf.append(sql_text[i])
                    i += 1
                if i < n:
                    buf.append("\n")
                    line_no += 1
                    i += 1
                continue

            if c == "#":
                while i < n and sql_text[i] != "\n":
                    buf.append(sql_text[i])
                    i += 1
                if i < n:
                    buf.append("\n")
                    line_no += 1
                    i += 1
                continue

            # Enter quotes
            if c == "'":
                in_s = True
                buf.append(c)
                i += 1
                continue
            if c == '"':
                in_d = True
                buf.append(c)
                i += 1
                continue
            if c == "`":
                in_bt = True
                buf.append(c)
                i += 1
                continue

            # Statement terminator (can be multi-char, e.g. $$ or //)
            if term and sql_text.startswith(term, i):
                i += len(term)
                stmt = "".join(buf).strip()
                end_line = line_no
                if stmt:
                    yield stmt, start_line, end_line
                break

            buf.append(c)
            if c == "\n":
                line_no += 1
            i += 1
        else:
            # EOF without terminator: flush what's left
            stmt = "".join(buf).strip()
            if stmt:
                yield stmt, start_line, line_no


def executeScriptsFromFile(path="SQL_Translate.sql"):
    cnx = mysql.connector.connect(
        user=_user_and_db,
        password=_pass,
        host='w01dc079.kasserver.com',
        database=_user_and_db,
        charset='utf8mb4',
        collation='utf8mb4_unicode_ci',
        autocommit=False,
    )
    cur = cnx.cursor()
    with open(path, "r", encoding="utf-8-sig") as f:  # -sig trims BOM if present
        sql = f.read()

    count = 0
    for count, (stmt, s_line, e_line) in enumerate(split_mysql_script(sql), start=1):
        try:
            cur.execute(stmt)
        except mysql.connector.Error as e:
            # Show precise context
            print(f"\nERROR in statement #{count} (lines {s_line}-{e_line})")
            snippet = "\n".join(stmt.splitlines()[:50])
            print("--- statement head ---\n", snippet, "\n--- end ---")
            print("MySQL error:", e)
            cnx.rollback()
            raise
    cnx.commit()
    cur.close()
    cnx.close()
    print(f"Executed {count} statements successfully.")


if __name__ == "__main__":
    run()
    executeScriptsFromFile()
