import { memo, useCallback, useEffect, useMemo, useRef } from 'react'
import 'live2d-cubism-core'
import { CubismFramework, LogLevel } from 'live2d-cubism-framework'
import { CubismMatrix44 } from 'live2d-cubism-framework/dist/math/cubismmatrix44'
import { CubismViewMatrix } from 'live2d-cubism-framework/dist/math/cubismviewmatrix'
import { useLatest, useLatestCallback } from './hooks'
import { CubismTargetPoint } from 'live2d-cubism-framework/dist/math/cubismtargetpoint'
import { loadLive2DModel, prepareLive2DModel } from './live2d/live2DModel'
import { useKeyboardShortcuts } from './shortcuts'

export type Live2DModel = Awaited<ReturnType<typeof loadLive2DModel>>
export type RenderedLive2DModel = Awaited<ReturnType<typeof prepareLive2DModel>>

const SCREENSHOT_WIDTH = 1920
const SCREENSHOT_HEIGHT = 1080

const SECONDS_PER_FRAME = 1 / 60

export interface Live2DDecal {
    textureIndex: number
    x: number
    y: number
    data: HTMLImageElement
}

export interface Live2DProps {
    model: Live2DModel
    position?: { x: number, y: number }
    scale?: number
    multiplyColors?: Partial<Record<string, { r: number, g: number, b: number, a?: number }>>
    screenColors?: Partial<Record<string, { r: number, g: number, b: number, a?: number }>>
    parameterValues?: Partial<Record<string, number>>
    parameterSmoothing?: Partial<Record<string, number>>
    decals?: Partial<Record<string, Live2DDecal>>
    requestScreenShot?: { current: boolean }
    backgroundImage?: HTMLImageElement | null
    backgroundColor?: string | null
    onError?: (err: any) => void
    onScreenShotTaken?: (blob: Blob) => void
}

const decalCanvas = document.createElement('canvas')
const decalCtx = decalCanvas.getContext('2d')

function prepareDecalCanvas(width: number, height: number, decal: CanvasImageSource | null) {
    decalCanvas.width = width
    decalCanvas.height = height
    decalCtx?.clearRect(0, 0, width, height)
    if (decal) {
        decalCtx?.drawImage(decal, 0, 0, width, height)
    }
    return decalCanvas
}

export const Live2D = memo(function Live2D({ model, position, scale, multiplyColors, screenColors, parameterValues, parameterSmoothing, decals, requestScreenShot, onError, onScreenShotTaken, backgroundImage, backgroundColor }: Live2DProps) {
    const stateRef = useRef<{ canvas: HTMLCanvasElement, gl: WebGLRenderingContext, shaderProgram: WebGLProgram, frameBuffer: WebGLFramebuffer, viewMatrix: CubismViewMatrix, deviceToScreenMatrix: CubismMatrix44, model: RenderedLive2DModel | null, decalCache: Map<string, Live2DDecal> } | null>(null)
    const lastUpdateRef = useRef(0)
    const cumulativeUpdateTimeRef = useRef(0)
    const isDraggingRef = useRef(false)
    const latestOnError = useLatestCallback(onError)
    const latestOnScreenShotTaken = useLatestCallback(onScreenShotTaken)
    const latestPosition = useLatest(position)
    const latestScale = useLatest(scale)
    const latestMultiplyColors = useLatest(multiplyColors)
    const latestScreenColors = useLatest(screenColors)
    const latestParameterValues = useLatest(parameterValues)
    const latestParameterSmoothing = useLatest(parameterSmoothing)
    const latestDecals = useLatest(decals)
    const latestRequestScreenShot = useLatest(requestScreenShot)
    const latestBackgroundImage = useLatest(backgroundImage)
    const latestBackgroundColor = useLatest(backgroundColor)

    useKeyboardShortcuts('Pose Model', useMemo(() => ({
        zoom: { name: 'Zoom', key: 'Scroll Wheel' },
        pan: { name: 'Pan', key: 'Left Mouse Button' },
        pan2: { name: 'Pan', key: 'Middle Mouse Button' },
        poseAll: { name: 'Pose All', key: 'Right Mouse Button' },
        poseHead: { name: 'Pose Head', key: 'Left Mouse Button', shift: true },
        poseEyes: { name: 'Pose Eyes', key: 'Right Mouse Button', shift: true },
        poseBody: { name: 'Pose Body', key: 'Right Mouse Button', ctrl: true },
        poseBody2: { name: 'Pose Body 2', key: 'Left Mouse Button', ctrl: true },
    }), []))

    const scheduleUpdateCallback = useCallback((callback: (time: number) => void) => {
        const currentState = stateRef.current
        requestAnimationFrame((time) => {
            if (stateRef.current === currentState) callback(time)
        })
    }, [])

    const onResize = useCallback(() => {
        if (stateRef.current) {
            const { canvas, gl, viewMatrix, deviceToScreenMatrix } = stateRef.current
            const width = document.documentElement.clientWidth
            const height = document.documentElement.clientHeight
            const ratio = width / height
            const left = -ratio
            const right = ratio
            const bottom = -1
            const top = 1

            canvas.width = width
            canvas.height = height
            gl.viewport(0, 0, width, height)
            viewMatrix.setScreenRect(left, right, bottom, top)
            viewMatrix.scale(1.5, 1.5)
            viewMatrix.setMaxScale(2.0)
            viewMatrix.setMinScale(0.5)
            viewMatrix.setMaxScreenRect(-2, 2, -2, 2)
            deviceToScreenMatrix.loadIdentity()
            if (width > height) {
                const screenW = Math.abs(right - left)
                deviceToScreenMatrix.scaleRelative(screenW / width, -screenW / width)
            } else {
                const screenH = Math.abs(top - bottom)
                deviceToScreenMatrix.scaleRelative(screenH / height, -screenH / height)
            }
            deviceToScreenMatrix.translateRelative(-width * 0.5, -height * 0.5)
        }
    }, [])

    const onMouseDrag = useCallback((e: MouseEvent) => {
        if (!stateRef.current || !stateRef.current.model) return
        e.preventDefault()

        const { canvas, deviceToScreenMatrix, viewMatrix, model } = stateRef.current
        const { panPoint, eyePoint, headPoint, bodyPoint, dragManager } = model.userModel
        const rect = canvas.getBoundingClientRect()
        const deviceX = e.clientX - rect.left
        const deviceY = e.clientY - rect.top
        const screenX = deviceToScreenMatrix.transformX(deviceX)
        const screenY = deviceToScreenMatrix.transformY(deviceY)
        const viewX = viewMatrix.invertTransformX(screenX)
        const viewY = viewMatrix.invertTransformY(screenY)

        if (!e.buttons) return
        if (e.type === 'mousedown') {
            isDraggingRef.current = true
        }
        if (!isDraggingRef.current) return
        const ctrl = e.ctrlKey || e.metaKey
        const alt = e.altKey
        const shift = e.shiftKey
        const lmb = !!(e.buttons & 0x1)
        const rmb = !!(e.buttons & 0x2)
        const mmb = !!(e.buttons & 0x4)
        if (mmb || (!shift && !ctrl && !alt && lmb)) {
            const deviceX1 = deviceX + e.movementX
            const deviceY1 = deviceY + e.movementY
            const screenX1 = deviceToScreenMatrix.transformX(deviceX1)
            const screenY1 = deviceToScreenMatrix.transformY(deviceY1)
            const panX = panPoint.getX()
            const panY = panPoint.getY()
            const finalX = panX + (screenX1 - screenX)
            const finalY = panY + (screenY1 - screenY)
            panPoint.set(finalX, finalY)
        } else {
            const targets: CubismTargetPoint[] = []
            if (!shift && !ctrl && !alt && rmb) targets.push(eyePoint, headPoint, bodyPoint, dragManager)
            if (shift && !ctrl && !alt && lmb) targets.push(headPoint)
            if (shift && !ctrl && !alt && rmb) targets.push(eyePoint)
            if (!shift && ctrl && !alt && lmb) targets.push(bodyPoint)
            if (!shift && ctrl && !alt && rmb) targets.push(dragManager)
            if (!targets.length) return
            for (const t of targets) t.set(viewX, viewY)
        }
    }, [])

    const onMouseUp = useCallback((e: MouseEvent) => {
        isDraggingRef.current = false
    }, [])

    const onWheelScroll = useCallback((e: WheelEvent) => {
        if (!stateRef.current) return
        e.preventDefault()
        let amount = -e.deltaY
        switch (e.deltaMode) {
            case WheelEvent.DOM_DELTA_PIXEL:
                amount *= 1
                break
            case WheelEvent.DOM_DELTA_LINE:
                amount *= 4
                break
            case WheelEvent.DOM_DELTA_PAGE:
                amount *= 80
                break
        }
        const { model, canvas, deviceToScreenMatrix, viewMatrix } = stateRef.current
        if (!model) return
        const current = model.userModel.scalePoint.getX()
        const target = Math.max(0.01, current + current * amount * 0.001)
        model.userModel.scalePoint.set(target, target)

        const diff = (target - current) * 1.5

        const rect = canvas.getBoundingClientRect()
        const deviceX = e.clientX - rect.left
        const deviceY = e.clientY - rect.top
        const screenX = deviceToScreenMatrix.transformX(deviceX)
        const screenY = deviceToScreenMatrix.transformY(deviceY)
        const viewX = viewMatrix.invertTransformX(screenX)
        const viewY = viewMatrix.invertTransformY(screenY)
        const panX = model.userModel.panPoint.getX()
        const panY = model.userModel.panPoint.getY()
        model.userModel.panPoint.set(panX - viewX * diff, panY - viewY * diff)
    }, [])

    const onContextMenu = useCallback((e: MouseEvent) => {
        e.preventDefault()
    }, [])

    const onUpdate = useCallback((time: number) => {
        if (!stateRef.current) return
        if (lastUpdateRef.current === 0) lastUpdateRef.current = time
        const deltaTime = (time - lastUpdateRef.current) / 1000
        lastUpdateRef.current = time
        cumulativeUpdateTimeRef.current += deltaTime

        // Limit the number of frames to catch up by
        cumulativeUpdateTimeRef.current = Math.min(cumulativeUpdateTimeRef.current, SECONDS_PER_FRAME * 4)

        const { canvas, gl, shaderProgram, frameBuffer, viewMatrix, decalCache, model } = stateRef.current

        if (latestRequestScreenShot()?.current) {
            canvas.width = SCREENSHOT_WIDTH
            canvas.height = SCREENSHOT_HEIGHT
        }

        gl.clearColor(0, 0, 0, 0)
        gl.enable(gl.DEPTH_TEST)
        gl.depthFunc(gl.LEQUAL)
        gl.clear(gl.COLOR_BUFFER_BIT | gl.DEPTH_BUFFER_BIT)
        gl.clearDepth(1.0)
        gl.enable(gl.BLEND)
        gl.blendFunc(gl.SRC_ALPHA, gl.ONE_MINUS_SRC_ALPHA)

        gl.useProgram(shaderProgram)

        // Draw sprites/backgrounds here if applicable

        const mvpMatrix = new CubismMatrix44()

        const projectionMatrix = new CubismMatrix44()
        if (canvas.width < canvas.height)
            projectionMatrix.scale(1.0, canvas.width / canvas.height)
        else
            projectionMatrix.scale(canvas.height / canvas.width, 1.0)
        mvpMatrix.multiplyByMatrix(projectionMatrix)

        const panX = (latestPosition()?.x ?? 0) + (model?.userModel.panPoint.getX() ?? 0)
        const panY = (latestPosition()?.y ?? 0) + (model?.userModel.panPoint.getY() ?? 0)
        viewMatrix.translateX(panX)
        viewMatrix.translateY(panY)
        const scaleX = 1.5 * (latestScale() ?? 1) * (model?.userModel.scalePoint.getX() ?? 1)
        const scaleY = 1.5 * (latestScale() ?? 1) * (model?.userModel.scalePoint.getY() ?? 1)
        viewMatrix.scale(scaleX, scaleY)

        mvpMatrix.multiplyByMatrix(viewMatrix)

        const modelMatrix = new CubismMatrix44()
        mvpMatrix.multiplyByMatrix(modelMatrix)

        if (model) {
            const decals = latestDecals()
            if (decals) {
                for (const key in decals) {
                    if (decalCache.get(key)?.data === decals[key]?.data) continue
                    const decal = decals[key]
                    if (!decal || (decal.data instanceof HTMLImageElement && !decal.data.complete)) continue
                    const decalToApply = prepareDecalCanvas(decal.data.width, decal.data.height, decal.data)
                    model.textures[decal.textureIndex].blit(decal.x, decal.y, decalToApply)
                    decalCache.set(key, decal)
                }
            }
            const decalCacheKeysToDelete: string[] = []
            for (const key of decalCache.keys()) {
                if (!decals || !(key in decals)) {
                    decalCacheKeysToDelete.push(key)
                }
            }
            for (const key of decalCacheKeysToDelete) {
                const decal = decalCache.get(key)!
                const resetDecal = prepareDecalCanvas(decal.data.width, decal.data.height, null)
                model.textures[decal.textureIndex].blit(decal.x, decal.y, resetDecal)
                decalCache.delete(key)
            }

            while (cumulativeUpdateTimeRef.current > SECONDS_PER_FRAME) {
                cumulativeUpdateTimeRef.current -= SECONDS_PER_FRAME

                model.userModel.update(SECONDS_PER_FRAME, { multiplyColors: latestMultiplyColors(), screenColors: latestScreenColors(), parameterValues: latestParameterValues(), parameterSmoothing: latestParameterSmoothing() })
            }

            model.userModel.getRenderer().setMvpMatrix(mvpMatrix)
            model.userModel.getRenderer().setRenderState(frameBuffer, [0, 0, canvas.width, canvas.height])
            try {
                model.userModel.getRenderer().drawModel()
            } catch (e) {
                model.userModel.resetRenderer(gl)
            }
        }

        if (latestRequestScreenShot()?.current) {
            latestRequestScreenShot()!.current = false

            const cvs = document.createElement('canvas')
            cvs.width = SCREENSHOT_WIDTH
            cvs.height = SCREENSHOT_HEIGHT

            const ctx = cvs.getContext('2d')!

            const bgColor = latestBackgroundColor()
            if (bgColor) {
                ctx.fillStyle = bgColor
                ctx.fillRect(0, 0, cvs.width, cvs.height)
            }
            const bgImage = latestBackgroundImage()
            if (bgImage) {
                const imageRatio = bgImage.width / bgImage.height
                const canvasRatio = cvs.width / cvs.height
                let w = cvs.width
                let h = cvs.height
                if (imageRatio < canvasRatio) {
                    h = w / imageRatio
                } else {
                    w = h * imageRatio
                }
                const x = -(w - cvs.width) / 2
                const y = -(h - cvs.height) / 2
                ctx.drawImage(bgImage, x, y, w, h)
            }

            ctx.drawImage(canvas, 0, 0)
            cvs.toBlob(blob => !!blob && latestOnScreenShotTaken(blob))
            onResize()
        }

        scheduleUpdateCallback(onUpdate)
    }, [latestRequestScreenShot, latestPosition, latestScale, scheduleUpdateCallback, latestDecals, latestMultiplyColors, latestScreenColors, latestParameterValues, latestParameterSmoothing, latestBackgroundColor, latestBackgroundImage, onResize, latestOnScreenShotTaken])

    const onCanvasMount = useCallback((canvas: HTMLCanvasElement | null) => {
        try {
            if (canvas && !stateRef.current) {
                if (!CubismFramework.isStarted()) {
                    CubismFramework.startUp({
                        logFunction: console.log,
                        loggingLevel: LogLevel.LogLevel_Verbose,
                    })
                }
                if (!CubismFramework.isInitialized()) {
                    CubismFramework.initialize()
                }

                const gl = canvas.getContext('webgl')
                if (!gl) throw new Error('Failed to initialize WebGL')

                const vertexShaderId = gl.createShader(gl.VERTEX_SHADER)
                if (vertexShaderId === null) throw new Error("Failed to initialize vertex shader")
                const vertexShader = `
                    precision mediump float;
                    attribute vec3 position;
                    attribute vec2 uv;
                    varying vec2 vuv;
                    void main(void) {
                        gl_Position = vec4(position, 1.0);
                        vuv = uv;
                    }
                `.split('\n').map(s => s.trim()).join('\n')
                gl.shaderSource(vertexShaderId, vertexShader)
                gl.compileShader(vertexShaderId)

                const fragmentShaderId = gl.createShader(gl.FRAGMENT_SHADER)
                if (fragmentShaderId === null) throw new Error('Failed to initialize fragment shader')
                const fragmentShader = `
                    precision mediump float;
                    varying vec2 vuv;
                    uniform sampler2D texture;
                    void main(void)
                    {
                        gl_FragColor = texture2D(texture, vuv);
                    }
                `.split('\n').map(s => s.trim()).join('\n')
                gl.shaderSource(fragmentShaderId, fragmentShader)
                gl.compileShader(fragmentShaderId)

                const shaderProgram = gl.createProgram()
                if (shaderProgram === null) throw new Error('Failed to initialize shader program')
                gl.attachShader(shaderProgram, vertexShaderId)
                gl.attachShader(shaderProgram, fragmentShaderId)
                gl.deleteShader(vertexShaderId)
                gl.deleteShader(fragmentShaderId)
                gl.linkProgram(shaderProgram)
                gl.useProgram(shaderProgram)

                const frameBuffer = gl.getParameter(gl.FRAMEBUFFER_BINDING)

                const viewMatrix = new CubismViewMatrix()
                const deviceToScreenMatrix = new CubismMatrix44()

                const decalCache = new Map<string, Live2DDecal>()

                stateRef.current = { canvas, gl, shaderProgram, frameBuffer, viewMatrix, deviceToScreenMatrix, decalCache, model: prepareLive2DModel(model, gl) }

                onResize()

                window.addEventListener('resize', onResize)
                document.addEventListener('mouseup', onMouseUp)
                stateRef.current.canvas.addEventListener('mousedown', onMouseDrag)
                stateRef.current.canvas.addEventListener('mousemove', onMouseDrag)
                stateRef.current.canvas.addEventListener('wheel', onWheelScroll, { passive: false })
                stateRef.current.canvas.addEventListener('contextmenu', onContextMenu)

                scheduleUpdateCallback(onUpdate)
            } else if (!canvas && stateRef.current) {
                window.removeEventListener('resize', onResize)
                document.removeEventListener('mouseup', onMouseUp)
                stateRef.current.canvas.removeEventListener('mousedown', onMouseDrag)
                stateRef.current.canvas.removeEventListener('mousemove', onMouseDrag)
                stateRef.current.canvas.removeEventListener('wheel', onWheelScroll)
                stateRef.current.canvas.removeEventListener('contextmenu', onContextMenu)

                stateRef.current = null
                lastUpdateRef.current = 0
            }
        } catch (err) {
            latestOnError(err)
        }
    }, [model, onResize, onMouseUp, onMouseDrag, onWheelScroll, onContextMenu, scheduleUpdateCallback, onUpdate, latestOnError])

    useEffect(() => {
        if (stateRef.current) {
            const canvas = stateRef.current.canvas
            stateRef.current = null
            onCanvasMount(canvas)
        }
    }, [onCanvasMount])

    return <div style={{ position: 'sticky', left: '0', top: '0', width: '0', height: '0', overflow: 'visible' }}>
        <canvas ref={onCanvasMount} style={{ position: 'absolute', top: '0', width: '100vw', height: '100vh', backgroundColor: backgroundColor ? `#${backgroundColor}` : undefined, backgroundImage: backgroundImage ? `url("${backgroundImage.src}")` : undefined, backgroundSize: 'cover', backgroundPosition: 'center' }} />
    </div>
})
