import {
  Vector3,
  DataTexture,
  RGBAFormat,
  FloatType,
  ShaderMaterial,
} from "three";
import { extend, MaterialNode } from "@react-three/fiber";

const glsl = require("glslify");

const size = 156;

function initPos(count: number, p = new Vector3()) {
  const data = new Float32Array(count * 4);

  for (let i = 0; i < count * 4; i += 4) {
    p.set(i / size, i / size, 0);
    p.toArray(data, i);
  }

  return data;
}

const dataTexture = new DataTexture(
  initPos(size * size),
  size,
  size,
  RGBAFormat,
  FloatType
);

dataTexture.needsUpdate = true;

class SimulationMaterial extends ShaderMaterial {
  constructor() {
    super({
      vertexShader: `varying vec2 vUv;
      void main() {
        vUv = uv;
        gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
      }`,
      fragmentShader: glsl`
      
      #define M_PI 3.1415926535897932384626433832795

      uniform sampler2D positions;
      uniform float speed;
      uniform float uTime;
      uniform float noiseFreq;
      uniform float noiseSize;
      uniform float noiseSpeed;
      uniform float noiseAdd;
      uniform sampler2D prevPositions;
      uniform bool stop;

      varying vec2 vUv;
      #pragma glslify: curl = require(glsl-curl-noise2)
      #pragma glslify: snoise3 = require(glsl-noise/simplex/3d)
      #pragma glslify: snoise2 = require(glsl-noise/simplex/2d)
      #pragma glslify: cnoise3 = require(glsl-noise/classic/3d)

      float random (vec2 st) {
        //return fract(sin(dot(st.xy, vec2(12.9898,78.233))) * 43758.5453123);
        float r = fract(sin(dot(st.xy, vec2(12.9898,78.233))) * 43758.5453123);
        return (r-0.5)*2.;
      }
      
      const float range = 1.;

      void main() {

        vec4 tex = texture2D(positions, vUv);
        vec3 pos = tex.rgb;
        float init_info = tex.a;

        if(stop){

          vec3 vel;

          float freq = 5.;

          if(pos.x > range || pos.x < -range || pos.y > 1. || pos.y < -1. || pos.z > 1. || pos.x < -range){
            float r_x = random(pos.xy)*range;
            float r_y = (sin(r_x * freq))*0.5 + random(pos.yz)*0.02;
            float r_z = (cos(r_x * freq))*0.5 + random(pos.xz)*0.02;
            pos = vec3(r_x, r_y, r_z);

            init_info = (r_y * r_z) * 1.2;
        }
        
        float dist_x = pos.x + 0.1;
        float dist_y = (sin(dist_x * freq))*0.5;
        float dist_z = (cos(dist_x * freq))*0.5;

        vec3 mag = vec3(dist_x, dist_y, dist_z) - pos;

        vel.x += 0.001;
        vel += mag*0.005;
        vel += curl(pos*noiseFreq+uTime*noiseSpeed)*noiseSize;
        vel += curl(pos*noiseFreq*0.1*noiseSpeed+uTime)*noiseAdd*0.001;

        pos += vel * speed;
      }
        gl_FragColor = vec4(pos, init_info);
      }`,
      uniforms: {
        positions: {
          value: dataTexture,
        },
        speed: { value: 1 },
        uTime: { value: 0 },
        noiseFreq: { value: 3 },
        noiseSize: { value: 0.001 },
        noiseSpeed: { value: 1 },
        noiseAdd: { value: 0 },
        stop: { value: false },
      },
    });
  }

  set setPositions(v: any) {
    this.uniforms.positions.value = v;
  }

  set setSpeed(v: any) {
    this.uniforms.speed.value = v;
  }

  set setNoiseFreq(v: any) {
    this.uniforms.noiseFreq.value = v;
  }

  set setNoiseSize(v: number) {
    this.uniforms.noiseSize.value = v;
  }

  set setNoiseSpeed(v: number) {
    this.uniforms.noiseSpeed.value = v;
  }

  set setNoiseAdd(v: number) {
    this.uniforms.noiseAdd.value = v;
  }

  set setUTime(v: number) {
    this.uniforms.uTime.value = v;
  }

  set setStop(v: boolean) {
    this.uniforms.stop.value = v;
  }
}

extend({ SimulationMaterial });

declare global {
  namespace JSX {
    interface IntrinsicElements {
      simulationMaterial: MaterialNode<
        SimulationMaterial,
        typeof SimulationMaterial
      >;
    }
  }
}

<simulationMaterial />;
