import numpy as np
from pywarpx import picmi
from space_analysis.simulation.warpx import HybridSimulation

constants = picmi.constants
class EMModes(HybridSimulation):
    """The following runs a simulation of an uniform plasma at a set
    temperature (Te = Ti) with an external magnetic field applied in either the
    z-direction (parallel to domain) or x-direction (perpendicular to domain).
    The analysis script (in this same directory) analyzes the output field data
    for EM modes. This input is based on the EM modes tests as described by
    Munoz et al. (2018) and tests done by Scott Nicks at TAE Technologies.
    """

    # Applied field parameters
    dim: int = 2
    B_dir: str = "z"  # Direction of the initial magnetic field
    B0: float = 0.25  # Initial magnetic field strength (T)
    betas: list[float] = [0.01, 0.1]  # Plasma beta, used to calculate temperature

    # Plasma species parameters
    m_ion_norms: list[float] = [100.0, 400.0]  # Ion mass (electron masses)
    vA_over_cs: list[float] = [
        1e-4,
        1e-3,
    ]  # ratio of Alfven speed and the speed of light

    # Spatial domain
    nz: int = 1024  # number of cells in z direction
    nx: int = 8  # number of cells in x (and y) direction for >1 dimensions

    # Temporal domain (if not run as a CI test)
    time_norm: float = 300.0  # Simulation temporal length (ion cyclotron periods)

    # Numerical parameters
    nppc: int = 256  # Seed number of particles per cell
    dz_norm: float = 1.0 / 10.0  # Cell size (ion skin depths)
    dt_norms: list[float] = [5e-3, 4e-3]  # Time step (ion cyclotron periods)

    # Plasma resistivity - used to dampen the mode excitation
    eta: float = 1e-7
    # Number of substeps used to update B
    substeps: int = 20

    def model_post_init(self, __context):
        """Get input parameters for the specific case desired."""

        # get simulation parameters from the defaults given the direction of
        # the initial B-field and the dimensionality
        self.get_simulation_parameters()

        # calculate various plasma parameters based on the simulation input
        self.get_plasma_quantities()

        # output diagnostics 20 times per cyclotron period
        self.diag_steps = int(1.0 / 20 / self.dt_norm)

        super().model_post_init(__context)

        self.setup_run()

    def get_simulation_parameters(self):
        """Pick appropriate parameters from the defaults given the direction
        of the B-field and the simulation dimensionality."""
        if self.B_dir == "z":
            idx = 0
            self.Bx = 0.0
            self.By = 0.0
            self.Bz = self.B0
        elif self.B_dir == "y":
            idx = 1
            self.Bx = 0.0
            self.By = self.B0
            self.Bz = 0.0
        else:
            idx = 1
            self.Bx = self.B0
            self.By = 0.0
            self.Bz = 0.0

        self.m_ion_norm = self.m_ion_norms[idx]
        self.beta = self.betas[idx]
        self.vA_over_c = self.vA_over_cs[idx]
        self.dt_norm = self.dt_norms[idx]

    def get_plasma_quantities(self):
        """Calculate various plasma parameters based on the simulation input."""
        # Ion mass (kg)
        self.m_ion = self.m_ion_norm * constants.m_e

        # Cyclotron angular frequency (rad/s) and period (s)
        self.w_ci = constants.q_e * abs(self.B0) / self.m_ion
        self.t_ci = 2.0 * np.pi / self.w_ci

        # Alfven speed (m/s): vA = B / sqrt(mu0 * n * (M + m)) = c * omega_ci / w_pi
        self.vA = self.vA_over_c * constants.c
        self.n_plasma = (self.B0 / self.vA) ** 2 / (
            constants.mu0 * (self.m_ion + constants.m_e)
        )

        # Ion plasma frequency (Hz)
        self.w_pi = np.sqrt(
            constants.q_e**2 * self.n_plasma / (self.m_ion * constants.ep0)
        )

        # Skin depth (m)
        self.d_i = constants.c / self.w_pi

        # Ion thermal velocity (m/s) from beta = 2 * (v_ti / vA)**2
        self.v_ti = np.sqrt(self.beta / 2.0) * self.vA

        # Temperature (eV) from thermal speed: v_ti = sqrt(kT / M)
        self.T_plasma = self.v_ti**2 * self.m_ion / constants.q_e  # eV

        self.n0 = self.n_plasma
        self.Te = self.T_plasma

    def setup_field(self):
        """Setup external field"""
        B_ext = picmi.AnalyticInitialField(
            Bx_expression=self.Bx, By_expression=self.By, Bz_expression=self.Bz
        )
        self._sim.add_applied_field(B_ext)

    def setup_particle(self):
        self.ions = picmi.Species(
            name="ions",
            charge="q_e",
            mass=self.m_ion,
            initial_distribution=picmi.UniformDistribution(
                density=self.n_plasma,
                rms_velocity=[self.v_ti] * 3,
            ),
        )
        self._sim.add_species(
            self.ions,
            layout=picmi.PseudoRandomLayout(
                grid=self._grid, n_macroparticles_per_cell=self.nppc
            ),
        )
        return self

    def setup_run(self):
        """Setup simulation components."""
        super().setup_run()

        #######################################################################
        # Initialize self._sim                                               #
        #######################################################################
        self._sim.write_input_file("inputs_em_modes")
sim = EMModes()
Numerical parameters:
    dt = 7.1e-11 s
    total steps = 60000

Initializing simulation with input parameters:
    Te = 0.003 eV
    n = 6.0e+17 cm^-3
    B0 = 250000000.00 nT
    M/m = 100

Plasma parameters:
    d_i = 6.9e-05 m
    t_ci = 1.4e-08 s
    v_ti = 2.1e+03 m/s
    vA = 3.0e+04 m/s
    vA/c = 0.0001
sim._sim.step()
<pywarpx.picmi.Cartesian2DGrid at 0x10a51df50>