Source code for lib.style

from rich import print as rprint
import plotly.graph_objects as go
import plotly.express as px

[docs]def format_coustom_plotly( fig, title=None, legend=dict(), fontsize=16, figsize=None, ranges=(None, None), matches=("x", "y"), tickformat=(".s", ".s"), log=(False, False), margin={"auto": True}, add_units=False, debug=False, ): """ Format a plotly figure Args: fig (plotly.graph_objects.Figure): plotly figure title (str): title of the figure (default: None) legend (dict): legend options (default: dict()) fontsize (int): font size (default: 16) figsize (tuple): figure size (default: None) ranges (tuple): axis ranges (default: (None,None)) matches (tuple): axis matches (default: ("x","y")) tickformat (tuple): axis tick format (default: ('.s','.s')) log (tuple): axis log scale (default: (False,False)) margin (dict): figure margin (default: {"auto":True,"color":"white","margin":(0,0,0,0)}) add_units (bool): True to add units to axis labels, False otherwise (default: False) debug (bool): True to print debug statements, False otherwise (default: False) Returns: fig (plotly.graph_objects.Figure): plotly figure """ # Find the number of subplots if type(fig) == go.Figure: try: rows, cols = fig._get_subplot_rows_columns() rows, cols = rows[-1], cols[-1] except Exception: rows, cols = 1, 1 rprint("[red]Error: unknown figure type[/red]") else: rows, cols = 1, 1 rprint("[red]Error: unknown figure type[/red]") if debug: rprint("[blue]Detected number of subplots: " + str(rows * cols) + "[/blue]") if figsize == None: figsize = (800 + 400 * (cols - 1), 600 + 200 * (rows - 1)) default_margin = {"color": "white", "margin": (0, 0, 0, 0)} if margin != None: for key in default_margin.keys(): if key not in margin.keys(): margin[key] = default_margin[key] fig.update_layout( title=title, legend=legend, template="presentation", font=dict(size=fontsize), paper_bgcolor=margin["color"], ) # font size and template fig.update_xaxes( matches=matches[0], showline=True, mirror="ticks", showgrid=True, minor_ticks="inside", tickformat=tickformat[0], # range=ranges[0], ) # tickformat=",.1s" for scientific notation if ranges[0] != None: fig.update_xaxes(range=ranges[0]) if ranges[1] != None: fig.update_yaxes(range=ranges[1]) fig.update_yaxes( matches=matches[1], showline=True, mirror="ticks", showgrid=True, minor_ticks="inside", tickformat=tickformat[1], # range=ranges[1], ) # tickformat=",.1s" for scientific notation if figsize != None: fig.update_layout(width=figsize[0], height=figsize[1]) if log[0]: fig.update_xaxes(type="log", tickmode="linear") if log[1]: fig.update_yaxes(type="log", tickmode="linear") if margin["auto"] == False: fig.update_layout( margin=dict( l=margin["margin"][0], r=margin["margin"][1], t=margin["margin"][2], b=margin["margin"][3], ) ) # Update axis labels to include units if add_units: try: fig.update_xaxes( title_text=fig.layout.xaxis.title.text + get_units(fig.layout.xaxis.title.text, debug=debug) ) except AttributeError: pass try: fig.update_yaxes( title_text=fig.layout.yaxis.title.text + get_units(fig.layout.yaxis.title.text, debug=debug) ) except AttributeError: pass return fig
[docs]def get_units(var, debug=False): """ Returns the units of a variable based on the variable name Args: var (str): variable name """ units = { "T": " (s) ", "Amplitude": " (mV) ", } unit = "" for unit_key in list(units.keys()): if debug: print("Checking for " + unit_key + " in " + var) if var.endswith(unit_key): unit = units[unit_key] if debug: print("Unit found for " + var) return unit