import { useRef, useMemo } from 'react'
import { useFrame } from '@react-three/fiber'
import * as THREE from 'three'

// Previous imports and component definition remain the same...
// Only changing the fragment shader's main function to reduce center glow

export default function SparkleParticles({ 
    count = 500,
    position = [0, 0, 0],     // [x, y, z] position in scene
    spread = [2, 2, 2],       // How far particles spread from center [x, y, z]
    rotationSpeed = 1,        // Speed of overall rotation (default: 1)
    twinkleSpeed = 1,         // Speed of twinkling animation (default: 1)
    baseOpacity = 0.6,        // Base opacity of particles (0-1)
    opacityRange = 0.8,       // How much opacity varies (0-1)
    noiseSpeed = 1,           // Speed of noise animation (default: 1)
    noiseIntensity = 0.5,     // Intensity of noise effect (0-1)
    starScale = 0.15,         // Base size of star points (default: 0.15)
    pulseSpeed = 1,           // Speed of star size animation (default: 1)
    pointCount = 12,          // Number of star points (default: 12)
    pulseMin = 0.1,           // Minimum pulse size (default: 0.1)
    pulseMax = 2.0,           // Maximum pulse size (default: 2.0)
    pointLength = 0.7,        // Length of star points (default: 0.7)
    disabled = false          // Option to pause animation
}) {
    // Previous setup code remains the same...
    const points = useRef()
    
    const particlesPosition = useMemo(() => {
        const positions = new Float32Array(count * 3)
        for (let i = 0; i < count; i++) {
            positions[i * 3] = (Math.random() - 0.5) * spread[0]
            positions[i * 3 + 1] = (Math.random() - 0.5) * spread[1]
            positions[i * 3 + 2] = (Math.random() - 0.5) * spread[2]
        }
        return positions
    }, [count, spread])

    const phases = useMemo(() => {
        return new Float32Array(count).map(() => Math.random() * Math.PI * 2)
    }, [count])

    const speeds = useMemo(() => {
        return new Float32Array(count).map(() => Math.random() * 0.5 + 0.75)
    }, [count])

    const shaderMaterial = useMemo(() => {
        return new THREE.ShaderMaterial({
            vertexShader: `
                attribute float speed;
                attribute float phase;
                varying float vSpeed;
                varying float vPhase;
                varying vec3 vPosition;
                uniform float uTwinkleSpeed;
                
                void main() {
                    vPosition = position;
                    vSpeed = speed * uTwinkleSpeed;
                    vPhase = phase;
                    vec4 mvPosition = modelViewMatrix * vec4(position, 1.0);
                    gl_PointSize = 12.0 * (300.0 / -mvPosition.z);
                    gl_Position = projectionMatrix * mvPosition;
                }
            `,
            fragmentShader: `
                varying float vSpeed;
                varying float vPhase;
                varying vec3 vPosition;
                uniform float uTime;
                uniform float uBaseOpacity;
                uniform float uOpacityRange;
                uniform float uNoiseSpeed;
                uniform float uNoiseIntensity;
                uniform float uStarScale;
                uniform float uPulseSpeed;
                uniform float uPointCount;
                uniform float uPulseMin;
                uniform float uPulseMax;
                uniform float uPointLength;

                // Noise functions
                vec3 mod289(vec3 x) {
                    return x - floor(x * (1.0 / 289.0)) * 289.0;
                }
                
                vec4 mod289(vec4 x) {
                    return x - floor(x * (1.0 / 289.0)) * 289.0;
                }
                
                vec4 permute(vec4 x) {
                    return mod289(((x * 34.0) + 1.0) * x);
                }
                
                vec4 taylorInvSqrt(vec4 r) {
                    return 1.79284291400159 - 0.85373472095314 * r;
                }
                
                vec3 fade(vec3 t) {
                    return t * t * t * (t * (t * 6.0 - 15.0) + 10.0);
                }
                
                float noise(vec3 P) {
                    vec3 Pi0 = floor(P);
                    vec3 Pi1 = Pi0 + vec3(1.0);
                    Pi0 = mod289(Pi0);
                    Pi1 = mod289(Pi1);
                    vec3 Pf0 = fract(P);
                    vec3 Pf1 = Pf0 - vec3(1.0);
                    vec4 ix = vec4(Pi0.x, Pi1.x, Pi0.x, Pi1.x);
                    vec4 iy = vec4(Pi0.yy, Pi1.yy);
                    vec4 iz0 = Pi0.zzzz;
                    vec4 iz1 = Pi1.zzzz;
                
                    vec4 ixy = permute(permute(ix) + iy);
                    vec4 ixy0 = permute(ixy + iz0);
                    vec4 ixy1 = permute(ixy + iz1);
                
                    vec4 gx0 = ixy0 * (1.0 / 7.0);
                    vec4 gy0 = fract(floor(gx0) * (1.0 / 7.0)) - 0.5;
                    gx0 = fract(gx0);
                    vec4 gz0 = vec4(0.5) - abs(gx0) - abs(gy0);
                    vec4 sz0 = step(gz0, vec4(0.0));
                    gx0 -= sz0 * (step(0.0, gx0) - 0.5);
                    gy0 -= sz0 * (step(0.0, gy0) - 0.5);
                
                    vec4 gx1 = ixy1 * (1.0 / 7.0);
                    vec4 gy1 = fract(floor(gx1) * (1.0 / 7.0)) - 0.5;
                    gx1 = fract(gx1);
                    vec4 gz1 = vec4(0.5) - abs(gx1) - abs(gy1);
                    vec4 sz1 = step(gz1, vec4(0.0));
                    gx1 -= sz1 * (step(0.0, gx1) - 0.5);
                    gy1 -= sz1 * (step(0.0, gy1) - 0.5);
                
                    vec3 g000 = vec3(gx0.x, gy0.x, gz0.x);
                    vec3 g100 = vec3(gx0.y, gy0.y, gz0.y);
                    vec3 g010 = vec3(gx0.z, gy0.z, gz0.z);
                    vec3 g110 = vec3(gx0.w, gy0.w, gz0.w);
                    vec3 g001 = vec3(gx1.x, gy1.x, gz1.x);
                    vec3 g101 = vec3(gx1.y, gy1.y, gz1.y);
                    vec3 g011 = vec3(gx1.z, gy1.z, gz1.z);
                    vec3 g111 = vec3(gx1.w, gy1.w, gz1.w);
                
                    vec4 norm0 = taylorInvSqrt(vec4(dot(g000, g000), dot(g010, g010), dot(g100, g100), dot(g110, g110)));
                    g000 *= norm0.x;
                    g010 *= norm0.y;
                    g100 *= norm0.z;
                    g110 *= norm0.w;
                    vec4 norm1 = taylorInvSqrt(vec4(dot(g001, g001), dot(g011, g011), dot(g101, g101), dot(g111, g111)));
                    g001 *= norm1.x;
                    g011 *= norm1.y;
                    g101 *= norm1.z;
                    g111 *= norm1.w;
                
                    float n000 = dot(g000, Pf0);
                    float n100 = dot(g100, vec3(Pf1.x, Pf0.yz));
                    float n010 = dot(g010, vec3(Pf0.x, Pf1.y, Pf0.z));
                    float n110 = dot(g110, vec3(Pf1.xy, Pf0.z));
                    float n001 = dot(g001, vec3(Pf0.xy, Pf1.z));
                    float n101 = dot(g101, vec3(Pf1.x, Pf0.y, Pf1.z));
                    float n011 = dot(g011, vec3(Pf0.x, Pf1.yz));
                    float n111 = dot(g111, Pf1);
                
                    vec3 fade_xyz = fade(Pf0);
                    vec4 n_z = mix(vec4(n000, n100, n010, n110), vec4(n001, n101, n011, n111), fade_xyz.z);
                    vec2 n_yz = mix(n_z.xy, n_z.zw, fade_xyz.y);
                    float n_xyz = mix(n_yz.x, n_yz.y, fade_xyz.x);
                    return 2.2 * n_xyz;
                }

                // Enhanced star shape function for longer, thinner points
                float starShape(vec2 p, float size) {
                    float angle = atan(p.y, p.x);
                    float r = length(p);
                    
                    // Create very thin primary points
                    float primaryPoints = abs(cos(angle * uPointCount * 0.5));
                    primaryPoints = pow(primaryPoints, 4.0); // Much sharper falloff
                    
                    // Create thin secondary points
                    float secondaryPoints = abs(cos(angle * uPointCount * 0.5 + 3.14159 / uPointCount));
                    secondaryPoints = pow(secondaryPoints, 4.0);
                    
                    // Use max instead of mix for sharper transitions
                    float star = max(primaryPoints, secondaryPoints * 0.8);
                    
                    // Apply length control with sharp falloff
                    star = pow(star, uPointLength);
                    
                    return star * size;
                }

                // Core glow function
                float coreGlow(float dist, float intensity) {
                    // Very tight center glow
                    return pow(1.0 - smoothstep(0.0, 0.1, dist), 1.0) * intensity;
                }
                
                void main() {
                    // Enhanced pulse effect
                    float starTime = uTime * uPulseSpeed + vPhase;
                    float pulse = sin(starTime) * 0.5 + 0.5;
                    float sizePulse = mix(uPulseMin, uPulseMax, pulse);
                    float opacityPulse = pulse * pulse;
                    
                    // Base twinkle effect
                    float twinkle = sin(vSpeed * uTime + vPhase) * 0.5 + 0.5;
                    twinkle = pow(twinkle, 2.0);
                    
                    vec2 center = gl_PointCoord - vec2(0.5);
                    float dist = length(center);
                    
                    // Create star shape with minimal glow
                    float starSize = uStarScale * sizePulse;
                    float star = starShape(center, starSize);
                    
                    // Combine star shape with tiny core glow
                    float sparkle = 1.0 - smoothstep(0.0, 0.05, dist - star * 0.15); // Tighter core
                    sparkle = pow(sparkle, 3.0); // Much sharper falloff
                    
                    // Add very small center glow
                    float core = coreGlow(dist, 0.3) * opacityPulse;
                    
                    // Add minimal noise only to brightness, not shape
                    vec3 noiseCoord = vec3(
                        gl_PointCoord * 8.0,
                        uTime * uNoiseSpeed
                    );
                    float noiseVal = noise(noiseCoord);
                    noiseVal = pow(noiseVal * 0.5 + 0.5, 2.0);
                    
                    // Combine effects with minimal glow
                    float combinedEffect = max(sparkle, core) * (twinkle + noiseVal * uNoiseIntensity * 0.3);
                    combinedEffect *= opacityPulse;
                    
                    // Calculate final opacity with very sharp falloff
                    float opacity = uBaseOpacity + (combinedEffect * uOpacityRange);
                    opacity *= opacityPulse;
                    
                    // Color with minimal spread
                    vec3 color = mix(
                        vec3(0.8, 0.6, 0.3),  // warm gold
                        vec3(0.8, 0.8, 0.4),   // bright yellow
                        combinedEffect
                    );
                    
                    // Add sharp highlights at point tips only
                    float tipHighlight = pow(star, 8.0) * sparkle;
                    color += vec3(0.3) * tipHighlight;
                    
                    gl_FragColor = vec4(color, sparkle * opacity);
                }
            `.replace('${previousNoiseCode}', /* Previous noise functions */),
            transparent: true,
            blending: THREE.AdditiveBlending,
            depthWrite: false,
            uniforms: {
                uTime: { value: 0 },
                uTwinkleSpeed: { value: twinkleSpeed },
                uBaseOpacity: { value: baseOpacity },
                uOpacityRange: { value: opacityRange },
                uNoiseSpeed: { value: noiseSpeed },
                uNoiseIntensity: { value: noiseIntensity },
                uStarScale: { value: starScale },
                uPulseSpeed: { value: pulseSpeed },
                uPointCount: { value: pointCount },
                uPulseMin: { value: pulseMin },
                uPulseMax: { value: pulseMax },
                uPointLength: { value: pointLength }
            }
        })
    }, [twinkleSpeed, baseOpacity, opacityRange, noiseSpeed, noiseIntensity, 
        starScale, pulseSpeed, pointCount, pulseMin, pulseMax, pointLength])

    useFrame((state, delta) => {
        if (!disabled) {
            points.current.rotation.x += delta * 0.05 * rotationSpeed
            points.current.rotation.y += delta * 0.075 * rotationSpeed
            shaderMaterial.uniforms.uTime.value = state.clock.elapsedTime
        }
    })

    return (
        <group position={position}>
            <points ref={points}>
                <bufferGeometry>
                    <bufferAttribute
                        attach="attributes-position"
                        count={count}
                        array={particlesPosition}
                        itemSize={3}
                    />
                    <bufferAttribute
                        attach="attributes-speed"
                        count={count}
                        array={speeds}
                        itemSize={1}
                    />
                    <bufferAttribute
                        attach="attributes-phase"
                        count={count}
                        array={phases}
                        itemSize={1}
                    />
                </bufferGeometry>
                <primitive object={shaderMaterial} attach="material" />
            </points>
        </group>
    )
}
