import * as THREE from "three";
import { STLLoader } from "three/examples/jsm/loaders/STLLoader";
import { VTKLoader } from "three/examples/jsm/loaders/VTKLoader";
import {
    AVENDA_DEFAULT_ABLATION_PROFILE_USER_UUID,
    AVENDA_DEFAULT_ABLATION_PROFILE_UUID,
    sceneOrientations,
    volumeNames,
} from "../../../constants";
import { PLANE } from "./consts";
import { getPHISignedURLS3 } from "../../../helpers/backend_api";
import {
    getAblationProfileWithVolumes,
    getAblationVolumeMeshURI,
} from "../../CreateAblationProfile/helpers";
import { getSignedUrl, toArray } from "../../../helpers/helpers";
import {
    computeBoundsTree,
    disposeBoundsTree,
    acceleratedRaycast,
} from "three-mesh-bvh";

THREE.BufferGeometry.prototype.computeBoundsTree = computeBoundsTree;
THREE.BufferGeometry.prototype.disposeBoundsTree = disposeBoundsTree;
let defaultRaycast = THREE.Mesh.prototype.raycast;
THREE.Mesh.prototype.raycast = acceleratedRaycast;

export function isPointInsideBounds(queryPt, boundingCone) {
    if (queryPt == null || boundingCone == null) {
        return;
    }

    let coneTip = new THREE.Vector3(
        boundingCone.vertex.x,
        boundingCone.vertex.y,
        boundingCone.vertex.z
    ); // the tip of the cone
    let coneDir = new THREE.Vector3(
        boundingCone.direction.x,
        boundingCone.direction.y,
        boundingCone.direction.z
    ); // the normalized axis vector, pointing from the tip to the base
    let coneHeight = boundingCone.height; // height
    let coneRadius = boundingCone.radius; // base radius

    queryPt = queryPt.clone();
    queryPt.sub(coneTip);

    let coneDist = queryPt.dot(coneDir);

    if (coneDist <= 0 || coneDist >= coneHeight) {
        return false;
    }

    coneRadius = (coneDist / coneHeight) * coneRadius;

    coneDir.multiplyScalar(coneDist);
    let orthDistance = queryPt.sub(coneDir).length();

    let isPointInsideCone = orthDistance < coneRadius;

    return isPointInsideCone;
}

export function isPointInsideMRIVolume(queryPt, mriModel) {
    return mriModel.boundingBox.containsPoint(queryPt);
}

export function getAdjacencyListFromMesh(faces) {
    const neighbors = {};
    for (let i = 0; i < faces.length; i += 3) {
        let a = faces[i + 0];
        let b = faces[i + 1];
        let c = faces[i + 2];

        neighbors[parseInt(a)] = neighbors[parseInt(a)]
            ? neighbors[parseInt(a)].add(b).add(c)
            : new Set().add(b).add(c);
        neighbors[parseInt(b)] = neighbors[parseInt(b)]
            ? neighbors[parseInt(b)].add(c).add(a)
            : new Set().add(c).add(a);
        neighbors[parseInt(c)] = neighbors[parseInt(c)]
            ? neighbors[parseInt(c)].add(a).add(b)
            : new Set().add(a).add(b);
    }

    return neighbors;
}

export function getRaycasterIntersectionWithPlane(raycaster, clipPlane) {
    let intersect = new THREE.Vector3();
    raycaster.ray.intersectPlane(clipPlane, intersect);
    return intersect;
}

export function getRaycasterIntersections(raycaster, objectArray) {
    let intersects = [];

    for (const i of objectArray) {
        i.raycast(raycaster, intersects);
    }

    if (intersects && intersects.length > 0) {
        intersects.sort((a, b) => a.distance - b.distance);
    }

    return intersects;
}

export function getRaycasterFirstIntersection(raycaster, objectArray) {
    let intersects = [];

    for (const i of objectArray) {
        i.raycast(raycaster, intersects);
    }

    if (intersects && intersects.length > 0) {
        intersects.sort((a, b) => a.distance - b.distance);

        return intersects[0].object;
    }

    return null;
}

export function getRaycasterFirstIntersectionBehindClipPlane(
    raycaster,
    objectArray,
    clipPlane
) {
    let intersects = [];

    for (const i of objectArray) {
        i.raycast(raycaster, intersects);
    }

    if (intersects) {
        intersects.sort((a, b) => a.distance - b.distance);

        for (let i = intersects.length - 1; i >= 0; i--) {
            let dist2Plane = clipPlane.distanceToPoint(
                intersects[Number(i)].point
            );

            if (
                dist2Plane <= 0 &&
                clipPlane.intersectsBox(
                    intersects[Number(i)].StructureModel.boundingBox
                )
            ) {
                return intersects[Number(i)].StructureModel;
            }
        }
    }

    return null;
}

export function getPointProjectedOntoPlane(point, plane) {
    let projection = new THREE.Vector3();

    if (plane) {
        plane.projectPoint(point, projection);
    }

    return projection;
}

export function getIntersectionsBetweenPlaneAndMesh(
    mathPlane,
    faces,
    vertices
) {
    let pointsOfIntersection = [];
    let a = new THREE.Vector3();
    let b = new THREE.Vector3();
    let c = new THREE.Vector3();
    let lineAB = new THREE.Line3();
    let lineBC = new THREE.Line3();
    let lineCA = new THREE.Line3();

    for (let i = 0; i < faces.length; i += 3) {
        a.copy(
            new THREE.Vector3(
                vertices[Math.floor(faces[parseInt(i)] * 3)],
                vertices[Math.floor(faces[parseInt(i)] * 3) + 1],
                vertices[Math.floor(faces[parseInt(i)] * 3) + 2]
            )
        );
        b.copy(
            new THREE.Vector3(
                vertices[Math.floor(faces[parseInt(i + 1)] * 3)],
                vertices[Math.floor(faces[parseInt(i + 1)] * 3) + 1],
                vertices[Math.floor(faces[parseInt(i + 1)] * 3) + 2]
            )
        );
        c.copy(
            new THREE.Vector3(
                vertices[Math.floor(faces[parseInt(i + 2)] * 3)],
                vertices[Math.floor(faces[parseInt(i + 2)] * 3) + 1],
                vertices[Math.floor(faces[parseInt(i + 2)] * 3) + 2]
            )
        );

        lineAB = new THREE.Line3(a, b);
        lineBC = new THREE.Line3(b, c);
        lineCA = new THREE.Line3(c, a);

        addPointOfIntersectionToList(lineAB, mathPlane, pointsOfIntersection);
        addPointOfIntersectionToList(lineBC, mathPlane, pointsOfIntersection);
        addPointOfIntersectionToList(lineCA, mathPlane, pointsOfIntersection);
    }

    return pointsOfIntersection;
}

export function addPointOfIntersectionToList(
    line,
    plane,
    pointsOfIntersection
) {
    const emptyVector = new THREE.Vector3(0, 0, 0);

    let pointOfIntersection = new THREE.Vector3();
    plane.intersectLine(line, pointOfIntersection);

    if (pointOfIntersection && !pointOfIntersection.equals(emptyVector)) {
        pointsOfIntersection.push(pointOfIntersection.clone());
    }
}

export function getAverageDistanceBetweenQueryPointAndPointCloud(
    pointCloud,
    queryPt
) {
    let averageDistance = 0;

    if (pointCloud) {
        averageDistance = pointCloud.reduce(
            (sum, point) => sum + point.distanceTo(queryPt),
            0
        );

        averageDistance /= pointCloud.length;
    }

    return averageDistance;
}

export function getCancerLesionCoverage(viewer) {
    let cache = {};

    return () => {
        let cancerLesion = viewer.scene.getObjectByName(
            volumeNames.MR_MARGIN_AI
        );
        if (!cancerLesion) return 0;

        let cancerLesionBounds = new THREE.Box3().setFromObject(cancerLesion);
        // Expand by a small scalar to avoid edge cases
        cancerLesionBounds = cancerLesionBounds.clone().expandByScalar(0.1);
        const subdivisions = 40;

        let tools = [];
        for (let obj of viewer.RASframe.children) {
            if (obj.name && obj.name.startsWith(volumeNames.MR_TARGETS)) {
                tools.push(obj);
            }
        }

        return computeCancerLesionOverlap(
            cancerLesion,
            cancerLesionBounds,
            tools,
            subdivisions,
            cache
        );
    };
}

export function computeCancerLesionOverlap(
    cancerLesion,
    cancerLesionBounds,
    tools,
    subdivisions,
    cache = {},
    useDefaultRaycast = false
) {
    if (tools.length === 0) return 0;

    if (useDefaultRaycast) THREE.Mesh.prototype.raycast = defaultRaycast;

    let isCacheValid = cache["lesionUUID"] === cancerLesion.uuid;

    if (!isCacheValid) initializeCache(cache, cancerLesion.uuid);

    let emptyRaycasts = cache["emptyRaycasts"];
    let bboxes = cache["bboxes"];
    let exclusivelyIntraLesionCellCoverages =
        cache["exclusivelyIntraLesionCellCoverages"];

    let intersectionScene = [cancerLesion, ...tools];

    let testPoints = [];
    if (isCacheValid) {
        _pushDirtyPoints(
            tools,
            bboxes,
            cancerLesionBounds,
            subdivisions,
            emptyRaycasts,
            testPoints
        );
    } else {
        _pushAllPoints(subdivisions, cancerLesionBounds, testPoints);
    }

    for (let point of testPoints) {
        let ray = new THREE.Raycaster();
        ray.set(point, new THREE.Vector3(0, 0, -1));
        let intersects = ray.intersectObjects(intersectionScene);

        intersects = removeDuplicateIntersections(intersects);

        let [cellTotalLesionSize, cellExclusivelyIntraLesionSize] =
            _measureCoverageForCell(intersects, cancerLesion, tools);

        // Total lesion size only needs to be computed once
        if (!isCacheValid) {
            cache["totalLesionSize"] += cellTotalLesionSize;
        }

        exclusivelyIntraLesionCellCoverages.set(
            _vectorToString(point),
            cellExclusivelyIntraLesionSize
        );

        if (cellTotalLesionSize === 0) {
            emptyRaycasts.add(point.toArray().toString());
        }
    }

    let exclusivelyIntraLesionSize = 0;
    for (let size of exclusivelyIntraLesionCellCoverages.values()) {
        exclusivelyIntraLesionSize += size;
    }

    return 1 - exclusivelyIntraLesionSize / cache["totalLesionSize"];
}

function _pushAllPoints(subdivisions, cancerLesionBounds, testPoints) {
    for (let i = 0; i < subdivisions; i++) {
        for (let j = 0; j < subdivisions; j++) {
            let point = _getPoint(cancerLesionBounds, subdivisions, i, j);
            testPoints.push(point);
        }
    }
}

function _pushDirtyPoints(
    tools,
    bboxes,
    cancerLesionBounds,
    subdivisions,
    emptyRaycasts,
    testPoints
) {
    let pointSet = new Set();
    let checkedTools = new Set();

    for (let tool of tools) {
        checkedTools.add(tool.uuid);

        let toolBounds = _getToolBounds(tool);

        if (bboxes.has(tool.uuid)) {
            let cachedToolBounds = bboxes.get(tool.uuid);

            if (
                _getBoundingBoxAsString(toolBounds) !==
                _getBoundingBoxAsString(cachedToolBounds)
            ) {
                // Add all points of new and old bounding box
                addTestPointsToSet(
                    cachedToolBounds,
                    cancerLesionBounds,
                    subdivisions,
                    emptyRaycasts,
                    pointSet
                );
                addTestPointsToSet(
                    toolBounds,
                    cancerLesionBounds,
                    subdivisions,
                    emptyRaycasts,
                    pointSet
                );

                bboxes.set(tool.uuid, toolBounds);
            }
        } else {
            addTestPointsToSet(
                toolBounds,
                cancerLesionBounds,
                subdivisions,
                emptyRaycasts,
                pointSet
            );
            bboxes.set(tool.uuid, toolBounds);
        }
    }

    // Check for any bounding boxes that have been deleted
    for (let [uuid, toolBounds] of bboxes) {
        if (!checkedTools.has(uuid)) {
            addTestPointsToSet(
                toolBounds,
                cancerLesionBounds,
                subdivisions,
                emptyRaycasts,
                pointSet
            );
            bboxes.set(uuid, toolBounds);
        }
    }

    // Convert back vector strings to three.js vectors
    pointSet.forEach((point) => testPoints.push(_stringToVector(point)));
}

function initializeCache(cache, uuid) {
    cache["emptyRaycasts"] = new Set();
    cache["bboxes"] = new Map();
    cache["lesionUUID"] = uuid;
    cache["exclusivelyIntraLesionCellCoverages"] = new Map();
    cache["totalLesionSize"] = 0;
}

function _vectorToString(vector) {
    return vector.toArray().toString();
}

function _stringToVector(string) {
    let components = string.split(",");
    let x = parseFloat(components[0]);
    let y = parseFloat(components[1]);
    let z = parseFloat(components[2]);
    return new THREE.Vector3(x, y, z);
}

function _getPoint(cancerLesionBounds, subdivisions, i, j) {
    let x =
        cancerLesionBounds.min.x +
        ((cancerLesionBounds.max.x - cancerLesionBounds.min.x) / subdivisions) *
            i;

    let y =
        cancerLesionBounds.min.y +
        ((cancerLesionBounds.max.y - cancerLesionBounds.min.y) / subdivisions) *
            j;

    let z = cancerLesionBounds.max.z;

    let point = new THREE.Vector3(x, y, z);

    return point;
}

function addTestPointsToSet(
    bounds,
    outerBounds,
    subdivisions,
    ignoreList,
    pointSet
) {
    let points = [];

    for (let i = 0; i < subdivisions; i++) {
        for (let j = 0; j < subdivisions; j++) {
            let point = _getPoint(outerBounds, subdivisions, i, j);

            if (
                point.x > bounds.min.x &&
                point.x < bounds.max.x &&
                point.y > bounds.min.y &&
                point.y < bounds.max.y
            )
                points.push(point);
        }
    }

    points.forEach((point) => {
        if (!ignoreList.has(_vectorToString(point))) {
            pointSet.add(_vectorToString(point));
        }
    });
}

function _getToolBounds(object) {
    let toolMeshes = object.children.filter((child) => child.type === "Mesh");
    return new THREE.Box3().setFromObject(toolMeshes[0]);
}

function _getBoundingBoxAsString(bbox) {
    return _vectorToString(bbox.min) + _vectorToString(bbox.max);
}

function _measureCoverageForCell(intersects, cancerLesion, tools) {
    // Iterate over intersects adding segments of the ray that are
    // both inside the cancer lesion and not inside a target. The total is
    // the uncovered cancer lesion size.
    // At the same time, keep track of total cancer lesion length, then
    // compute their relative percentages.
    let exclusivelyIntraLesionSize = 0;
    let lesionSize = 0;

    let insideLesion = false;
    let insideTarget = new Set();

    let lastLesionPosition;
    for (let [index, intersect] of intersects.entries()) {
        // We either hit the front face or back face of an object.
        // The individual faces are children, therefore we take the parent object.
        let parentObject = intersect.object.parent;

        if (parentObject.uuid === cancerLesion.uuid) {
            insideLesion = !insideLesion;

            if (insideLesion) {
                lastLesionPosition = intersect.distance;
            } else {
                lesionSize += Math.abs(intersect.distance - lastLesionPosition);
            }
        }

        if (tools.includes(parentObject)) {
            if (insideTarget.has(parentObject))
                insideTarget.delete(parentObject);
            else insideTarget.add(parentObject);
        }

        if (insideLesion && insideTarget.size === 0) {
            // Handle a case where the duplicate search missed a duplicate
            // intersection due to floating point rounding error.
            if (!intersects[index + 1]) continue;

            exclusivelyIntraLesionSize += Math.abs(
                intersects[index + 1].distance - intersect.distance
            );
        }
    }

    return [lesionSize, exclusivelyIntraLesionSize];
}

export function removeDuplicateIntersections(intersects) {
    // Duplicate intersections arise when a ray intersects the line between
    // two faces. We keep a set of already intersected points for each mesh to
    // avoid duplicates.

    let intersectionSet = {};
    return intersects.filter((intersect) => {
        if (!intersectionSet[intersect.object.uuid]) {
            intersectionSet[intersect.object.uuid] = new Set();
            intersectionSet[intersect.object.uuid].add(
                _vectorToString(intersect.point)
            );
            return true;
        }

        let exists = intersectionSet[intersect.object.uuid].has(
            _vectorToString(intersect.point)
        );
        if (exists) {
            return false;
        } else {
            intersectionSet[intersect.object.uuid].add(
                _vectorToString(intersect.point)
            );
            return true;
        }
    });
}

export function convertCamAxesFromWorldToLocal(worldTransform, camAxes) {
    let axes = {};

    axes.sagAxis = camAxes.sagAxis.clone();
    axes.corAxis = camAxes.corAxis.clone();
    axes.axialAxis = camAxes.axialAxis.clone();

    axes.sagAxis = axes.sagAxis.applyMatrix4(worldTransform);
    axes.corAxis = axes.corAxis.applyMatrix4(worldTransform);
    axes.axialAxis = axes.axialAxis.applyMatrix4(worldTransform);

    axes.sagAxis = axes.sagAxis.normalize();
    axes.corAxis = axes.corAxis.normalize();
    axes.axialAxis = axes.axialAxis.normalize();

    return axes;
}

export function getQuaternionAlignedWithBounds(intersect, boundingCone) {
    let trgtDir = new THREE.Vector3(
        intersect.x - boundingCone.vertex.x,
        intersect.y - boundingCone.vertex.y,
        intersect.z - boundingCone.vertex.z
    );

    let upVec = new THREE.Vector3(0, 1, 0);
    let quaternion = new THREE.Quaternion();
    quaternion.setFromUnitVectors(upVec, trgtDir.clone().normalize());

    return quaternion;
}

export function getPositionTranslatedInPlane(
    currentPosition,
    adjustmentDistance,
    orientation,
    plane
) {
    let updatedPosition = new THREE.Vector3();

    updatedPosition.copy(currentPosition);

    switch (orientation) {
        case sceneOrientations.SAGITTAL:
            if (plane === PLANE.VERTICAL) {
                updatedPosition.z = currentPosition.z + adjustmentDistance;
            } else {
                updatedPosition.y = currentPosition.y - adjustmentDistance;
            }

            break;

        case sceneOrientations.CORONAL:
            if (plane === PLANE.VERTICAL) {
                updatedPosition.z = currentPosition.z + adjustmentDistance;
            } else {
                updatedPosition.x = currentPosition.x - adjustmentDistance;
            }

            break;

        case sceneOrientations.AXIAL:
            if (plane === PLANE.VERTICAL) {
                updatedPosition.y = currentPosition.y - adjustmentDistance;
            } else {
                updatedPosition.x = currentPosition.x - adjustmentDistance;
            }

            break;

        default:
            break;
    }

    return updatedPosition;
}

export async function loadMeshFromURI(input, uri) {
    getPHISignedURLS3({
        ...input,
        uris: toArray(uri),
    })
        .then((payload) => payload.json())
        .then(async (json) => {
            let url = getSignedUrl(json.payload.signedurls);
            let loadPromise = new Promise((resolve) => {
                fileLoadHelper(url, resolve);
            });

            await loadPromise.then((geometry) => geometry);
        });
}

export async function loadMeshFromURL(url) {
    let loadPromise = new Promise((resolve) => {
        fileLoadHelper(url, resolve);
    });

    await loadPromise.then((geometry) => geometry);
}

export async function multiUrlMeshLoad(urlObj) {
    let completedLoads = 0;
    let geometries = {};
    let urls = Object.values(urlObj);

    let loadPromise = new Promise((resolve) => {
        function loadComplete(loadedGeometry, id) {
            geometries[id.toString()] = loadedGeometry;
            completedLoads += 1;

            if (completedLoads === urls.length) {
                resolve(geometries);
            }
        }

        if (urlObj) {
            for (const key in urlObj) {
                if (Object.prototype.hasOwnProperty.call(urlObj, key)) {
                    fileLoadHelper(urlObj[key.toString()], loadComplete, key);
                }
            }
        }
    });

    return await loadPromise;
}

export async function getDefaultProfileVolumes(authToken) {
    let input = {
        authToken: authToken,
        userUuid: AVENDA_DEFAULT_ABLATION_PROFILE_USER_UUID,
        ablationProfileId: AVENDA_DEFAULT_ABLATION_PROFILE_UUID,
    };

    let defaultIDList = await getAblationProfileWithVolumes(input);

    let urlObj = {};

    for (const volID of defaultIDList.volumeUUIDs) {
        input.volumeUuid = volID;
        let response = await getAblationVolumeMeshURI(input);
        if (response) {
            urlObj[String(volID)] = response.downloadUri;
        }
    }

    return await multiUrlMeshLoad(urlObj);
}

async function fileLoadHelper(url, callbackFunc, id = undefined) {
    const isSTL = (url) => /\.stl$/.test(url.toLowerCase());

    let fileLoader;
    if (isSTL(url)) {
        fileLoader = new STLLoader();
    } else {
        fileLoader = new VTKLoader();
    }

    function loaded(loadedGeometry) {
        if (id) {
            callbackFunc(loadedGeometry, id);
        } else {
            callbackFunc(loadedGeometry);
        }
    }

    fileLoader.load(url, loaded);
}

export const LPS2RAS = new THREE.Matrix4().makeBasis(
    new THREE.Vector3(-1, 0, 0),
    new THREE.Vector3(0, 1, 0),
    new THREE.Vector3(0, 0, -1)
);
export const RAS2LPS = LPS2RAS.invert();
