import * as THREE from 'three';

declare module 'three' {
    interface Mesh {
        triangleCategorizer?: TriangleCategorizer;
    }
    namespace Mesh {
        export class Mesh { }
    }
}

const protoRaycast = THREE.Mesh.prototype.raycast;
THREE.Mesh.prototype.raycast = function (this: THREE.Mesh, raycaster, interescts) {
    if (!this.triangleCategorizer) return protoRaycast.call(this, raycaster, interescts);
    return this.triangleCategorizer.raycast(raycaster, interescts);
};

export class TriangleCategorizer {
    public yGrid: number;
    public xGrid: number;
    public zGrid: number;
    public mesh: THREE.Mesh;
    public geometry: THREE.BufferGeometry;
    public rootBox: THREE.Box3;
    private gridBoxByName: { [boxName: string]: GridBox; } = {};
    private gridBoxes: GridBox[] = [];

    constructor(mesh: THREE.Mesh, xGrid = 10, yGrid = 10, zGrid = 10) {
        if (xGrid < 1 || yGrid < 1 || zGrid < 1) throw new Error('Grid must be at least 1');
        this.xGrid = xGrid;
        this.yGrid = yGrid;
        this.zGrid = zGrid;
        this.mesh = mesh;
        this.geometry = mesh.geometry;
        this.build();
    }

    public build() {
        this.mesh.updateMatrixWorld(true);
        this.generateRootBox();
        this.generateGridBoxes();
        this.insertIndexToGridBoxes();
    }

    public raycast(raycaster: THREE.Raycaster, interescts: THREE.Intersection[]) {
        // interescts.length = 0;  // not sure why adding before. but now it clears the previous intersects result, which is not ideal.
        const intersectGridBoxes = this.getIntersectGridBox(raycaster);
        const _vec3 = new THREE.Vector3();
        const _triangle = new THREE.Triangle(new THREE.Vector3(), new THREE.Vector3(), new THREE.Vector3());
        const ray = raycaster.ray;
        let minDistance = Infinity;
        let resultPoint: THREE.Vector3;
        let resultTriangle: THREE.Triangle;
        let resultVerticeIndexes: [number, number, number];
        for (let i = 0; i < intersectGridBoxes.length; i++) {
            const gridBox = intersectGridBoxes[i].box;
            if (gridBox.distanceToPoint(ray.origin) > minDistance) continue;
            const triangleIndexes = intersectGridBoxes[i].triangleIndexes;

            for (let j = 0; j < triangleIndexes.length; j++) {
                const verticeIndexes = this.getVerticeIndexes(triangleIndexes[j]);
                this.setTriangleByVerticeIndexes(_triangle, ...verticeIndexes);
                const intersect = ray.intersectTriangle(_triangle.a, _triangle.b, _triangle.c, false, _vec3);
                if (intersect) {
                    const distance = intersect.sub(ray.origin).length();
                    if (distance > raycaster.near && distance < raycaster.far) {
                        if (distance < minDistance) {
                            minDistance = distance;
                            resultVerticeIndexes = verticeIndexes;
                            resultPoint = intersect.clone().add(ray.origin);
                            resultTriangle = _triangle.clone();
                        }
                    }
                }
            }
        }
        if (resultPoint) {
            const normal = new THREE.Vector3();
            resultTriangle.getNormal(normal);
            const face = {
                a: resultVerticeIndexes[0],
                b: resultVerticeIndexes[1],
                c: resultVerticeIndexes[2],
                normal,
                materialIndex: 0 // todo: check 0 is correct
            };
            interescts.push({
                distance: minDistance,
                point: resultPoint,
                object: this.mesh,
                face,
            });
        }
    }

    private generateGridBoxes() {
        for (let x = 0; x < this.xGrid; x++) {
            for (let y = 0; y < this.yGrid; y++) {
                for (let z = 0; z < this.zGrid; z++) {
                    const boxName = `x${x}y${y}z${z}`;
                    const gridBox = new GridBox(boxName);
                    const size = new THREE.Vector3();
                    this.rootBox.getSize(size);
                    const xGridSize = size.x / this.xGrid;
                    const yGridSize = size.y / this.yGrid;
                    const zGridSize = size.z / this.zGrid;
                    gridBox.box.min.copy(this.rootBox.min);
                    gridBox.box.max.copy(this.rootBox.min);
                    gridBox.box.min.add(new THREE.Vector3(x * xGridSize, y * yGridSize, z * zGridSize));
                    gridBox.box.max.add(
                        new THREE.Vector3((x + 1) * xGridSize, (y + 1) * yGridSize, (z + 1) * zGridSize),
                    );
                    this.gridBoxByName[boxName] = gridBox;
                    this.gridBoxes.push(gridBox);
                }
            }
        }
    }

    private generateRootBox() {
        if (!this.geometry.boundingBox) this.geometry.computeBoundingBox();
        this.rootBox = this.geometry.boundingBox.applyMatrix4(this.mesh.matrixWorld);
    }

    private insertIndexToGridBoxes() {
        const self = this;
        const indexAttribute = this.geometry.getIndex();
        const positionAttribute = this.geometry.getAttribute('position');
        const vertexCount = indexAttribute ? indexAttribute.count : positionAttribute.count;
        const _vec1 = new THREE.Vector3();
        const _vec2 = new THREE.Vector3();
        const _vec3 = new THREE.Vector3();
        const _box = new THREE.Box3();

        for (let i = 0; i < vertexCount; i += 3) {
            const indexes = indexAttribute
                ? ([indexAttribute.getX(i), indexAttribute.getX(i + 1), indexAttribute.getX(i + 2)] as const)
                : ([i, i + 1, i + 2] as const);
            _vec1.fromBufferAttribute(positionAttribute, indexes[0]);
            _vec2.fromBufferAttribute(positionAttribute, indexes[1]);
            _vec3.fromBufferAttribute(positionAttribute, indexes[2]);
            _vec1.applyMatrix4(this.mesh.matrixWorld);
            _vec2.applyMatrix4(this.mesh.matrixWorld);
            _vec3.applyMatrix4(this.mesh.matrixWorld);
            setVectorsBox(_box, _vec1, _vec2, _vec3);
            const intersectGridBoxes = getIntersectGridBox(_box);
            intersectGridBoxes.forEach((gridBox) => gridBox.triangleIndexes.push(i));
        }

        function setVectorsBox(
            box: THREE.Box3,
            vectorA: THREE.Vector3,
            vectorB: THREE.Vector3,
            vectorC: THREE.Vector3,
        ) {
            box.makeEmpty();
            box.min.min(vectorA);
            box.min.min(vectorB);
            box.min.min(vectorC);
            box.max.max(vectorA);
            box.max.max(vectorB);
            box.max.max(vectorC);
        }

        function getIntersectGridBox(box: THREE.Box3) {
            const rootSize = new THREE.Vector3();
            const rootBox = self.rootBox;
            const result: GridBox[] = [];
            rootBox.getSize(rootSize);

            const xStart = Math.floor(((box.min.x - rootBox.min.x) / rootSize.x) * self.xGrid);
            const yStart = Math.floor(((box.min.y - rootBox.min.y) / rootSize.y) * self.yGrid);
            const zStart = Math.floor(((box.min.z - rootBox.min.z) / rootSize.z) * self.zGrid);
            const xEnd = Math.ceil(((box.max.x - rootBox.min.x) / rootSize.x) * self.xGrid) - 1;
            const yEnd = Math.ceil(((box.max.y - rootBox.min.y) / rootSize.y) * self.yGrid) - 1;
            const zEnd = Math.ceil(((box.max.z - rootBox.min.z) / rootSize.z) * self.zGrid) - 1;

            for (let x = xStart; x <= xEnd; x++) {
                for (let y = yStart; y <= yEnd; y++) {
                    for (let z = zStart; z <= zEnd; z++) {
                        const gridBox = self.gridBoxByName[`x${x}y${y}z${z}`];
                        if (gridBox.box.intersectsBox(box)) {
                            result.push(gridBox);
                        }
                    }
                }
            }
            return result;
        }
    }

    private getIntersectGridBox(raycaster: THREE.Raycaster) {
        const ray = raycaster.ray;
        const _vec3 = new THREE.Vector3();
        return this.gridBoxes
            .map((gridBox) => {
                const result = ray.intersectBox(gridBox.box, _vec3);
                const distance = result ? _vec3.distanceTo(ray.origin) : Infinity;
                return { gridBox, distance, result };
            })
            .filter((intersect) => intersect.result)
            .sort((a, b) => (a.distance > b.distance ? 1 : -1))
            .map((intersect) => intersect.gridBox);
    }

    private getVerticeIndexes(index: number): [number, number, number] {
        const indexAttribute = this.geometry.getIndex();
        return indexAttribute
            ? [indexAttribute.getX(index), indexAttribute.getX(index + 1), indexAttribute.getX(index + 2)]
            : [index, index + 1, index + 2];
    }

    private setTriangleByVerticeIndexes(triangle: THREE.Triangle, ...verticeIndex: [number, number, number]) {
        const positionAttribute = this.geometry.getAttribute('position');
        triangle.a.fromBufferAttribute(positionAttribute, verticeIndex[0]);
        triangle.b.fromBufferAttribute(positionAttribute, verticeIndex[1]);
        triangle.c.fromBufferAttribute(positionAttribute, verticeIndex[2]);
        triangle.a.applyMatrix4(this.mesh.matrixWorld);
        triangle.b.applyMatrix4(this.mesh.matrixWorld);
        triangle.c.applyMatrix4(this.mesh.matrixWorld);
    }
}

class GridBox {
    public triangleIndexes: number[];
    public box: THREE.Box3;
    public name: string;
    constructor(name: string) {
        this.name = name;
        this.triangleIndexes = [];
        this.box = new THREE.Box3();
    }
}
