# STM-BJ_scanSelect
# # v 1.2
# Richard J. Brooke <rb7318@bristol.ac.uk> 12.02.15

# Import python modules
from Tkinter import *
import tkFileDialog
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import os
from scipy.signal import medfilt
import platform
from shutil import copy
import cPickle as pickle

###### These are the parameters of the plateau finding code which you may find need changing ##############################################

limit = 0.1 # Minimum gradient criteria for finding conductance steps
diff = 5 # Data window size for calculating the gradient
median_filter, degree = True, 15 # Option for smoothing each scan with a moving median filter.
Plateau_checks, noise_check, slope_check = True, 0.0015, 2.5 # Option for additional checks on the plateau itself. Paramaters for noise and slope can be varied.

# Other parameters you might want to change

No_Data_sets = 1 # No. of data sets to be evaluated. So you can leave the program to run over lots of experiments while you go for lunch!
u_th,l_th = 0.2,-5 # Discard values of conductance larger than u_th and lower than l_th. Saves looking through data above G0 and in the noise level. Saves time.
Plotting = False #  Option to Visually check individual scans as they are evaluated. 

# Scan parameters
sampsec = 10000 # Data acquisition rate
srange = 4 # Scan range
sduration = 0.3 # Scan duration
interval = srange / sduration / sampsec # Step distance

# What to save ???
Save_Scans = True
Save_Plateaus = True
Save_Histograms = True

#################################################################################################################################################################

class Scan_filter:
	def __init__(self):       
		self.scans = [] # List to contain scan filenames and scan conductance
		self.steps = [] # Make list to contain locations of steps in each scan
		self.PlateauData = [] # List to contain lists of all plateaus for all scans
		self.KeepList = [] # List of all scan numbers to be kept
		self.prefix = ''

	def Scan_check(self,scanNo,index, G):
		# This method looks for plateaus in each of the scans
		plateaus = [] # local list for each scan 
		s = 'Checking scan ' + str(scanNo+1)
		KeepFlag = False
		# Plateaus in the STM-BJ scans are characterised by the presence of 2 or more sharp steps in the conductance.
		# Steps can be found by evaluating the gradient. Large spikes in the gradient are expected at conductance steps.
		# The derivative and second derivative of the scans is required to check for these steps.
		if median_filter == True: # Moving median smoothing can be used if scans are noisy
			G1 = medfilt(G,degree)
			dG = findgrad(index,G1, diff)  # Take derivative
			d2G = findgrad(index, dG, diff) # 2nd derivative
		else:
			dG = findgrad(index,G, diff)   # Take derivative
			d2G = findgrad(index, dG, diff) # 2nd derivative

		self.step_find(dG,d2G) # Call step find function to look for conductance steps in scan	]
				
		if (len(self.steps) > 1): # if the number of steps found is 2 or more then there are some possible plateaus in the scan
			for i in range(len(self.steps)-1):
				noiseflag,slopeflag = False,False
				if Plateau_checks == True: # Perform checks on the plateau if Plateau_checks is selected 
					# First fit a linear regression line to the Plateau
					slope,intercept,r_value,p_value,std_err = stats.linregress(range(self.steps[i]+1,self.steps[i+1]-1),G[self.steps[i]+1:self.steps[i+1]-1])  
					# Check that 'plateau' is not noisy.
					if std_err > noise_check:
						s += ', A plateau FAILED noise check, std_err = ' + "%.4f" % std_err
						noiseflag = True
					# Checks the Conductance range spanned by the plateau is not too large.
					if abs(G[self.steps[i]+1]-G[self.steps[i+1]-1]) > slope_check:
						s += ', A plateau FAILED slope check, DG = ' + "%.2f" % (G[self.steps[i]+1]-G[self.steps[i+1]-1])
						slopeflag = True                
					# End of plateau checks. 
					# We could add more checks here or use different checks. If you have any ideas please let me know. 

				if noiseflag == False and slopeflag == False: 
					# If the plateau has passed the checks then add it to the list of plateaus.
					KeepFlag = True
					plateaus.append([[index[self.steps[i]],index[self.steps[i+1]]],G[self.steps[i]:self.steps[i+1]]])
					s+= ', Plateau Found'

		if KeepFlag == True:
			s+= ' ... Scan kept!'
			self.KeepList.append(scanNo)
		else:
			s+= ' ... Scan discarded'
		print s
		self.PlateauData.append(plateaus) # add the plateaus to the global plateau list

		# Plots individual scans to visually check plateau selection
		if Plotting == True:		
			plt.close()     
			f,axarr = plt.subplots(3, sharex=True)  
			if median_filter == True:
				axarr[0].plot(index,G,linewidth = 0.7, color='k')
				axarr[0].plot(index,G1,linewidth = 1.5, color='b')
			else:
				axarr[0].plot(index,G, color='b')
			if KeepFlag == True:
				for i in range(len(plateaus)):
					platstart,platstop = plateaus[i][0][0],plateaus[i][0][1]
					axarr[0].plot(range(platstart,platstop),plateaus[i][1],color='r')
			for i in range(len(self.steps)):
				axarr[0].plot([index[self.steps[i]],index[self.steps[i]]],[1,-6], 'g',linewidth = 1)
				axarr[1].plot([index[self.steps[i]],index[self.steps[i]]],[0.1,-0.4], 'g',linewidth = 1)
				axarr[2].plot([index[self.steps[i]],index[self.steps[i]]],[0.06,-0.06], 'g',linewidth = 1)
			axarr[1].plot(index,dG,color='b')
			limitline = [-limit]*len(index)
			axarr[1].plot(index,limitline, 'k',linewidth = 1.2,ls='--')
			axarr[2].plot(index,d2G, color='b')
			limitline = [0]*len(index)
			axarr[2].plot(index,limitline, 'k',linewidth = 1.2, ls='--')
			axarr[0].set_ylim([-5.8,0.5])
			axarr[1].set_ylim([-0.4,0.1])
			axarr[2].set_ylim([-0.06,0.06])
			axarr[0].set_ylabel(r'$Log_{10}[G/G_{0}]$')
			axarr[1].set_ylabel(r'$d(Log_{10}[G/G_{0}])/dz$')
			axarr[2].set_ylabel(r'$d^{2}(Log_{10}[G/G_{0}])/dz^{2}$')
			axarr[2].set_xlabel(r'$z$')
			axarr[0].set_xlim([index[0],index[-1]])
			axarr[1].set_xlim([index[0],index[-1]])
			axarr[2].set_xlim([index[0],index[-1]])
			plt.show()

	def step_find(self, x, dx):
		# checks for conductance steps in scan 
		step = None
		for i in range(2,len(x)-2):
		# range is from 2 to -2 to avoid index errors when finding averages
			# Find conductance step 
			# Conductance steps occur when the second differential of the conductance trace is equal to zero.
			# however the second differential is noisy so goes through zero in many places.
			# therefore there is also a check to see that the gradient is larger than the specified limit.     
			if (average(dx[i-2:i]) < 0.0) and (average(dx[i:i+2]) > 0.0) and (average(x[i-2:i+2])<-limit):
				step = i
				if len(self.steps) > 0:
					offset = self.steps[-1]
				else:
					offset = 0
				self.steps.append(offset+step)			
				break
			else:
				step = None
		# Recursively call to find all Conductance steps unless no step can be found
		if step != None:
			self.step_find(x[step:-1],dx[step:-1])

	def start(self,prefix): 
		# This method is called first for each scan file in the data set it calls the scan_check method for each scan. 
		# It plots the final histograms and saves them in the directory /Selected_Histograms
		self.prefix = os.path.abspath(prefix)
		Scan_names, self.scans = list(os.listdir(prefix)),os.listdir(prefix) # list of file names of scans for checking 
		no_scans = len(self.scans) # total number of scans
		for scanNo in range(no_scans):  
			self.steps = [] # Reset list to contain locations of steps in each scan
			index, G = self.readFile(scanNo)
			self.Scan_check(scanNo,index, G)
			self.scans[scanNo] = G # Replace scan filename entry with conductance list
				
		plt.close()
		# Calculate 'X', ratio of selected to discarded scans
		plateaus = [x for x in self.PlateauData if x != []] # remove empty lists from plateaus (scans which have no plateau)
		no_selected = len(plateaus)
		X = float(no_selected)/float(no_scans)*100		
		# Plot histogram of all scans conductance
		flattened = [val for sublist in self.scans for val in sublist]		
		# Normalise histogram height to total number of scans
		weight = float(1.0/no_scans)
		hist_weights = [weight]*len(flattened)
		# Calculate the number of bins (Histogram bin width should be the same for all plots which are to be compared)
		bin_width = 0.01
		numBins = int(abs(max(flattened)-min(flattened))/bin_width)
		n1, bins1, patches =plt.hist(flattened,bins= numBins, weights = hist_weights, color='k',histtype='step', normed=1, label = 'Unfiltered Data')
		# Plot histogram of selected plateau conductance
		flattened = []
		# make a list of all plateau conductance data
		for i in range(len(plateaus)):
			for j in range(len(plateaus[i])):
				flattened.extend(plateaus[i][j][1])	
		weight = float(1.0/no_selected)
		hist_weights = [weight]*len(flattened)
		numBins = int(abs(max(flattened)-min(flattened))/bin_width)
		n2, bins2, patches = plt.hist(flattened,bins= numBins,weights = hist_weights,color='r',histtype='step', normed=1, label = 'Plateaus')
		plt.legend()
		print "%.2f" % X, '% of ',  no_scans, ' scans selected'
		if No_Data_sets > 1:
			plt.show(block=False)
		else:
			plt.show()


		# sort out directory for saving data
		prefix = prefix.split(slash)
		directory = slash
		for i in range(len(prefix)-1):
			directory += slash + prefix[i]


		if Save_Scans == True:
			# Save selected scans
			# Try to make the save directory unless it already exists
			try: 
				os.makedirs(directory +slash + r'Selected_Scans' + slash)
			except OSError:
					if not os.path.isdir(directory +slash + r'Selected_Scans'+slash):
						raise
			# Save Data
			Fnum = 0
			for i in self.KeepList:
				I_dat = self.scans[i]
				# Reconstruct distance data
				s_dat = range(0,len(I_dat)) 
				s_dat = synchronizeScans(s_dat,I_dat)
				for j in range(len(s_dat)):
					s_dat[j] = (s_dat[j])*interval
				# Change back to linear conductance scale
				for j in range(len(I_dat)):
					I_dat[j] = 10**I_dat[j]	
				string = ''
				Fname = directory +slash+ r'Selected_Scans'+slash+'Selected_Scan_' + str(Fnum).zfill(4) + '.txt'
				for j in range(len(s_dat)):
					string += str(s_dat[j]) + ', ' + str(I_dat[j]) + '\n' 
				with open(Fname,'wb') as f:
					f.write(string)
				Fnum += 1 
		


		if Save_Plateaus == True:
			# Save selected Plateaus
			# Try to make the save directory unless it already exists
			try: 
				os.makedirs(directory +slash+ r'Selected_Plateaus' + slash)
			except OSError:
					if not os.path.isdir(directory +slash+ r'Selected_Plateaus'+slash):
						raise
			# Save Data
			Data = [x for x in self.PlateauData if x != []] 
			# remove empty lists from plateaus (scans which have no plateau)
			Fnum = 0
			for Scan in Data:
				string = ''
				Fname = directory +slash+ r'Selected_Plateaus'+slash+'Plateau_' + str(Fnum).zfill(4) + '.txt'
				for plateau in Scan:
					# Reconstruct distance
					s = np.arange(plateau[0][0],plateau[0][1])
					for i in range(len(plateau[1])):
						S, I = s[i]*interval,10**float(plateau[1][i])
						string += str(S) + ', ' + str(I) + '\n' 
				with open(Fname,'wb') as f:
					f.write(string)
				Fnum += 1 
					
			

		if Save_Histograms == True:
			# Save histograms
			# Try to make the save directory unless it already exists
			try: 
				os.makedirs(directory +slash + r'Selected_Histograms' + slash)
			except OSError:
					if not os.path.isdir(directory +slash+ r'Selected_Histograms'+slash):
						raise
			# save data
			bins1, n1 = bin_centre(bins1, n1)
			bins2, n2 = bin_centre(bins2, n2)
			file_name1 = directory+slash+'Selected_Histograms'+ slash + 'Unfiltered_histogram.txt'
			file_name2 = directory+slash+'Selected_Histograms'+ slash + 'Selected_histogram.txt'
			s1,s2 = '', ''
			for i in range(len(n1)):
				try:
					s1 += str(bins1[i]) + ', ' + str(n1[i]) + '\n'
					s2 += str(bins2[i]) + ', ' + str(n2[i]) + '\n'
				except IndexError:
					s1 += str(bins1[i]) + ', ' + str(n1[i]) + '\n'

			with open(file_name1, 'wb') as f:
				f.write(s1)
			with open(file_name2, 'wb') as f:
				f.write(s2)


		# Reset all lists for next data set.
		self.scans = []
		self.steps = []
		self.PlateauData = []
		self.KeepList = [] 
		self.prefix = ''


	def readFile(self, scanNo):
		I = []
		filename = self.scans[scanNo]
		with open(self.prefix+slash+filename) as F: 
			for line in F:
				if float(line) > 0:
					if (np.log10(float(line)) < u_th) and (np.log10(float(line)) > l_th):
						I.append(np.log10(float(line)))
		i = np.arange(len(I))			
		i = synchronizeScans(i,I)
		return i,I


    
############################## Define FUNCTIONS ###################################### 

def average(List): # Calculates the average of a list
	return sum(List)/len(List)

def findgrad(x, y, diff):
	gradient = []
	for i in range(0,diff):
		gradient.append(0) # Index error due to range used to calculate grad. Ignore start and end of scan.        
	for i in range(diff,len(x)-diff):
		slope,intercept,r_value,p_value,std_err = stats.linregress(x[(i-diff):(i+diff)],y[(i-diff):(i+diff)])
		gradient.append(slope)  
	for i in range(len(x)-diff,len(x)): # Ignore end of scan (otherwise Index error)
		gradient.append(0)   
	return gradient	

def bin_centre(bins, n):
	bincenters = 0.5*(bins[1:]+bins[:-1])
	new_n =  0.5*(n[1:]+n[:-1])
	return bincenters, new_n

def synchronizeScans(s_dat,I_dat):
	# synchronise the displacement of the scans when G first becomes smaller than 0.5G0	
	points = len(I_dat)
	for i in range(points):
		if I_dat[i] <= -0.6:
			s_origin = i
			break
	for i in range(points):
		s_dat[i] = s_dat[i] - s_origin
	return s_dat

#######################################################################################


system = platform.system()
if system == 'Linux':
	slash = '/'
else:
	slash = '\ '.strip()

Filter = Scan_filter()
Data = []
for i in range(No_Data_sets):
	Data.append(tkFileDialog.askdirectory())

for prefix in Data:
	Filter.start(prefix)

