In [1]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
import ipywidgets as widgets
from datetime import datetime

# Constants
h = 6.626e-34
m_e = 9.11e-31

# Material categories
semiconductors = {
 "Si": {"bandgap": 1.12, "E_v": 0.0, "E_c": 1.12, "m_n": 0.26 * m_e, "m_p": 0.39 * m_e},
 "GaAs": {"bandgap": 1.43, "E_v": 0.0, "E_c": 1.43, "m_n": 0.067 * m_e, "m_p": 0.48 * m_e},
 "Ge": {"bandgap": 0.66, "E_v": 0.0, "E_c": 0.66, "m_n": 0.55 * m_e, "m_p": 0.37 * m_e},
 "InP": {"bandgap": 1.35, "E_v": 0.0, "E_c": 1.35, "m_n": 0.08 * m_e, "m_p": 0.6 * m_e},
 "CdTe": {"bandgap": 1.5, "E_v": 0.0, "E_c": 1.5, "m_n": 0.11 * m_e, "m_p": 0.35 * m_e},
 "ZnSe": {"bandgap": 2.7, "E_v": 0.0, "E_c": 2.7, "m_n": 0.16 * m_e, "m_p": 0.6 * m_e},
 "GaN": {"bandgap": 3.4, "E_v": 0.0, "E_c": 3.4, "m_n": 0.2 * m_e, "m_p": 0.6 * m_e},
 "AlN": {"bandgap": 6.2, "E_v": 0.0, "E_c": 6.2, "m_n": 0.4 * m_e, "m_p": 0.8 * m_e},
 "InAs": {"bandgap": 0.36, "E_v": 0.0, "E_c": 0.36, "m_n": 0.023 * m_e, "m_p": 0.41 * m_e},
 "InSb": {"bandgap": 0.17, "E_v": 0.0, "E_c": 0.17, "m_n": 0.014 * m_e, "m_p": 0.45 * m_e},
 "ZnO": {"bandgap": 3.37, "E_v": 0.0, "E_c": 3.37, "m_n": 0.24 * m_e, "m_p": 0.59 * m_e},
 "CdS": {"bandgap": 2.42, "E_v": 0.0, "E_c": 2.42, "m_n": 0.21 * m_e, "m_p": 0.8 * m_e},
}

insulators = {
 "SiO2": {"bandgap": 9.0, "E_v": 0.0, "E_c": 9.0, "m_n": 0.5 * m_e, "m_p": 0.5 * m_e},
 "Al2O3": {"bandgap": 8.8, "E_v": 0.0, "E_c": 8.8, "m_n": 0.5 * m_e, "m_p": 0.5 * m_e},
 "HfO2": {"bandgap": 5.6, "E_v": 0.0, "E_c": 5.6, "m_n": 0.5 * m_e, "m_p": 0.5 * m_e},
 
}

metals = {
 "Metals": {"bandgap": 0.0, "E_v": 0.0, "E_c": 0.0, "m_n": 1.0 * m_e, "m_p": 1.0 * m_e},
}

all_materials = {**semiconductors, **insulators, **metals}

categories = {
 "Semiconductors": semiconductors,
 "Insulators": insulators,
 "Metals": metals
}

# Generate DoS
def D(E_v, E_c, m_n, m_p):
 E_above = np.linspace(E_c, E_c + 1, 1000)
 D_c = 8 * np.pi * m_n * np.sqrt(2 * m_n * (E_above - E_c)) / h**3
 E_below = np.linspace(E_v - 1, E_v, 1000)
 D_v = 8 * np.pi * m_p * np.sqrt(2 * m_p * (E_v - E_below)) / h**3
 return D_c, D_v, E_above, E_below

# Store all checkboxes
checkboxes = {}
for material in all_materials:
 default_checked = material in list(semiconductors.keys())[:3]
 checkboxes[material] = widgets.Checkbox(value=default_checked, description=material)

# UI elements
category_dropdown = widgets.Dropdown(
 options=list(categories.keys()),
 value="Semiconductors",
 description="Filter:"
)

plot_button = widgets.Button(description="Plot Selected")
clear_button = widgets.Button(description="Clear Output", button_style='danger')
checkbox_container = widgets.VBox()

def update_checkbox_display(category):
 """Display checkboxes from the selected category."""
 materials_in_cat = categories[category].keys()
 checkbox_container.children = [checkboxes[mat] for mat in materials_in_cat]

# Initial load
update_checkbox_display(category_dropdown.value)

# Plot handler
def on_plot_click(b):
 clear_output(wait=True)
 display(ui) # Redisplay UI after clearing
 
 fig, ax = plt.subplots()

 for mat, cb in checkboxes.items():
 if cb.value:
 props = all_materials[mat]
 D_c, D_v, E_above, E_below = D(props["E_v"], props["E_c"], props["m_n"], props["m_p"])
 color = ax._get_lines.get_next_color()
 ax.plot(D_c, E_above, label=mat, color=color)
 ax.plot(D_v, E_below, color=color)
 ax.axhline(props["E_c"], color=color, linestyle='--', alpha=0.5)

 yticks_pos = [0] + [all_materials[mat]["E_c"] for mat, cb in checkboxes.items() if cb.value]
 yticks_lbl = [r'$E_v$'] + [f'$E_c$ ({mat})' for mat, cb in checkboxes.items() if cb.value]
 ax.set_yticks(yticks_pos)
 ax.set_yticklabels(yticks_lbl)
 ax.set_xlabel(r'D(E) [cm$^{-3}$ eV$^{-1}$]')
 ax.set_ylabel('E [eV]')
 ax.axhline(0, color='black', linestyle='--', alpha=0.5, label=r'$E_v$')
 # ax.legend()
 ax.legend(borderpad=0.5, framealpha=1, edgecolor='#000', borderaxespad=0, fancybox=False, loc='upper right')
 ax.grid(True)
 plt.tight_layout()
 plt.margins(x=0.0, y=0.0, tight=True)
 
 # Save last figure to be exported
 global last_fig
 last_fig = fig
 
 plt.show()

# Export handlers
def on_export(format):
 if last_fig:
 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 filename = f"DoS_plot_{timestamp}.{format}"
 last_fig.savefig(filename, format=format)
 print(f"✅ Exported as: {filename}")

export_png_btn = widgets.Button(description="🖼️ Export PNG")
export_pdf_btn = widgets.Button(description="📄 Export PDF")

export_png_btn.on_click(lambda b: on_export("png"))
export_pdf_btn.on_click(lambda b: on_export("pdf"))

# Clear handler
def on_clear_click(b):
 clear_output(wait=True)
 display(ui)

# Update category
category_dropdown.observe(lambda change: update_checkbox_display(change['new']), names='value')
plot_button.on_click(on_plot_click)
clear_button.on_click(on_clear_click)

# Layout
ui = widgets.VBox([
 category_dropdown,
 checkbox_container,
 widgets.HBox([plot_button, clear_button, export_png_btn, export_pdf_btn])
])

last_fig = None # Holds last plotted figure for export
display(ui)


VBox(children=(Dropdown(description='Filter:', options=('Semiconductors', 'Insulators', 'Metals'), value='Semi…