| # Copyright (C) 2021-2022 Apple Inc. All rights reserved. |
| # |
| # Redistribution and use in source and binary forms, with or without |
| # modification, are permitted provided that the following conditions |
| # are met: |
| # 1. Redistributions of source code must retain the above copyright |
| # notice, this list of conditions and the following disclaimer. |
| # 2. Redistributions in binary form must reproduce the above copyright |
| # notice, this list of conditions and the following disclaimer in the |
| # documentation and/or other materials provided with the distribution. |
| # |
| # THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS "AS IS" AND |
| # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| # DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS BE LIABLE FOR |
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| import os |
| import re |
| import sys |
| |
| from .command import Command |
| from .branch import Branch |
| from .squash import Squash |
| |
| from webkitbugspy import Tracker |
| from webkitcorepy import arguments, run, Terminal |
| from webkitscmpy import local, log, remote |
| |
| |
| class PullRequest(Command): |
| name = 'pull-request' |
| aliases = ['pr', 'pfr', 'upload'] |
| help = 'Push the current checkout state as a pull-request' |
| BLOCKED_LABEL = 'merging-blocked' |
| MERGE_LABELS = ['merge-queue'] |
| UNSAFE_MERGE_LABELS = ['unsafe-merge-queue'] |
| |
| @classmethod |
| def parser(cls, parser, loggers=None): |
| Branch.parser(parser, loggers=loggers) |
| Squash.parser(parser, loggers=loggers) |
| parser.add_argument( |
| '--add', '--no-add', |
| dest='will_add', default=None, |
| help='When drafting a change, add (or never add) modified files to set of staged changes to be committed', |
| action=arguments.NoAction, |
| ) |
| parser.add_argument( |
| '--rebase', '--no-rebase', '--update', '--no-update', |
| dest='rebase', default=None, |
| help='Rebase (or do not rebase) the pull-request on the source branch before pushing', |
| action=arguments.NoAction, |
| ) |
| parser.add_argument( |
| '--squash', '--no-squash', |
| dest='squash', default=None, |
| help='Combine all commits on the current development branch into a single commit before pushing', |
| action=arguments.NoAction, |
| ) |
| parser.add_argument( |
| '--defaults', '--no-defaults', action=arguments.NoAction, default=None, |
| help='Do not prompt the user for defaults, always use (or do not use) them', |
| ) |
| parser.add_argument( |
| '--overwrite', '--amend', action='store_const', const='overwrite', |
| dest='technique', default=None, |
| help='When creating a pull request, overwrite the existing commit by default', |
| ) |
| parser.add_argument( |
| '--append', action='store_const', const='append', |
| dest='technique', default=None, |
| help='When creating a pull request, append a new commit on the existing branch by default', |
| ) |
| parser.add_argument( |
| '--with-history', '--no-history', |
| dest='history', default=None, |
| help='Create numbered branches to track the history of a change', |
| action=arguments.NoAction, |
| ) |
| parser.add_argument( |
| '--draft', dest='draft', action='store_true', default=None, |
| help='Mark a pull request as a draft when creating it', |
| ) |
| parser.add_argument( |
| '--remote', dest='remote', type=str, default=None, |
| help='Make a pull request against a specific remote', |
| ) |
| parser.add_argument( |
| '--checks', '--no-checks', |
| dest='checks', default=None, |
| help='Explicitly enable or disable automatic pre-flight checks', |
| action=arguments.NoAction, |
| ) |
| |
| @classmethod |
| def create_commit(cls, args, repository, **kwargs): |
| # First, find the set of files to be modified |
| modified = [] if args.will_add is False else repository.modified() |
| if args.will_add: |
| modified = list(set(modified).union(set(repository.modified(staged=False)))) |
| |
| # Next, add all modified file |
| for file in set(modified) - set(repository.modified(staged=True)): |
| log.info(' Adding {}...'.format(file)) |
| if run([repository.executable(), 'add', file], cwd=repository.root_path).returncode: |
| sys.stderr.write("Failed to add '{}'\n".format(file)) |
| return 1 |
| |
| # Then, see if we already have a commit associated with this branch we need to modify |
| has_commit = repository.commit(include_log=False, include_identifier=False).branch == repository.branch and repository.branch != repository.default_branch |
| if not modified and has_commit: |
| log.info('Using committed changes...') |
| return 0 |
| |
| bug_urls = getattr(args, '_bug_urls', None) or '' |
| if isinstance(bug_urls, (list, tuple)): |
| bug_urls = '\n'.join(bug_urls) |
| |
| # Otherwise, we need to create a commit |
| will_amend = has_commit and args.technique == 'overwrite' |
| if not modified: |
| sys.stderr.write('No modified files\n') |
| return 1 |
| log.info('Amending commit...' if will_amend else 'Creating commit...') |
| env = os.environ |
| env['COMMIT_MESSAGE_TITLE'] = getattr(args, '_title', None) or '' |
| env['COMMIT_MESSAGE_BUG'] = bug_urls |
| if run( |
| [repository.executable(), 'commit', '--date=now'] + (['--amend'] if will_amend else []), |
| cwd=repository.root_path, |
| env=env, |
| ).returncode: |
| sys.stderr.write('Failed to generate commit\n') |
| return 1 |
| |
| return 0 |
| |
| @classmethod |
| def title_for(cls, commits): |
| title = os.path.commonprefix([commit.message.splitlines()[0] for commit in commits]) |
| if not title: |
| title = commits[0].message.splitlines()[0] |
| title = title.rstrip().lstrip() |
| return title[:-5].rstrip() if title.endswith('(Part') else title |
| |
| @classmethod |
| def check_pull_request_args(cls, repository, args): |
| if not args.technique: |
| args.technique = repository.config()['webkitscmpy.pull-request'] |
| if args.history is None: |
| args.history = dict( |
| always=True, |
| disabled=False, |
| never=False, |
| ).get(repository.config()['webkitscmpy.history']) |
| if args.history and repository.config()['webkitscmpy.history'] == 'never': |
| sys.stderr.write('History retention was requested, but repository configuration forbids it\n') |
| return False |
| return True |
| |
| @classmethod |
| def pull_request_branch_point(cls, repository, args, **kwargs): |
| # FIXME: We can do better by infering the remote from the branch point, if it's not specified |
| source_remote = args.remote or 'origin' |
| |
| if repository.branch is None or repository.branch in repository.DEFAULT_BRANCHES or repository.PROD_BRANCHES.match(repository.branch): |
| if Branch.main( |
| args, repository, |
| why="'{}' is not a pull request branch".format(repository.branch), |
| redact=source_remote != 'origin', **kwargs |
| ): |
| sys.stderr.write("Abandoning pushing pull-request because '{}' could not be created\n".format(args.issue)) |
| return None |
| elif args.issue and repository.branch != args.issue: |
| sys.stderr.write("Creating a pull-request for '{}' but we're on '{}'\n".format(args.issue, repository.branch)) |
| return None |
| |
| if not repository.config().get('remote.{}.url'.format(source_remote)): |
| sys.stderr.write("'{}' is not a remote in this repository\n".format(source_remote)) |
| return None |
| |
| branch_point = Branch.branch_point(repository) |
| if run([ |
| repository.executable(), 'branch', '-f', |
| branch_point.branch, |
| 'remotes/{}/{}'.format(source_remote, branch_point.branch), |
| ], cwd=repository.root_path).returncode: |
| sys.stderr.write("Failed to match '{}' to it's remote '{}'\n".format(branch_point.branch, source_remote)) |
| return None |
| return branch_point |
| |
| @classmethod |
| def find_existing_pull_request(cls, repository, remote): |
| existing_pr = None |
| for pr in remote.pull_requests.find(opened=None, head=repository.branch): |
| existing_pr = pr |
| if existing_pr.opened: |
| break |
| return existing_pr |
| |
| @classmethod |
| def pre_pr_checks(cls, repository): |
| num_checks = 0 |
| log.info('Running pre-PR checks...') |
| for key, path in repository.config().items(): |
| if not key.startswith('webkitscmpy.pre-pr.'): |
| continue |
| num_checks += 1 |
| name = key.split('.')[-1] |
| log.info(' Running {}...'.format(name)) |
| command = run(path.split(' '), cwd=repository.root_path) |
| if command.returncode: |
| if Terminal.choose( |
| '{} failed, continue uploading pull request?'.format(name), |
| default='No', |
| ) == 'No': |
| sys.stderr.write('Pre-PR check {} failed\n'.format(name)) |
| return False |
| else: |
| log.info(' {} failed, continuing PR upload anyway'.format(name)) |
| else: |
| log.info(' Ran {}!'.format(name)) |
| |
| if num_checks: |
| log.info('All pre-PR checks run!') |
| else: |
| log.info('No pre-PR checks to run') |
| return True |
| |
| @classmethod |
| def is_revert_commit(cls, commit): |
| msg = commit.message.split() |
| if not len(msg): |
| return False |
| title = msg[0] |
| return title.startswith('Revert') |
| |
| @classmethod |
| def add_comment_to_reverted_commit_bug_tracker(cls, repository, args, pr, commit): |
| source_remote = args.remote or 'origin' |
| rmt = repository.remote(name=source_remote) |
| if not rmt: |
| sys.stderr.write("'{}' doesn't have a recognized remote\n".format(repository.root_path)) |
| return 1 |
| if not rmt.pull_requests: |
| sys.stderr.write("'{}' cannot generate pull-requests\n".format(rmt.url)) |
| return 1 |
| |
| log.info('Adding comment for reverted commits...') |
| for line in commit.message.split(): |
| tracker = Tracker.from_string(line) |
| if tracker: |
| tracker.add_comment('Reverted by {}'.format(pr.link)) |
| tracker.set(opened=True) |
| continue |
| return 0 |
| |
| @classmethod |
| def create_pull_request(cls, repository, args, branch_point, callback=None, unblock=True): |
| # FIXME: We can do better by inferring the remote from the branch point, if it's not specified |
| source_remote = args.remote or 'origin' |
| if not repository.config().get('remote.{}.url'.format(source_remote)): |
| sys.stderr.write("'{}' is not a remote in this repository\n".format(source_remote)) |
| return 1 |
| |
| rebasing = args.rebase if args.rebase is not None else repository.config().get( |
| 'webkitscmpy.auto-rebase-branch', |
| repository.config().get('pull.rebase', 'true'), |
| ) == 'true' |
| |
| if rebasing: |
| log.info("Rebasing '{}' on '{}'...".format(repository.branch, branch_point.branch)) |
| if repository.pull(rebase=True, branch=branch_point.branch): |
| sys.stderr.write("Failed to rebase '{}' on '{},' please resolve conflicts\n".format(repository.branch, branch_point.branch)) |
| return 1 |
| log.info("Rebased '{}' on '{}!'".format(repository.branch, branch_point.branch)) |
| branch_point = Branch.branch_point(repository) |
| |
| if args.checks is None: |
| args.checks = repository.config().get('webkitscmpy.auto-check', 'false') == 'true' |
| if args.checks and not cls.pre_pr_checks(repository): |
| sys.stderr.write('Checks have failed, aborting pull request.\n') |
| return 1 |
| |
| remote_repo = repository.remote(name=source_remote) |
| if not remote_repo: |
| sys.stderr.write("'{}' doesn't have a recognized remote\n".format(repository.root_path)) |
| return 1 |
| |
| existing_pr = None |
| if remote_repo.pull_requests: |
| existing_pr = cls.find_existing_pull_request(repository, remote_repo) |
| if existing_pr and not existing_pr.opened and not args.defaults and (args.defaults is False or Terminal.choose( |
| "'{}' is already associated with '{}', which is closed.\nWould you like to create a new pull-request?".format( |
| repository.branch, existing_pr, |
| ), default='No', |
| ) == 'Yes'): |
| existing_pr = None |
| |
| # Remove any active labels |
| if existing_pr and existing_pr._metadata and existing_pr._metadata.get('issue'): |
| log.info("Checking PR labels for active labels...") |
| pr_issue = existing_pr._metadata['issue'] |
| labels = pr_issue.labels |
| did_remove = False |
| for to_remove in cls.MERGE_LABELS + cls.UNSAFE_MERGE_LABELS + ([cls.BLOCKED_LABEL] if unblock else []): |
| if to_remove in labels: |
| log.info("Removing '{}' from PR {}...".format(to_remove, existing_pr.number)) |
| labels.remove(to_remove) |
| did_remove = True |
| if did_remove: |
| pr_issue.set_labels(labels) |
| |
| if isinstance(remote_repo, remote.GitHub): |
| target = 'fork' if source_remote == 'origin' else '{}-fork'.format(source_remote) |
| if not repository.config().get('remote.{}.url'.format(target)): |
| sys.stderr.write("'{}' is not a remote in this repository. Have you run `{} setup` yet?\n".format( |
| source_remote, os.path.basename(sys.argv[0]), |
| )) |
| return 1 |
| else: |
| target = source_remote |
| |
| log.info("Pushing '{}' to '{}'...".format(repository.branch, target)) |
| if run([repository.executable(), 'push', '-f', target, repository.branch], cwd=repository.root_path).returncode: |
| sys.stderr.write("Failed to push '{}' to '{}' (alias of '{}')\n".format(repository.branch, target, repository.url(name=target))) |
| sys.stderr.write("Your checkout may be mis-configured, try re-running 'git-webkit setup' or\n") |
| sys.stderr.write("your checkout may not have permission to push to '{}'\n".format(repository.url(name=target))) |
| return 1 |
| |
| if rebasing and target.endswith('fork') and repository.config().get('webkitscmpy.update-fork', 'false') == 'true': |
| log.info("Syncing '{}' to remote '{}'".format(branch_point.branch, target)) |
| if run([repository.executable(), 'push', target, '{branch}:{branch}'.format(branch=branch_point.branch)], cwd=repository.root_path).returncode: |
| sys.stderr.write("Failed to sync '{}' to '{}.' Error is non fatal, continuing...\n".format(branch_point.branch, target)) |
| |
| if args.history or (target != source_remote and args.history is None and args.technique == 'overwrite'): |
| regex = re.compile(r'^{}-(?P<count>\d+)$'.format(repository.branch)) |
| count = max([ |
| int(regex.match(branch).group('count')) if regex.match(branch) else 0 for branch in |
| repository.branches_for(remote=target) |
| ] + [0]) + 1 |
| |
| history_branch = '{}-{}'.format(repository.branch, count) |
| log.info("Creating '{}' as a reference branch".format(history_branch)) |
| if run([ |
| repository.executable(), 'branch', history_branch, repository.branch, |
| ], cwd=repository.root_path).returncode or run([ |
| repository.executable(), 'push', '-f', target, history_branch, |
| ], cwd=repository.root_path).returncode: |
| sys.stderr.write("Failed to create and push '{}' to '{}'\n".format(history_branch, target)) |
| |
| if not remote_repo.pull_requests: |
| sys.stderr.write("'{}' cannot generate pull-requests\n".format(remote_repo.url)) |
| return 1 |
| if args.draft and not remote_repo.pull_requests.SUPPORTS_DRAFTS: |
| sys.stderr.write("'{}' does not support draft pull requests, aborting\n".format(remote_repo.url)) |
| return 1 |
| |
| commits = list(repository.commits(begin=dict(hash=branch_point.hash), end=dict(branch=repository.branch))) |
| |
| issue = None |
| for line in commits[0].message.split() if commits[0] and commits[0].message else []: |
| issue = Tracker.from_string(line) |
| if issue: |
| break |
| |
| if existing_pr: |
| log.info("Updating pull-request for '{}'...".format(repository.branch)) |
| pr = remote_repo.pull_requests.update( |
| pull_request=existing_pr, |
| title=cls.title_for(commits), |
| commits=commits, |
| base=branch_point.branch, |
| head=repository.branch, |
| opened=None if existing_pr.opened else True, |
| draft=args.draft, |
| ) |
| if not pr: |
| sys.stderr.write("Failed to update pull-request '{}'\n".format(existing_pr)) |
| return 1 |
| print("Updated '{}'!".format(pr)) |
| else: |
| log.info("Creating pull-request for '{}'...".format(repository.branch)) |
| pr = remote_repo.pull_requests.create( |
| title=cls.title_for(commits), |
| commits=commits, |
| base=branch_point.branch, |
| head=repository.branch, |
| draft=args.draft, |
| ) |
| if not pr: |
| sys.stderr.write("Failed to create pull-request for '{}'\n".format(repository.branch)) |
| return 1 |
| print("Created '{}'!".format(pr)) |
| if cls.is_revert_commit(commits[0]): |
| cls.add_comment_to_reverted_commit_bug_tracker(repository, args, pr, commits[0]) |
| |
| if issue: |
| log.info('Checking issue assignee...') |
| if issue.assignee != issue.tracker.me(): |
| issue.assign(issue.tracker.me()) |
| print('Assigning associated issue to {}'.format(issue.tracker.me())) |
| log.info('Checking for pull request link in associated issue...') |
| if pr.url and not any([pr.url in comment.content for comment in issue.comments]): |
| if issue.opened: |
| issue.add_comment('Pull request: {}'.format(pr.url)) |
| else: |
| issue.open(why='Re-opening for pull request {}'.format(pr.url)) |
| print('Posted pull request link to {}'.format(issue.link)) |
| |
| if issue and pr._metadata and pr._metadata.get('issue'): |
| log.info('Syncing PR labels with issue component...') |
| pr_issue = pr._metadata['issue'] |
| project = pr_issue.tracker.name |
| component = issue.component |
| if pr_issue.component == component or component not in pr_issue.tracker.projects.get(project, {}).get('components', {}): |
| component = None |
| version = issue.version |
| if pr_issue.version == version or version not in pr_issue.tracker.projects.get(project, {}).get('versions', []): |
| version = None |
| if component or version: |
| pr_issue.set_component(component=component, version=version) |
| log.info('Synced PR labels with issue component!') |
| else: |
| log.info('No label syncing required') |
| |
| if pr.url: |
| print(pr.url) |
| |
| if callback: |
| return callback(pr) |
| return 0 |
| |
| @classmethod |
| def main(cls, args, repository, **kwargs): |
| if not isinstance(repository, local.Git): |
| sys.stderr.write("Can only '{}' on a native Git repository\n".format(cls.name)) |
| return 1 |
| if not cls.check_pull_request_args(repository, args): |
| return 1 |
| |
| branch_point = cls.pull_request_branch_point(repository, args, **kwargs) |
| if not branch_point: |
| return 1 |
| |
| result = cls.create_commit(args, repository, **kwargs) |
| if result: |
| return result |
| if args.squash: |
| result = Squash.squash_commit(args, repository, branch_point, **kwargs) |
| if result: |
| return result |
| |
| return cls.create_pull_request(repository, args, branch_point) |