import sys
import os
import numpy
import math

class data:
    """Class to read in data in the procar file

       Can specify filename, default is PROCAR
       number of spin channels, default is 1
       Whether noncollinear calculation was performed, default is False
    """

    def __init__(self,filename="PROCAR",nspin=1,noncollinear=False):
        """Reads PROCAR file, including kpoints, band energies and site-projections
        """
        self.filename=filename
        self.nspin=nspin
        self.noncollinear=noncollinear
        if self.noncollinear==True: 
            print( "NB: Separate x, y and z contributions are ignored for noncollinear calculations.")
            # Only one set of data if noncollinear calculation is performed
            self.nspin = 1
        #Open file object
        if os.path.isfile(self.filename):
            procar_file=open(self.filename,"r")
        else:
            print( "No file with name %s was found." % filename)
            sys.exit(0)
        #Read first line and determine file type
        line=procar_file.readline()
        if line.find("phase")!=-1:
            # LORBIT = 12
            self.itype = 2
            print( " File is lm decomposed with phase information")
        elif line.find("lm decomposed")!=-1:
            # LORBIT = 11
            self.itype = 1
            print( " File is lm decomposed without phase information")
        elif line.find("new format")!=-1:
            # LORBIT = 10
            self.itype = 0
            print( " File is l decomposed")
        else:
            print( " Could not recognise PROCAR file\n")
            sys.exit()
        #Read in number of k-points, bands and ions
        line=procar_file.readline()
        self.nkpoints=int(line[14:19])
        self.nbands_up=int(line[39:44])
        self.nions=int(line[63:68])
        print( "    Number of k-points : %d" % self.nkpoints)
        print( "       Number of bands : %d" % self.nbands_up)
        print( "        Number of ions : %d" % self.nions)
        #Set up some arrays
        self.kpoints=numpy.zeros((self.nkpoints,3),numpy.float)
        self.weights=numpy.zeros((self.nkpoints),numpy.float)
        self.kpointr=numpy.zeros((self.nkpoints),numpy.float)
        self.energies=numpy.zeros((self.nspin,self.nkpoints,self.nbands_up),numpy.float)
        self.occupancies=numpy.zeros((self.nspin,self.nkpoints,self.nbands_up),numpy.float)
        #Start loop over spins
        for ispin in range(self.nspin):
            if ispin == 1:
                line=procar_file.readline()
                if line=="":
                    print( " File only contains data for one spin channel.")
                    sys.exit(0)
                self.nbands_down=int(line[39:43])
                if self.nbands_down!=self.nbands_up:
                    print( " Only implemented for same number of bands in both spin channels")
                    sys.exit(0)
            # Data for up spin
            for ikpt in range(self.nkpoints):
                line=procar_file.readline()
                line=procar_file.readline()
                #Check we are at the right place
                if line.find("k-point")==-1:
                    print( " Error reading in PROCAR file")
                    print( " k-point %d" % (ikpt+1))
                    sys.exit(0)
                #Read k-point info
                self.kpoints[ikpt,0]=float(line[19:30])
                self.kpoints[ikpt,1]=float(line[30:41])
                self.kpoints[ikpt,2]=float(line[41:52])
                self.weights[ikpt]=float(line[65:76])
                #Distance along k-point line
                if ikpt>0:
                    kdiff = numpy.sqrt(numpy.dot((self.kpoints[ikpt]-self.kpoints[ikpt-1]),(self.kpoints[ikpt]-self.kpoints[ikpt-1])))
                    self.kpointr[ikpt]=self.kpointr[ikpt-1]+kdiff
                #Read in band energy and occupancy
                for iband in range(self.nbands_up):
                    line=procar_file.readline()
                    line=procar_file.readline()
                    #Check we are at the right place
                    if line.find("band")==-1:
                        print( " Error reading in PROCAR file")
                        print( " k-point %d, band %d" % (ikpt+1,iband+1))
                        print( line)
                        sys.exit(0)
                    self.energies[ispin,ikpt,iband]=float(line[19:33])
                    self.occupancies[ispin,ikpt,iband]=float(line[40:52])
                    #Read in projections
                    line=procar_file.readline()
                    line=procar_file.readline()
                    #Work out number of projections
                    if ispin==0 and ikpt==0 and iband==0.:
                        self.nproj=len(line.split())-2
                        self.projections=numpy.zeros((self.nspin,self.nkpoints,self.nbands_up,self.nions,self.nproj),numpy.float)
                        self.phases=numpy.zeros((self.nspin,self.nkpoints,self.nbands_up,self.nions,self.nproj,2),numpy.float)
                    for iion in range(self.nions):
                        line=procar_file.readline()
                        for iproj in range(self.nproj):
#                           print( "%d %d %d %d %d %7.4f"  % (ispin,ikpt,iband,iion,iproj,float(line[7+iproj*7:12+iproj*7])))
                            self.projections[ispin,ikpt,iband,iion,iproj]=float(line[7+iproj*7:12+iproj*7])
                    #Skip totals line if more than one ion present
                    if self.nions>1:
                        line=procar_file.readline()
                    #If it is a noncollinear spin calculation we will skip the lines with separate contributions
                    if self.noncollinear:
                        for i in range(3*(self.nions+1)):
                            line=procar_file.readline()
                        # If f-functions are present this seems to be necessary
                        if self.nproj==16:
                            for i in range(3):
                                line=procar_file.readline()
                    if self.itype==2:
                        line=procar_file.readline()
                        for iion in range(self.nions):
                            line=procar_file.readline()
                            for iproj in range(self.nproj):
                                self.phases[ispin,ikpt,iband,iion,iproj,0]=float(line[7+iproj*7:10+iproj*7])
                            line=procar_file.readline()
                            for iproj in range(self.nproj):
                                self.phases[ispin,ikpt,iband,iion,iproj,1]=float(line[7+iproj*7:10+iproj*7])
                    # If f-functions are present this seems to be necessary
                    if self.nproj==16:
                        line=procar_file.readline()
                line=procar_file.readline()
        print( " Successfully read file %s" % self.filename)

    def get_dos(self,emin,emax,nepoints,width,energy_shift=0.,occ=False,lineshape="Gaussian",spins=[]):
        """Returns the total density of states

        Returns the density of states as a numpy array over a given range by broadening
        states with a specified line shape 
        """
        #Check lineshape is supported
        #Currently accepts:
        #   gaussian - for Gaussian function
        #   erf - for error function
        #   lorentzian - for Lorentzian function
        #   gausslorentz - for a sum of Gaussian and Lorentzian functions
        if lineshape.lower()=="gaussian":
            #Redefine width so that full-width at half maximum is correct
            width=width/2./math.sqrt(2.*math.log(2.))
        elif lineshape.lower()=="erf":
            #Import erf function from scipy
            from scipy.special import erf
            #Redefine width so that full-width at half maximum is correct
            width=width/2./math.sqrt(2.*math.log(2.))
        elif lineshape.lower()=="gausslorentz":
            #Define Gaussian width parameter so that full-width at half maximum is correct
            sigma=width/2./math.sqrt(2.*math.log(2.))
        elif lineshape.lower()!="lorentzian":
            print( " ERROR: Line shape %s is not currently supported." % (lineshape))
            return
        #If spins is empty then run over all spin channels present
        if spins==[]:
            spins=range(self.nspin)
        dos=numpy.zeros((nepoints,2),numpy.float)
        #Work out difference between energy points
        ediff=(emax-emin)/float(nepoints-1)
        print( " Started calculation of density of states."  )
        #Loop over points in energy
        if lineshape.lower() == "gaussian":
            #Calculate gaussian prefactor and constant
            prefactor = 1./width/math.sqrt(2.*math.pi)
            constant = 1./2./width/width
            for ie in range(nepoints):
                energy = emin+ie*ediff
                dos[ie,0] = energy
                #Calculate array of gaussians, one element per band
                if occ:
                    gaussian = self.occupancies[spins,...]*numpy.exp(-constant*(self.energies[spins,...]+energy_shift-energy)**2)
                else:
                    gaussian = numpy.exp(-constant*(self.energies[spins,...]+energy_shift-energy)**2)
                #Scale gaussian by prefactor and k-point weight
                gaussian = prefactor*gaussian
                weights=self.weights
                weights.shape=[1,self.nkpoints,1]
                gaussian=weights*gaussian
                #Sum over all gaussian contributions at this energy
                dos[ie,1] = numpy.sum(gaussian)
        elif lineshape.lower() == "lorentzian":
            #Calculate gaussian prefactor and constant
            prefactor = width/math.pi/2.
            constant = (width/2.)**2
            for ie in range(nepoints):
                energy = emin+ie*ediff
                dos[ie,0] = energy
                #Calculate array of lorentzians, one element per band
                lorentzian = 1./((self.energies[spins,...]+energy_shift-energy)**2+constant)
                #Scale lorentzian by prefactor and k-point weight
                lorentzian = prefactor*lorentzian
                weights=self.weights
                weights.shape=[1,self.nkpoints,1]
                lorentzian=weights*lorentzian
                #Sum over all lorentzian contributions at this energy
                dos[ie,1] = numpy.sum(lorentzian)
        elif lineshape.lower()=="gausslorentz":
            g_prefactor = 1./sigma/math.sqrt(2.*math.pi)
            g_constant = 1./2./sigma/sigma
            l_prefactor = width/math.pi/2.
            l_constant = (width/2.)**2
            for ie in range(nepoints):
                energy = emin+ie*ediff
                dos[ie,0] = energy
                #Calculate array of gaussians, one element per band
                gaussian = numpy.exp(-g_constant*(self.energies[spins,...]+energy_shift-energy)**2)
                #Scale gaussian by prefactor and k-point weight
                gaussian = g_prefactor*gaussian
                weights=self.weights
                weights.shape=[1,self.nkpoints,1]
                gaussian=weights*gaussian
                #Sum over all gaussian contributions at this energy
                dos[ie,1] = 0.5*numpy.sum(gaussian)
                #Calculate array of lorentzians, one element per band
                lorentzian = 1./((self.energies[spins,...]+energy_shift-energy)**2+l_constant)
                #Scale lorentzian by prefactor and k-point weight
                lorentzian = l_prefactor*lorentzian
                lorentzian = weights*lorentzian
                #Sum over all lorentzian contributions at this energy
                dos[ie,1] = dos[ie,1]+0.5*numpy.sum(lorentzian)
        else:
            constant = 1.0/(math.sqrt(2.0)*width)
            for ie in range(nepoints):
                energy = emin+ie*ediff
                estart = energy-ediff/2.0
                eend   = energy+ediff/2.0
                dos[ie,0] = energy
                #Calculate array of error functions, one element per band
                contribution = erf(constant*(self.energies[spins,...]+energy_shift-estart))-erf(constant*(self.energies[spins,...]+energy_shift-eend))
                weights=self.weights
                weights.shape=[1,self.nkpoints,1]
                contribution *= weights
                #Sum over all gaussian contributions at this energy
                dos[ie,1] = numpy.sum(contribution)/2.0/ediff  # factor 2 needed in transition to error function from gaussian
        print( " Completed calculation of density of states."  )
        return dos

    def get_integrated_dos(self,emin,emax,nepoints,width,energy_shift=0.0,lineshape="Gaussian",spins=[]):
        """Returns the integrated total density of states

        Returns the density of states as a numpy array over a given range by broadening
        states with a specified line shape 
        """
        #If spins is empty then run over all spin channels present
        if spins==[]:
            spins=range(self.nspin)
        int_dos=self.get_dos(emin,emax,nepoints,width,energy_shift,lineshape,spins)
        ediff=(emax-emin)/float(nepoints-1)
        #Set dos at first energy to zero to initialise integration
        int_dos[0,1]=0.
        #Integrate
        int_dos[:,1]=int_dos[:,1].cumsum()
        int_dos[:,1]=int_dos[:,1]*ediff
        return int_dos

    def get_partial_dos(self,emin,emax,nepoints,width,energy_shift=0.,lineshape="Gaussian",spins=[],atoms=[],kpoints=[],orbitals="all"):
        """Returns the partial density of states

        Returns the partial density of states as a numpy array over a given energy range by broadening
        states with a specified line shape. This can be done for one atom, or summed over a number of atoms, 
        potentially even summed over specific k-points only.
        The atoms should be specified in atoms as indices starting from 0. 
        Different angular momentum can be specified using orbitals. Currently "all", "s", "p", "px", "py", "pz",
        "sp", "sp2", "d", "dxy", "dxz", "dyz", "dz2" and "dx2" are recognised.
        """
        #Check lineshape is supported
        if lineshape.lower()=="gaussian":
            #Redefine width so that full-width at half maximum is correct
            width=width/2./math.sqrt(2.*math.log(2.))
        elif lineshape.lower()=="erf":
            #Import erf function from scipy
            from scipy.special import erf
            #Redefine width so that full-width at half maximum is correct
            width=width/2./math.sqrt(2.*math.log(2.))
        elif lineshape.lower()=="gausslorentz":
            #Define Gaussian width parameter so that full-width at half maximum is correct
            sigma=width/2./math.sqrt(2.*math.log(2.))
        elif lineshape.lower()!="lorentzian":
            print( " ERROR: Line shape %s is not currently supported." % (lineshape))
            return
        #If spins is empty then run over all spin channels present
        if spins==[]:
            spins=range(self.nspin)
        # determine k-point weights
        if kpoints==[]:
            kpoints = range(self.nkpoints)
            weights_temp = self.weights.copy()
        else:   # take relative k-point weight but normalize it only over those points selected ...
            wtotal = 0
            for i in kpoints: wtotal += self.weights[i]
            weights_temp = numpy.zeros(self.nkpoints)    # unused k-points get zero weight ... 
            for i in kpoints: weights_temp[i] = self.weights[i]/wtotal
        pdos=numpy.zeros((nepoints,2),numpy.float)
        #Work out difference between energy points
        ediff=(emax-emin)/float(nepoints-1)
        #Parse orbitals string to select which projections to sum over
        proj=self.parse_orbitals(orbitals)
        #If atoms is empty, run over all atoms
        if atoms == []:
            atoms = range(self.nions)
        print( " Started calculation of partial density of states."  )
        #Reduce projections to just those needed
        temp_proj=self.projections[:,:,:,atoms,:]
        temp_proj=temp_proj[:,:,:,:,proj]
        temp_proj=temp_proj.sum(axis=4)
        temp_proj=temp_proj.sum(axis=3)
        temp_proj=temp_proj[spins,...]
        #Loop over points in energy
        if lineshape.lower() == "gaussian":
            #Calculate gaussian prefactor and constant
            prefactor = 1./width/math.sqrt(2.*math.pi)
            constant = 1./2./width/width
            for ie in range(nepoints):
                energy = emin+ie*ediff
                pdos[ie,0] = energy
                #Calculate array of gaussians, one element per band
                gaussian = numpy.exp(-constant*(self.energies[spins,...]+energy_shift-energy)**2)
                #Scale gaussian by prefactor and k-point weight
                gaussian = prefactor*gaussian
                weights_temp.shape=[1,self.nkpoints,1]
                gaussian=weights_temp*gaussian
                #Sum over all bands and projections
                pdos[ie,1] = numpy.sum(gaussian*temp_proj)
        elif lineshape.lower() == "lorentzian":
            #Calculate gaussian prefactor and constant
            prefactor = width/math.pi/2.
            constant = (width/2.)**2
            for ie in range(nepoints):
                energy = emin+ie*ediff
                pdos[ie,0] = energy
                #Calculate array of lorentzians, one element per band
                lorentzian = 1./((self.energies[spins,...]+energy_shift-energy)**2+constant)
                #Scale lorentzian by prefactor and k-point weight
                lorentzian = prefactor*lorentzian
                weights=self.weights
                weights.shape=[1,self.nkpoints,1]
                lorentzian=weights*lorentzian
                #Sum over all lorentzian contributions at this energy
                pdos[ie,1] = numpy.sum(lorentzian*temp_proj)
        elif lineshape.lower()=="gausslorentz":
            g_prefactor = 1./sigma/math.sqrt(2.*math.pi)
            g_constant = 1./2./sigma/sigma
            l_prefactor = width/math.pi/2.
            l_constant = (width/2.)**2
            for ie in range(nepoints):
                energy = emin+ie*ediff
                pdos[ie,0] = energy
                #Calculate array of gaussians, one element per band
                gaussian = numpy.exp(-g_constant*(self.energies[spins,...]+energy_shift-energy)**2)
                #Scale gaussian by prefactor and k-point weight
                gaussian = g_prefactor*gaussian
                weights=self.weights
                weights.shape=[1,self.nkpoints,1]
                gaussian=weights*gaussian
                #Sum over all gaussian contributions at this energy
                pdos[ie,1] = 0.5*numpy.sum(gaussian*temp_proj)
                #Calculate array of lorentzians, one element per band
                lorentzian = 1./((self.energies[spins,...]+energy_shift-energy)**2+l_constant)
                #Scale lorentzian by prefactor and k-point weight
                lorentzian = l_prefactor*lorentzian
                lorentzian = weights*lorentzian
                #Sum over all lorentzian contributions at this energy
                pdos[ie,1] = pdos[ie,1]+0.5*numpy.sum(lorentzian*temp_proj)
        else:
            constant = 1.0/(math.sqrt(2.0)*width)
            for ie in range(nepoints):
                energy = emin+ie*ediff
                estart = energy-ediff/2.0
                eend   = energy+ediff/2.0
                pdos[ie,0] = energy
                #Calculate array of error functions, one element per band
                contribution = erf(constant*(self.energies[spins,...]+energy_shift-estart))-erf(constant*(self.energies[spins,...]+energy_shift-eend))
                weights_temp.shape=[1,self.nkpoints,1]
                contribution*=weights_temp
                #Sum over all bands and projections
                pdos[ie,1] = numpy.sum(contribution*temp_proj)/2.0/ediff   # factor 2 needed in transition to error functions from gaussians
        print( " Completed calculation of partial density of states."  )
        return pdos

    def get_projections(self,atoms=[],orbitals="all"):
        """Returns projections onto specific atoms and orbitals
        """
        #Parse orbitals string to select which projections to sum over
        proj=self.parse_orbitals(orbitals)
        #Reduce projections to just those needed
        temp_proj=self.projections[:,:,:,atoms,:]
        temp_proj=temp_proj[:,:,:,:,proj]
        #Sum over those atoms and orbitals
        return temp_proj.sum(axis=-1).sum(axis=-1)

    def parse_orbitals(self,orbitals):
        """Function to parse orbitals option
        """
        #Parse orbitals string to select which projections to sum over
        if self.itype==0 and len(orbitals)>1 and orbitals!="all" and orbitals!="sp":
            print( "Orbital choice %s is not consistent with this PROCAR file" % orbitals)
            return
        if orbitals=="s":
            proj=[0]
        elif orbitals=="p":
            if self.itype==0:
                proj=[1]
            else:
                proj=[1,2,3]
        elif orbitals=="px":
            proj=[3]
        elif orbitals=="py":
            proj=[1]
        elif orbitals=="pz":
            proj=[2]
        elif orbitals=="sp":
            if self.itype==0:
                proj=[0,1]
            else:
                proj=[0,1,2,3]
        elif orbitals=="sp2":
            proj=[0,1,3]
        elif orbitals=="d":
            if self.itype==0:
                proj=[2]
            else:
                proj=[4,5,6,7,8]
        elif orbitals=="dxy":
            proj=[4]
        elif orbitals=="dxz":
            proj=[7]
        elif orbitals=="dyz":
            proj=[5]
        elif orbitals=="dz2":
            proj=[6]
        elif orbitals=="dx2":
            proj=[8]
        elif orbitals=="all":
            proj=range(self.nproj)
        else:
            print( "Orbital choice %s is not valid." % orbitals)
            return
        return proj
       
