import json
import sys
import time
from collections import namedtuple
from functools import cache
from pathlib import Path
from typing import Optional

import psycopg

from com import eval_config, progressbar


Note = namedtuple("Note", ["renote_id", "reply_id", "user_id"])
Tree = namedtuple("Tree", ["id", "replies", "renotes"])

config = eval_config()
conn: psycopg.Connection = config["connect"]()
user_id: str = config["user_id"]
early_exit: Optional[int] = config.get("early_exit")


print("fetching note ids", file=sys.stderr)
note_ids = set()
cur = conn.execute(
    'select id from note where "userId" = %s and not ("renoteId" is not null and text is null)',
    [user_id],
)
while rows := cur.fetchmany(0xFF):
    time.sleep(0.0001)
    for row in rows:
        note_ids.add(row[0])
    if early_exit and len(note_ids) > early_exit:
        break


@cache
def get_note(id: str) -> Note:
    time.sleep(0.0001)
    return Note(
        *conn.execute(
            'select "renoteId", "replyId", "userId" from note where id = %s', [id]
        ).fetchone()
    )


roots = {}
trees = {}


def tree_init(id: str, seek: bool = True) -> Tree:
    if tree := trees.get(id):
        return tree
    tree = Tree(id, [], [])
    note = get_note(id)
    if note.reply_id or note.renote_id:
        if note.reply_id:
            p_tree = tree_init(note.reply_id)
            p_tree.replies.append(tree)
        if note.renote_id:
            r_tree = tree_init(note.renote_id, False)
            r_tree.renotes.append(tree)
    else:
        roots[id] = tree
    trees[id] = tree
    return tree


def make_widgets(msg, trees, roots):
    widgets = [
        f"{msg} ",
        progressbar.Percentage(),
        " ",
        progressbar.Bar(),
        " ",
        progressbar.SimpleProgress("%(value_s)s/%(max_value_s)s"),
        " ",
    ]
    if trees:
        widgets += [progressbar.Variable("trees"), " "]
    if roots:
        widgets += [progressbar.Variable("roots"), " "]
    widgets += [progressbar.ETA()]
    return widgets


pb = progressbar.ProgressBar(
    0,
    len(note_ids),
    widgets=make_widgets("building trees", True, True),
)
for note_id in note_ids:
    tree_init(note_id)
    pb.increment(trees=len(trees), roots=len(roots))
pb.finish()


def traverse(tree: Tree):
    note = get_note(tree.id)
    if note.user_id == user_id:
        expand(tree)
    else:
        for child in tree.replies:
            traverse(child)


def expand(tree: Tree):
    time.sleep(0.0001)
    for row in conn.execute(
        "select id from note_replies(%s, 1, 1000)", [tree.id]
    ).fetchall():
        if row[0] in trees:
            continue
        note = get_note(row[0])
        new = Tree(row[0], [], [])
        if note.reply_id == tree.id:
            # is a reply
            tree.replies.append(new)
            trees[row[0]] = new
        if note.renote_id == tree.id:
            # is a renote
            tree.renotes.append(new)
            trees[row[0]] = new
    for child in tree.replies:
        expand(child)


roots_len = len(roots)
pb = progressbar.ProgressBar(
    0, roots_len, widgets=make_widgets("expanding roots", True, False)
)

for root in roots.values():
    traverse(root)
    pb.increment(trees=len(trees))
pb.finish()


with Path("graph.db").open("w") as f:
    pb = progressbar.ProgressBar(
        0, len(trees), widgets=make_widgets("saving graph", False, False)
    )
    for key, tree in trees.items():
        note = get_note(tree.id)
        is_root = tree.id in roots
        f.write(f"{tree.id}\t")
        f.write(",".join((reply.id for reply in tree.replies)))
        f.write(f"\t")
        f.write(",".join((renote.id for renote in tree.renotes)))
        f.write(f"\t")
        flags = []
        if tree.id in roots:
            flags.append("root")
        if note.user_id == user_id:
            flags.append("self")
        f.write(",".join(flags))
        f.write(f"\n")
        pb.increment()
    pb.finish()