grrmpy.visualize.custum_viewer のソースコード

import sys
import traceback
from typing import Any,Dict,List,Optional,Union
import threading
import time
import numpy as np
from ase import Atoms
from ase.constraints import FixAtoms
from ase.io.vasp import write_vasp

from ase.io import Trajectory, write
from ase.visualize.ngl import NGLDisplay
from ipywidgets import (Button, Checkbox, Output,
                        Text, BoundedFloatText,RadioButtons,Image,ColorPicker,BoundedIntText,FloatSlider,
                        HBox,VBox,Tab,Dropdown,Layout)
from traitlets import Bunch

import nglview as nv
from nglview import NGLWidget
from nglview.component import ComponentViewer
from nglview.color import ColormakerRegistry

# USER
from grrmpy.visualize.functions import (update_tooltip_atoms,generate_js_code,
                                        get_struct,add_force_shape,rotate_view,spin_view)
from grrmpy.visualize.color import default,vesta,jmol

[ドキュメント]class View(NGLDisplay): """ Parameters: atoms: | AtomsまたはAtomsのリストまたはTrajectoryクラス xsize: | 横幅(px単位) ysize: | 縦幅(px単位) """ def __init__( self, atoms: Union[Atoms, Trajectory, List[Atoms]], xsize: int = 400, ysize: int = 500, ): super().__init__(atoms, xsize=xsize, ysize=ysize) self.v = self.gui.view # For backward compatibility... # del self.gui # デフォルトのGUIを削除 # self.gui = HBox([self.view, VBox()]) # GUIを再設定 # Make useful shortcuts for the user of the class self.gui.view = self.view self.gui.control_box = self.gui.children[1] self.gui.custom_colors = self.custom_colors ####### Property ##################################### self.replace_structure = False self._use_struct_cache = True self._struct_cache = [] self._force_components = [] self.force_color = [1, 0, 0] # Red vector for force color. self.pre_label = False if isinstance(atoms, Atoms): self._struct_cache = [None] else: # atoms is Trajectory or List[Atoms] self._struct_cache = [None for _ in range(len(atoms))] #色の設定 self.cm = ColormakerRegistry default_jscode = generate_js_code(default) vesta_jscode = generate_js_code(vesta) jmol_jscode = generate_js_code(jmol) self.cm.add_scheme_func('default',default_jscode) self.cm.add_scheme_func('vesta',vesta_jscode) self.cm.add_scheme_func('jmol',jmol_jscode) ################################################### # ---原子上にマウスを置いたときに,原子のindexと位置を表示する update_tooltip_atoms(self.view, self._get_current_atoms()) # GUI作成&表示 self.build_gui() # 初期表示 self.view.camera = "orthographic" if self.camera_style=='平行投影' else "perspective" self.view.add_spacefill() self.view.add_ball_and_stick() self.view.add_label( color="black", labelType="text", labelText=["" for _ in range(len(self._get_current_atoms()))], zOffset=2.0, attachment="middle_center", radius=1, ) self._update_repr() self.view.unobserve(NGLWidget._on_frame_changed) self.view.observe(self._on_frame_changed, names=["frame"]) @property def camera_style(self): return self.gui.camera_radio_btn.value @property def show_charge(self): return self.gui.show_charge_checkbox.value @property def show_force(self): return self.gui.show_force_checkbox.value @property def transparent(self): return self.gui.transparent.value @property def factor(self): return self.gui.factor.value def build_gui(self): ###カメラ(RadioButton)### self.gui.camera_radio_btn = RadioButtons( options=['平行投影','透視投影'], value='平行投影', description='カメラ') self.gui.camera_radio_btn.observe(self.change_camera) ###ラベル(RadioButton)### self.gui.label_radio_btn = RadioButtons( options=['なし','インデックス','元素','電荷','FixAtoms'], value='なし', description='ラベル') self.gui.label_radio_btn.observe(self.change_label) ###セル(CheckBox)#### self.gui.cell_check_box = Checkbox(value=True,description="セルユニット",) self.gui.cell_check_box.observe(self.show_unitcell) ###カラースキーム### self.csel = Dropdown(options=["default","vesta","element","jmol"], value='element', description='色') self.csel.observe(self._update_repr) ###モデル### self.gui.model_radio_btn = RadioButtons( options=['球棒モデル','空間充填モデル'], value='球棒モデル', description='モデル') self.gui.model_radio_btn.observe(self._update_repr) ###再配置(チェックボックス)### self.gui.replace_structure_checkbox = Checkbox( value=self.replace_structure, description="再配置") self.gui.replace_structure_checkbox.observe(self.change_replace_structure) ###アウトプット(エラー表示)### self.gui.out_widget = Output(layout={"border": "0px solid black"}) ##IMG## self.gui.a = Button(description='A',tooltip='A軸方向から見る',layout = Layout(width='30px')) self.gui.a.on_click(lambda e,x=180,y=-90,z=90: self.rotate_view(e,x,y,z)) self.gui.b = Button(description='B',tooltip='B軸方向から見る',layout = Layout(width='30px')) self.gui.b.on_click(lambda e,x=90,y=0,z=-90: self.rotate_view(e,x,y,z)) self.gui.c = Button(description='C',tooltip='C軸方向から見る',layout = Layout(width='30px')) self.gui.c.on_click(lambda e,x=180,y=0,z=0: self.rotate_view(e,x,y,z)) self.gui.center = Button(tooltip='中心',layout = Layout(width='30px'),icon='bullseye') self.gui.center.on_click(lambda e,:self.center(e)) # self.gui.up = Button(description='x',tooltip=' x軸周りで回転(上)',layout = Layout(width='30px')) # self.gui.up.on_click(lambda e,axis=0,angle=10: self.spin_view(e,axis,angle)) # self.gui.down = Button(description='x*',tooltip='x軸周りで回転(下))',layout = Layout(width='30px')) # self.gui.left = Button(description='y*',tooltip='y軸周りで回転(左)',layout = Layout(width='30px')) # self.gui.right = Button(description='y',tooltip='y軸周りで回転(右)',layout = Layout(width='30px')) # self.gui.cc = Button(description='z',tooltip='z軸周りで回転(時計周り)',layout = Layout(width='30px')) # self.gui.rc = Button(description='z*',tooltip='z軸周りで回転(反時計周り)',layout = Layout(width='30px')) # self.gui.step = BoundedFloatText(value=10,min=0,max=360,step=10, # description='Step(*):', # layout = Layout(width='150px')) # ファイル名 self.gui.transparent = Checkbox(value=True,description="PIG保存時に透過背景") if isinstance(self.atoms, Atoms): self.gui.filename_text = Text(value="Atoms", description="ファイル名: ",layout = Layout(width='200px')) self.extention_dict = {0:'traj', 1:'cif', 2:'xyz', 3:'html', 4:"vasp", 5:'png'} self.gui.file_extention = Dropdown(options=[(val,key) for key,val in self.extention_dict.items()],value=1,layout = Layout(width='70px')) else:# list of Atoms or Traj self.gui.filename_text = Text(value="Images", description="ファイル名: ",layout = Layout(width='200px')) self.extention_dict = {0:'traj',1:'traj+',2:'cif',3:'cif+',4:'xyz',5:'xyz+',6:'html',7:'html+',8:"vasp",9:'png'} self.gui.file_extention = Dropdown(options=[(val,key) for key,val in self.extention_dict.items()],value=1,layout = Layout(width='70px')) ##ダウンロード## self.gui.download = Button(description='PNGをダウンロード', tooltip='ローカルPCにPNGをダウンロードする') self.gui.download.on_click(self.download_image) ##保存## self.gui.save = Button(description='ファイルへ保存') self.gui.save.on_click(self.save_image) # 電荷表示 self.gui.show_charge_checkbox = Checkbox(value=False,description="電荷:色",) self.gui.show_charge_checkbox.observe(self.show_charge_event) self.gui.charge_scale_slider = BoundedFloatText( value=1.0, min=0.0, max=100.0, step=0.1, description="電荷:スケール",style = {'description_width': 'initial'}) self.gui.charge_scale_slider.observe(self.show_charge_event) self.gui.show_force_checkbox = Checkbox(value=False,description="力",) self.gui.show_force_checkbox.observe(self.show_force_event) self.gui.force_scale_slider = FloatSlider( value=0.5, min=0.0, max=100.0, step=0.1, description="力:スケール") self.gui.force_scale_slider.observe(self.show_force_event) # その他 self.gui.color_picker = ColorPicker(concise=False,description='ラベルカラー:',value='black', layout = Layout(width='200px'),style = {'description_width': 'initial'}) self.gui.color_picker.observe(self.update_label) self.gui.label_size = BoundedFloatText(value=1.0,min=0,max=3.0,step=0.1,description='ラベルサイズ:', layout = Layout(width='200px'),style = {'description_width': 'initial'}) self.gui.label_size.observe(self.update_label) self.gui.charge_round = BoundedIntText(value=2,min=0,max=10,step=1,description='電荷 小数点以下:', layout = Layout(width='200px'),style = {'description_width': 'initial'}) self.gui.charge_round.observe(self.change_label) self.gui.factor = BoundedIntText(value=4,min=0,max=30,step=1,description='PIGの解像度:', layout = Layout(width='200px'),style = {'description_width': 'initial'}) ###表示#### # r = list(self.gui.control_box.children) img1 = HBox([ self.gui.a, self.gui.b, self.gui.c, self.gui.center, # self.gui.step # 回転の中心を変える方法が分からないので非表示 ]) # img2 = HBox([ # self.gui.up, # self.gui.down, # self.gui.right, # self.gui.left, # self.gui.cc, # self.gui.rc, # ]) general = VBox([ img1, # img2, # 回転の中心を変える方法が分からないので非表示 self.gui.transparent, HBox([self.gui.filename_text,self.gui.file_extention]), HBox([self.gui.download,self.gui.save]), self.gui.label_radio_btn, self.gui.show_charge_checkbox, self.gui.charge_scale_slider, self.gui.show_force_checkbox, self.gui.force_scale_slider, ]) other = VBox([ self.csel, self.rad, self.gui.model_radio_btn, self.gui.camera_radio_btn, self.gui.cell_check_box, ]) detail = VBox([ self.gui.color_picker, self.gui.label_size, self.gui.charge_round, self.gui.factor, ]) self.tab = Tab([general,other,detail],_titles={0:"プロパティなど", 1:"スタイル",2:"その他"}) self.gui.control_box.children = tuple([self.tab, self.gui.replace_structure_checkbox, self.gui.out_widget]) def save_image(self,e=None): name = self.gui.filename_text.value atoms = self._get_current_atoms() # 現在表示中のAtoms if self.extention_dict[self.gui.file_extention.value] == "traj": write(f"{name}.traj",atoms,format="traj") elif self.extention_dict[self.gui.file_extention.value] == "traj+": write(f"{name}.traj",self.atoms,format="traj") elif self.extention_dict[self.gui.file_extention.value] == "cif": write(f"{name}.cif",atoms,format="cif") elif self.extention_dict[self.gui.file_extention.value] == "cif+": write(f"{name}.cif",self.atoms,format="cif") elif self.extention_dict[self.gui.file_extention.value] == "xyz": write(f"{name}.xyz",atoms) elif self.extention_dict[self.gui.file_extention.value] == "xyz+": write(f"{name}.xyz",self.atoms) elif self.extention_dict[self.gui.file_extention.value] == "html": nv.write_html(f"{name}.html",self.view,tuple([self.view.frame])) elif self.extention_dict[self.gui.file_extention.value] == "html+": nv.write_html(f"{name}.html",self.view,(0,len(self.atoms)-1)) elif self.extention_dict[self.gui.file_extention.value] == "vasp": write_vasp(f"{name}",atoms,sort=True,wrap=True,direct=True) elif self.extention_dict[self.gui.file_extention.value] == "png": thread = threading.Thread(target=self._save_image_png, args=(f"{name}.png", self.view), daemon=True) thread.start() def download_image(self,e=None): try: filename = self.gui.filename_text.value self.view.download_image(filename=filename,transparent=self.transparent,factor=self.factor) except Exception as e: with self.gui.out_widget: print(traceback.format_exc(), file=sys.stderr) def _save_image_png(self,filename: str, v: NGLWidget): try: image = v.render_image(transparent=self.transparent,factor=self.factor) except Exception as e: with self.gui.out_widget: print(traceback.format_exc(), file=sys.stderr) while not image.value: time.sleep(0.1) with open(filename, "wb") as fh: fh.write(image.value) def rotate_view(self,e,x,y,z): rotate_view(self.view,x=x,y=y,z=z) self.view.center() def spin_view(self,e, axis, angle): spin_view(self.view, axis, angle) self.view.center() def center(self,e=None): self.view.center() def update_label(self,e=None): self.view.update_label( color=self.gui.color_picker.value, labelType="text", labelText=self.labelText, zOffset=2.0, attachment="middle_center", radius=self.gui.label_size.value, ) def _change_label(self,atoms,option): if option == "インデックス": self.labelText=[str(i) for i in range(len(atoms))] elif option == "元素": self.labelText=[i for i in atoms.get_chemical_symbols()] elif option == "電荷": try: self.labelText = np.round(atoms.get_charges().ravel(), self.gui.charge_round.value).astype("str").tolist() except: self.labelText=["" for _ in range(len(atoms))] with self.gui.out_widget: raise Exception("Calculatorを設定してください") elif option == "FixAtoms": self.labelText = self._get_fix_atoms_label_text(atoms) self.update_label() def change_label(self,e=None): self.gui.out_widget.clear_output() option = self.gui.label_radio_btn.value atoms = self._get_current_atoms() if option == "なし": self.labelText=["" for _ in range(len(atoms))] self.update_label() return else: self._change_label(atoms,option) return def change_camera(self,e=None): option = self.gui.camera_radio_btn.value if option == "平行投影": self.view.camera = 'orthographic' elif option == "透視投影": self.view.camera = 'perspective' def _update_repr(self,e=None): option = self.gui.model_radio_btn.value if option == "球棒モデル": self.view.update_spacefill(radiusType='covalent', radiusScale=self.rad.value, color_scheme=self.csel.value)#color_scale='rainbow') self.view.update_ball_and_stick(color_scheme=self.csel.value) elif option == "空間充填モデル": self.view.remove_spacefill() self.view.add_spacefill() self.view.update_spacefill(radiusType="vwf",color_scheme=self.csel.value) def show_unitcell(self,e=None): if self.gui.cell_check_box.value: self.view.add_unitcell() # Cellの表示 else: self.view.remove_unitcell() def _get_current_atoms(self) -> Atoms: if isinstance(self.atoms, Atoms): return self.atoms else: return self.atoms[self.view.frame] def _get_fix_atoms_label_text(self,atoms): indices_list = [] for constraint in atoms.constraints: if isinstance(constraint, FixAtoms): indices_list.extend(constraint.index.tolist()) label_text = [] for i in range(len(atoms)): if i in indices_list: label_text.append("Fix") else: label_text.append("") return label_text def change_replace_structure(self,event: Optional[Bunch] = None): if self.gui.replace_structure_checkbox.value: self.replace_structure = True self._on_frame_changed(None) else: self.replace_structure = False def _on_frame_changed(self, change: Dict[str, Any]): """set and send coordinates at current frame""" v: NGLWidget = self.view atoms: Atoms = self._get_current_atoms() self.clear_force() if self.replace_structure: # set and send coordinates at current frame struct = self._struct_cache[v.frame] if struct is None: struct = get_struct(atoms) if self._use_struct_cache: self._struct_cache[v.frame] = struct # Cache v._remote_call("replaceStructure", target="Widget", args=struct) else: # Only update position info v._set_coordinates(v.frame) if self.show_force: self.add_force() if self.show_charge: self.show_charge_event() # Tooltip: update `var atoms_pos` inside javascript. atoms = self._get_current_atoms() if atoms.get_pbc().any(): _, Q = atoms.cell.standard_form() else: Q = np.eye(3) Q_str = str(Q.tolist()) var_str = f"this._Q = {Q_str}" v._execute_js_code(var_str) ###ラベルの更新 option = self.gui.label_radio_btn.value if option != "なし": atoms = self._get_current_atoms() self._change_label(atoms,option) def _ipython_display_(self, **kwargs): """viewプロパティを書かなくてもjupyter上で勝手に表示してくれる""" return self.gui._ipython_display_(**kwargs) def show_charge_event(self, event: Optional[Bunch] = None, refresh: bool = True): self.gui.out_widget.clear_output() if self.show_charge: atoms = self._get_current_atoms() # TODO: How to change `scale` and `radiusScale` by user? # Register "atomStore.partialCharge" attribute inside javascript charge_scale: float = self.gui.charge_scale_slider.value # Note that Calculator must be set here! try: charge_str = str((atoms.get_charges().ravel() * charge_scale).tolist()) except Exception as e: with self.gui.out_widget: print(traceback.format_exc(), file=sys.stderr) # `append_stderr` method shows same text twice somehow... # self.gui.out_widget.append_stderr(str(e)) return var_code = f"var chargeArray = {charge_str}" js_code = """ var component = this.stage.compList[0] var atomStore = component.structure.atomStore if (atomStore.partialCharge === undefined) { atomStore.addField('partialCharge', 1, 'float32') } for (let i = 0; i < chargeArray.length; ++i) { atomStore.partialCharge[i] = chargeArray[i]; } """ self.view._execute_js_code(var_code + js_code) # Show charge color # TODO: More efficient way: # We must set other "color_scheme" at first, to update "partialcharge" color scheme... # color_schme="element" is chosen here, but any color_scheme except "partialcharge" is ok. # Skip this procedure to avoid heavy process, user must turn on and off "show charge" now. if refresh: self.view._update_representations_by_name( "spacefill", radiusType="covalent", radiusScale=self.rad.value, color_scheme="element", color_scale="rwb", ) self.view._update_representations_by_name( "spacefill", radiusType="covalent", radiusScale=self.rad.value, color_scheme="partialcharge", color_scale="rwb", ) else: # Revert to original color scheme. self._update_repr() def show_force_event(self, event: Optional[Bunch] = None): self.gui.out_widget.clear_output() self.clear_force() if self.show_force: self.add_force() def add_force(self): force_scale: float = self.gui.force_scale_slider.value try: atoms = self._get_current_atoms() c = add_force_shape(atoms, self.v, force_scale, self.force_color) self._force_components.append(c) except Exception as e: with self.gui.out_widget: print(traceback.format_exc(), file=sys.stderr) # `append_stderr` method shows same text twice somehow... # self.gui.out_widget.append_stderr(str(e)) return def clear_force(self): # Remove existing force components. for c in self._force_components: self.v.remove_component(c) # Same with c.clear() self._force_components = []
def add_force_shape( atoms: Atoms, view: NGLWidget, force_scale: float = 0.5, force_color: Optional[List[int]] = None, ) -> ComponentViewer: if force_color is None: force_color = [1, 0, 0] # Defaults to red color. # Add force components forces = atoms.get_forces() pos = atoms.positions if atoms.get_pbc().any(): rcell, rot_t = atoms.cell.standard_form() rot = rot_t.T pos_frac = pos.dot(rot) force_frac = forces.dot(rot) else: pos_frac = pos force_frac = forces shapes = [] for i in range(atoms.get_global_number_of_atoms()): pos1 = pos_frac[i] pos2 = pos1 + force_frac[i] * force_scale pos1_list = pos1.tolist() pos2_list = pos2.tolist() shapes.append(("arrow", pos1_list, pos2_list, force_color, 0.2)) c = view._add_shape(shapes, name="Force") return c