Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Internal
---------
* Factor `main.py` into several files using mixins.
* Move CLI argument handling back to `main.py`.
* Clean up needless imports in `main.py`.
* Update Python versions used in CI.
* Add CI on macOS.
* Add limited CI on Windows.
Expand Down
24 changes: 15 additions & 9 deletions mycli/cli_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

from mycli.config import str_to_bool
from mycli.constants import EMPTY_PASSWORD_FLAG_SENTINEL, ISSUES_URL, REPO_URL
from mycli.main_modes.batch import main_batch_from_stdin, main_batch_with_progress_bar, main_batch_without_progress_bar
from mycli.main_modes.checkup import main_checkup
from mycli.main_modes.execute import main_execute_from_cli
from mycli.main_modes.list_dsn import main_list_dsn
from mycli.main_modes.list_ssh_config import main_list_ssh_config
from mycli.packages.cli_utils import is_valid_connection_scheme
from mycli.packages.ssh_utils import read_ssh_config

if TYPE_CHECKING:
Expand All @@ -21,7 +27,7 @@
def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> None:
from mycli import main as main_module

cli_verbosity = main_module.preprocess_cli_args(cli_args, main_module.is_valid_connection_scheme)
cli_verbosity = main_module.preprocess_cli_args(cli_args, is_valid_connection_scheme)

mycli = client_factory(
prompt=cli_args.prompt,
Expand All @@ -38,7 +44,7 @@ def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> Non
)

if cli_args.checkup:
main_module.main_checkup(mycli)
main_checkup(mycli)
sys.exit(0)

if cli_args.csv and cli_args.format not in [None, 'csv']:
Expand Down Expand Up @@ -86,10 +92,10 @@ def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> Non
)

if cli_args.list_dsn:
sys.exit(main_module.main_list_dsn(mycli))
sys.exit(main_list_dsn(mycli))

if cli_args.list_ssh_config:
sys.exit(main_module.main_list_ssh_config(mycli, cli_args))
sys.exit(main_list_ssh_config(mycli, cli_args))

if 'MYSQL_UNIX_PORT' in os.environ:
# deprecated 2026-03
Expand Down Expand Up @@ -141,7 +147,7 @@ def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> Non
try:
dsn_uri = mycli.config["alias_dsn"][cli_args.dsn]
except KeyError:
is_valid_scheme, scheme = main_module.is_valid_connection_scheme(cli_args.dsn)
is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.dsn)
if is_valid_scheme:
dsn_uri = cli_args.dsn
else:
Expand Down Expand Up @@ -410,16 +416,16 @@ def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> Non
)

if cli_args.execute is not None:
sys.exit(main_module.main_execute_from_cli(mycli, cli_args))
sys.exit(main_execute_from_cli(mycli, cli_args))

if cli_args.batch is not None and cli_args.batch != '-' and cli_args.progress and sys.stderr.isatty():
sys.exit(main_module.main_batch_with_progress_bar(mycli, cli_args))
sys.exit(main_batch_with_progress_bar(mycli, cli_args))

if cli_args.batch is not None:
sys.exit(main_module.main_batch_without_progress_bar(mycli, cli_args))
sys.exit(main_batch_without_progress_bar(mycli, cli_args))

if not sys.stdin.isatty():
sys.exit(main_module.main_batch_from_stdin(mycli, cli_args))
sys.exit(main_batch_from_stdin(mycli, cli_args))

mycli.run_cli()
mycli.close()
41 changes: 22 additions & 19 deletions mycli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
from typing import IO, Literal

from cli_helpers.tabular_output import TabularOutputFormatter
from prompt_toolkit.formatted_text import to_formatted_text
from prompt_toolkit.shortcuts import PromptSession
import sqlparse
Expand All @@ -21,12 +22,17 @@
from mycli.client_commands import ClientCommandsMixin
from mycli.client_connection import ClientConnectionMixin
from mycli.client_query import ClientQueryMixin
from mycli.clistyle import style_factory_helpers, style_factory_ptoolkit
from mycli.completion_refresher import CompletionRefresher
from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, write_default_config
from mycli.constants import DEFAULT_PROMPT
from mycli.main_modes import repl as repl_package
from mycli.output import OutputMixin
from mycli.packages import special
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.tabular_output import sql_format
from mycli.schema_prefetcher import SchemaPrefetcher
from mycli.sqlcompleter import SQLCompleter
from mycli.sqlexecute import SQLExecute
from mycli.types import Query

Expand Down Expand Up @@ -94,16 +100,15 @@ def __init__(

# Load config.
config_files: list[str | IO[str]] = self.system_config_files + [myclirc] + [self.pwd_config_file]
from mycli import main as main_module

c = self.config = main_module.read_config_files(config_files)
c = self.config = read_config_files(config_files)
# this parallel config exists to
# * compare with my.cnf
# * support the --checkup feature
# todo: after removing my.cnf, create the parallel configs only when --checkup is set
self.config_without_package_defaults = main_module.read_config_files(config_files, ignore_package_defaults=True)
self.config_without_package_defaults = read_config_files(config_files, ignore_package_defaults=True)
# this parallel config exists to compare with my.cnf support the --checkup feature
self.config_without_user_options = main_module.read_config_files(config_files, ignore_user_options=True)
self.config_without_user_options = read_config_files(config_files, ignore_user_options=True)
self.multi_line = c["main"].as_bool("multi_line")
self.key_bindings = c["main"]["key_bindings"]
self.emacs_ttimeoutlen = c['keys'].as_float('emacs_ttimeoutlen')
Expand All @@ -120,8 +125,8 @@ def __init__(
FavoriteQueries.instance = FavoriteQueries.from_config(self.config)

self.dsn_alias: str | None = None
self.main_formatter = main_module.TabularOutputFormatter(format_name=c["main"]["table_format"])
self.redirect_formatter = main_module.TabularOutputFormatter(format_name=c["main"].get("redirect_format", "csv"))
self.main_formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
self.redirect_formatter = TabularOutputFormatter(format_name=c["main"].get("redirect_format", "csv"))
sql_format.register_new_formatter(self.main_formatter)
sql_format.register_new_formatter(self.redirect_formatter)
self.main_formatter.mycli = self
Expand All @@ -131,9 +136,9 @@ def __init__(
if cli_verbosity:
self.verbosity = cli_verbosity
self.cli_style = c["colors"]
self.ptoolkit_style = main_module.style_factory_ptoolkit(self.syntax_style, self.cli_style)
self.helpers_style = main_module.style_factory_helpers(self.syntax_style, self.cli_style)
self.helpers_warnings_style = main_module.style_factory_helpers(self.syntax_style, self.cli_style, warnings=True)
self.ptoolkit_style = style_factory_ptoolkit(self.syntax_style, self.cli_style)
self.helpers_style = style_factory_helpers(self.syntax_style, self.cli_style)
self.helpers_warnings_style = style_factory_helpers(self.syntax_style, self.cli_style, warnings=True)
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
c_dest_warning = c["main"].as_bool("destructive_warning")
self.destructive_warning = c_dest_warning if warn is None else warn
Expand All @@ -153,7 +158,7 @@ def __init__(

# Write user config if system config wasn't the last config loaded.
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
main_module.write_default_config(myclirc)
write_default_config(myclirc)

# audit log
if self.logfile is None and "audit_log" in c["main"]:
Expand All @@ -163,11 +168,11 @@ def __init__(
self.echo("Error: Unable to open the audit log file. Your queries will not be logged.", err=True, fg="red")
self.logfile = False

self.completion_refresher = main_module.CompletionRefresher()
self.completion_refresher = CompletionRefresher()
self.prefetch_schemas_mode = c["main"].get("prefetch_schemas_mode", "always") or "always"
raw_prefetch_list = c["main"].as_list("prefetch_schemas_list") if "prefetch_schemas_list" in c["main"] else []
self.prefetch_schemas_list = [s.strip() for s in raw_prefetch_list if s and s.strip()]
self.schema_prefetcher = main_module.SchemaPrefetcher(self)
self.schema_prefetcher = SchemaPrefetcher(self)

self.logger = logging.getLogger(__name__)
self.initialize_logging()
Expand All @@ -180,7 +185,7 @@ def __init__(

# Initialize completer.
self.smart_completion = c["main"].as_bool("smart_completion")
self.completer = main_module.SQLCompleter(
self.completer = SQLCompleter(
self.smart_completion, supported_formats=self.main_formatter.supported_formats, keyword_casing=keyword_casing
)
self._completer_lock = threading.Lock()
Expand All @@ -195,17 +200,17 @@ def __init__(
self.register_special_commands()

# Load .mylogin.cnf if it exists.
mylogin_cnf_path = main_module.get_mylogin_cnf_path()
mylogin_cnf_path = get_mylogin_cnf_path()
if mylogin_cnf_path:
mylogin_cnf = main_module.open_mylogin_cnf(mylogin_cnf_path)
mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
if mylogin_cnf_path and mylogin_cnf:
# .mylogin.cnf gets read last, even if defaults_file is specified.
self.cnf_files.append(mylogin_cnf)
elif mylogin_cnf_path and not mylogin_cnf:
# There was an error reading the login path file.
print("Error: Unable to read login path file.")

self.my_cnf = main_module.read_config_files(self.cnf_files, list_values=False)
self.my_cnf = read_config_files(self.cnf_files, list_values=False)
ensure_my_cnf_sections(self.my_cnf)
prompt_cnf = self.read_my_cnf(self.my_cnf, ["prompt"])["prompt"]
configure_prompt_state(self, c, prompt, prompt_cnf, toolbar_format)
Expand All @@ -220,6 +225,4 @@ def close(self) -> None:
self.sqlexecute.close()

def run_cli(self) -> None:
from mycli import main as main_module

main_module.main_repl(self)
repl_package.main_repl(self)
24 changes: 11 additions & 13 deletions mycli/client_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

import click

from mycli.main_modes.repl import set_all_external_titles
from mycli.packages import special
from mycli.packages.filepaths import dir_path_exists
from mycli.packages.interactive_utils import confirm_destructive_query
from mycli.packages.special.main import ArgType, SpecialCommandAlias
from mycli.packages.sqlresult import SQLResult
from mycli.sqlexecute import SQLExecute


class ClientCommandsMixin:
Expand Down Expand Up @@ -124,19 +128,17 @@ def change_db(self, arg: str, **_) -> Generator[SQLResult, None, None]:
click.secho("No database selected", err=True, fg="red")
return

# todo: this jump back to repl.py is a sign that separation is incomplete.
# also: it should not be needed. Don't titles update on every new prompt?
from mycli import main as main_module

assert isinstance(self.sqlexecute, main_module.SQLExecute)
assert isinstance(self.sqlexecute, SQLExecute)

if self.sqlexecute.dbname == arg:
msg = f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'
else:
self.sqlexecute.change_db(arg)
msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'

main_module.set_all_external_titles(cast(Any, self))
# todo: this jump back to repl.py is a sign that separation is incomplete.
# also: it should not be needed. Don't titles update on every new prompt?
set_all_external_titles(cast(Any, self))

yield SQLResult(status=msg)

Expand All @@ -150,13 +152,11 @@ def execute_from_file(self, arg: str, **_) -> Iterable[SQLResult]:
except IOError as e:
return [SQLResult(status=str(e))]

from mycli import main as main_module

if self.destructive_warning and main_module.confirm_destructive_query(self.destructive_keywords, query) is False:
if self.destructive_warning and confirm_destructive_query(self.destructive_keywords, query) is False:
message = "Wise choice. Command execution stopped."
return [SQLResult(status=message)]

assert isinstance(self.sqlexecute, main_module.SQLExecute)
assert isinstance(self.sqlexecute, SQLExecute)
return self.sqlexecute.run(query)

def change_prompt_format(self, arg: str, **_) -> list[SQLResult]:
Expand All @@ -182,14 +182,12 @@ def initialize_logging(self) -> None:
"DEBUG": logging.DEBUG,
}

from mycli import main as main_module

# Disable logging if value is NONE by switching to a no-op handler
# Set log level to a high value so it doesn't even waste cycles getting called.
if log_level.upper() == "NONE":
handler: logging.Handler = logging.NullHandler()
log_level = "CRITICAL"
elif main_module.dir_path_exists(log_file):
elif dir_path_exists(log_file):
handler = logging.FileHandler(log_file)
else:
self.echo(f'Error: Unable to open the log file "{log_file}".', err=True, fg="red")
Expand Down
18 changes: 10 additions & 8 deletions mycli/client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
from typing import TYPE_CHECKING, Any

import click
import keyring
import pymysql
from pymysql.constants.CR import CR_SERVER_LOST
from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR

from mycli.compat import WIN
from mycli.config import str_to_bool
from mycli.constants import (
DEFAULT_CHARSET,
DEFAULT_HOST,
DEFAULT_PORT,
EMPTY_PASSWORD_FLAG_SENTINEL,
ER_MUST_CHANGE_PASSWORD_LOGIN,
)
from mycli.packages.filepaths import guess_socket_location
from mycli.sqlexecute import SQLExecute

try:
from pwd import getpwuid
Expand Down Expand Up @@ -61,8 +65,6 @@ def connect(
reset_keyring: bool | None = None,
keepalive_ticks: int | None = None,
) -> None:
from mycli import main as main_module

cnf = {
"database": None,
"user": None,
Expand Down Expand Up @@ -101,7 +103,7 @@ def connect(
or user_connection_config.get("default_socket")
or cnf["socket"]
or cnf["default_socket"]
or main_module.guess_socket_location()
or guess_socket_location()
)

passwd = passwd if isinstance(passwd, (str, int)) else cnf["password"]
Expand Down Expand Up @@ -133,7 +135,7 @@ def connect(
False,
):
try:
use_local_infile = main_module.str_to_bool(local_infile_option or '')
use_local_infile = str_to_bool(local_infile_option or '')
break
except (TypeError, ValueError):
pass
Expand Down Expand Up @@ -176,7 +178,7 @@ def connect(
keyring_retrieved_cleanly = False

if passwd is None and use_keyring and not reset_keyring:
passwd = main_module.keyring.get_password(keyring_domain, keyring_identifier)
passwd = keyring.get_password(keyring_domain, keyring_identifier)
if passwd is not None:
keyring_retrieved_cleanly = True

Expand Down Expand Up @@ -212,9 +214,9 @@ def _update_keyring(password: str | None, keyring_retrieved_cleanly: bool):
return
if reset_keyring or (use_keyring and not keyring_retrieved_cleanly):
try:
saved_pw = main_module.keyring.get_password(keyring_domain, keyring_identifier)
saved_pw = keyring.get_password(keyring_domain, keyring_identifier)
if password != saved_pw or reset_keyring:
main_module.keyring.set_password(keyring_domain, keyring_identifier, password)
keyring.set_password(keyring_domain, keyring_identifier, password)
click.secho(f'Password saved to the system keyring at {keyring_domain}/{keyring_identifier}', err=True)
except Exception as e:
click.secho(f'Password not saved to the system keyring: {e}', err=True, fg='red')
Expand All @@ -228,7 +230,7 @@ def _connect(
try:
if keyring_save_eligible:
_update_keyring(connection_info["password"], keyring_retrieved_cleanly=keyring_retrieved_cleanly)
self.sqlexecute = main_module.SQLExecute(**connection_info)
self.sqlexecute = SQLExecute(**connection_info)
except pymysql.OperationalError as e1:
if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto":
# if we already tried and failed to connect without SSL, raise the error
Expand Down
Loading
Loading