# Copyright (C) 2021-2022 Apple Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1.  Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
# 2.  Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in the
#     documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import re
import sys

from .command import Command
from .branch import Branch
from .squash import Squash

from webkitbugspy import Tracker
from webkitcorepy import arguments, run, Terminal
from webkitscmpy import local, log, remote


class PullRequest(Command):
    name = 'pull-request'
    aliases = ['pr', 'pfr', 'upload']
    help = 'Push the current checkout state as a pull-request'
    BLOCKED_LABEL = 'merging-blocked'
    MERGE_LABELS = ['merge-queue']
    UNSAFE_MERGE_LABELS = ['unsafe-merge-queue']

    @classmethod
    def parser(cls, parser, loggers=None):
        Branch.parser(parser, loggers=loggers)
        Squash.parser(parser, loggers=loggers)
        parser.add_argument(
            '--add', '--no-add',
            dest='will_add', default=None,
            help='When drafting a change, add (or never add) modified files to set of staged changes to be committed',
            action=arguments.NoAction,
        )
        parser.add_argument(
            '--rebase', '--no-rebase', '--update', '--no-update',
            dest='rebase', default=None,
            help='Rebase (or do not rebase) the pull-request on the source branch before pushing',
            action=arguments.NoAction,
        )
        parser.add_argument(
            '--squash', '--no-squash',
            dest='squash', default=None,
            help='Combine all commits on the current development branch into a single commit before pushing',
            action=arguments.NoAction,
        )
        parser.add_argument(
            '--defaults', '--no-defaults', action=arguments.NoAction, default=None,
            help='Do not prompt the user for defaults, always use (or do not use) them',
        )
        parser.add_argument(
            '--overwrite', '--amend', action='store_const', const='overwrite',
            dest='technique', default=None,
            help='When creating a pull request, overwrite the existing commit by default',
        )
        parser.add_argument(
            '--append', action='store_const', const='append',
            dest='technique', default=None,
            help='When creating a pull request, append a new commit on the existing branch by default',
        )
        parser.add_argument(
            '--with-history', '--no-history',
            dest='history', default=None,
            help='Create numbered branches to track the history of a change',
            action=arguments.NoAction,
        )
        parser.add_argument(
            '--draft', dest='draft', action='store_true', default=None,
            help='Mark a pull request as a draft when creating it',
        )
        parser.add_argument(
            '--remote', dest='remote', type=str, default=None,
            help='Make a pull request against a specific remote',
        )
        parser.add_argument(
            '--checks', '--no-checks',
            dest='checks', default=None,
            help='Explicitly enable or disable automatic pre-flight checks',
            action=arguments.NoAction,
        )

    @classmethod
    def create_commit(cls, args, repository, **kwargs):
        # First, find the set of files to be modified
        modified = [] if args.will_add is False else repository.modified()
        if args.will_add:
            modified = list(set(modified).union(set(repository.modified(staged=False))))

        # Next, add all modified file
        for file in set(modified) - set(repository.modified(staged=True)):
            log.info('    Adding {}...'.format(file))
            if run([repository.executable(), 'add', file], cwd=repository.root_path).returncode:
                sys.stderr.write("Failed to add '{}'\n".format(file))
                return 1

        # Then, see if we already have a commit associated with this branch we need to modify
        has_commit = repository.commit(include_log=False, include_identifier=False).branch == repository.branch and repository.branch != repository.default_branch
        if not modified and has_commit:
            log.info('Using committed changes...')
            return 0

        bug_urls = getattr(args, '_bug_urls', None) or ''
        if isinstance(bug_urls, (list, tuple)):
            bug_urls = '\n'.join(bug_urls)

        # Otherwise, we need to create a commit
        will_amend = has_commit and args.technique == 'overwrite'
        if not modified:
            sys.stderr.write('No modified files\n')
            return 1
        log.info('Amending commit...' if will_amend else 'Creating commit...')
        env = os.environ
        env['COMMIT_MESSAGE_TITLE'] = getattr(args, '_title', None) or ''
        env['COMMIT_MESSAGE_BUG'] = bug_urls
        if run(
            [repository.executable(), 'commit', '--date=now'] + (['--amend'] if will_amend else []),
            cwd=repository.root_path,
            env=env,
        ).returncode:
            sys.stderr.write('Failed to generate commit\n')
            return 1

        return 0

    @classmethod
    def title_for(cls, commits):
        title = os.path.commonprefix([commit.message.splitlines()[0] for commit in commits])
        if not title:
            title = commits[0].message.splitlines()[0]
        title = title.rstrip().lstrip()
        return title[:-5].rstrip() if title.endswith('(Part') else title

    @classmethod
    def check_pull_request_args(cls, repository, args):
        if not args.technique:
            args.technique = repository.config()['webkitscmpy.pull-request']
        if args.history is None:
            args.history = dict(
                always=True,
                disabled=False,
                never=False,
            ).get(repository.config()['webkitscmpy.history'])
        if args.history and repository.config()['webkitscmpy.history'] == 'never':
            sys.stderr.write('History retention was requested, but repository configuration forbids it\n')
            return False
        return True

    @classmethod
    def pull_request_branch_point(cls, repository, args, **kwargs):
        # FIXME: We can do better by infering the remote from the branch point, if it's not specified
        source_remote = args.remote or 'origin'

        if repository.branch is None or repository.branch in repository.DEFAULT_BRANCHES or repository.PROD_BRANCHES.match(repository.branch):
            if Branch.main(
                args, repository,
                why="'{}' is not a pull request branch".format(repository.branch),
                redact=source_remote != 'origin', **kwargs
            ):
                sys.stderr.write("Abandoning pushing pull-request because '{}' could not be created\n".format(args.issue))
                return None
        elif args.issue and repository.branch != args.issue:
            sys.stderr.write("Creating a pull-request for '{}' but we're on '{}'\n".format(args.issue, repository.branch))
            return None

        if not repository.config().get('remote.{}.url'.format(source_remote)):
            sys.stderr.write("'{}' is not a remote in this repository\n".format(source_remote))
            return None

        branch_point = Branch.branch_point(repository)
        if run([
            repository.executable(), 'branch', '-f',
            branch_point.branch,
            'remotes/{}/{}'.format(source_remote, branch_point.branch),
        ], cwd=repository.root_path).returncode:
            sys.stderr.write("Failed to match '{}' to it's remote '{}'\n".format(branch_point.branch, source_remote))
            return None
        return branch_point

    @classmethod
    def find_existing_pull_request(cls, repository, remote):
        existing_pr = None
        for pr in remote.pull_requests.find(opened=None, head=repository.branch):
            existing_pr = pr
            if existing_pr.opened:
                break
        return existing_pr

    @classmethod
    def pre_pr_checks(cls, repository):
        num_checks = 0
        log.info('Running pre-PR checks...')
        for key, path in repository.config().items():
            if not key.startswith('webkitscmpy.pre-pr.'):
                continue
            num_checks += 1
            name = key.split('.')[-1]
            log.info('    Running {}...'.format(name))
            command = run(path.split(' '), cwd=repository.root_path)
            if command.returncode:
                if Terminal.choose(
                    '{} failed, continue uploading pull request?'.format(name),
                    default='No',
                ) == 'No':
                    sys.stderr.write('Pre-PR check {} failed\n'.format(name))
                    return False
                else:
                    log.info('    {} failed, continuing PR upload anyway'.format(name))
            else:
                log.info('    Ran {}!'.format(name))

        if num_checks:
            log.info('All pre-PR checks run!')
        else:
            log.info('No pre-PR checks to run')
        return True

    @classmethod
    def is_revert_commit(cls, commit):
        msg = commit.message.split()
        if not len(msg):
            return False
        title = msg[0]
        return title.startswith('Revert')

    @classmethod
    def add_comment_to_reverted_commit_bug_tracker(cls, repository, args, pr, commit):
        source_remote = args.remote or 'origin'
        rmt = repository.remote(name=source_remote)
        if not rmt:
            sys.stderr.write("'{}' doesn't have a recognized remote\n".format(repository.root_path))
            return 1
        if not rmt.pull_requests:
            sys.stderr.write("'{}' cannot generate pull-requests\n".format(rmt.url))
            return 1

        log.info('Adding comment for reverted commits...')
        for line in commit.message.split():
            tracker = Tracker.from_string(line)
            if tracker:
                tracker.add_comment('Reverted by {}'.format(pr.link))
                tracker.set(opened=True)
                continue
        return 0

    @classmethod
    def create_pull_request(cls, repository, args, branch_point, callback=None, unblock=True):
        # FIXME: We can do better by inferring the remote from the branch point, if it's not specified
        source_remote = args.remote or 'origin'
        if not repository.config().get('remote.{}.url'.format(source_remote)):
            sys.stderr.write("'{}' is not a remote in this repository\n".format(source_remote))
            return 1

        rebasing = args.rebase if args.rebase is not None else repository.config().get(
            'webkitscmpy.auto-rebase-branch',
            repository.config().get('pull.rebase', 'true'),
        ) == 'true'

        if rebasing:
            log.info("Rebasing '{}' on '{}'...".format(repository.branch, branch_point.branch))
            if repository.pull(rebase=True, branch=branch_point.branch):
                sys.stderr.write("Failed to rebase '{}' on '{},' please resolve conflicts\n".format(repository.branch, branch_point.branch))
                return 1
            log.info("Rebased '{}' on '{}!'".format(repository.branch, branch_point.branch))
            branch_point = Branch.branch_point(repository)

        if args.checks is None:
            args.checks = repository.config().get('webkitscmpy.auto-check', 'false') == 'true'
        if args.checks and not cls.pre_pr_checks(repository):
            sys.stderr.write('Checks have failed, aborting pull request.\n')
            return 1

        remote_repo = repository.remote(name=source_remote)
        if not remote_repo:
            sys.stderr.write("'{}' doesn't have a recognized remote\n".format(repository.root_path))
            return 1

        existing_pr = None
        if remote_repo.pull_requests:
            existing_pr = cls.find_existing_pull_request(repository, remote_repo)
            if existing_pr and not existing_pr.opened and not args.defaults and (args.defaults is False or Terminal.choose(
                    "'{}' is already associated with '{}', which is closed.\nWould you like to create a new pull-request?".format(
                        repository.branch, existing_pr,
                    ), default='No',
            ) == 'Yes'):
                existing_pr = None

        # Remove any active labels
        if existing_pr and existing_pr._metadata and existing_pr._metadata.get('issue'):
            log.info("Checking PR labels for active labels...")
            pr_issue = existing_pr._metadata['issue']
            labels = pr_issue.labels
            did_remove = False
            for to_remove in cls.MERGE_LABELS + cls.UNSAFE_MERGE_LABELS + ([cls.BLOCKED_LABEL] if unblock else []):
                if to_remove in labels:
                    log.info("Removing '{}' from PR {}...".format(to_remove, existing_pr.number))
                    labels.remove(to_remove)
                    did_remove = True
            if did_remove:
                pr_issue.set_labels(labels)

        if isinstance(remote_repo, remote.GitHub):
            target = 'fork' if source_remote == 'origin' else '{}-fork'.format(source_remote)
            if not repository.config().get('remote.{}.url'.format(target)):
                sys.stderr.write("'{}' is not a remote in this repository. Have you run `{} setup` yet?\n".format(
                    source_remote, os.path.basename(sys.argv[0]),
                ))
                return 1
        else:
            target = source_remote

        log.info("Pushing '{}' to '{}'...".format(repository.branch, target))
        if run([repository.executable(), 'push', '-f', target, repository.branch], cwd=repository.root_path).returncode:
            sys.stderr.write("Failed to push '{}' to '{}' (alias of '{}')\n".format(repository.branch, target, repository.url(name=target)))
            sys.stderr.write("Your checkout may be mis-configured, try re-running 'git-webkit setup' or\n")
            sys.stderr.write("your checkout may not have permission to push to '{}'\n".format(repository.url(name=target)))
            return 1

        if rebasing and target.endswith('fork') and repository.config().get('webkitscmpy.update-fork', 'false') == 'true':
            log.info("Syncing '{}' to remote '{}'".format(branch_point.branch, target))
            if run([repository.executable(), 'push', target, '{branch}:{branch}'.format(branch=branch_point.branch)], cwd=repository.root_path).returncode:
                sys.stderr.write("Failed to sync '{}' to '{}.' Error is non fatal, continuing...\n".format(branch_point.branch, target))

        if args.history or (target != source_remote and args.history is None and args.technique == 'overwrite'):
            regex = re.compile(r'^{}-(?P<count>\d+)$'.format(repository.branch))
            count = max([
                int(regex.match(branch).group('count')) if regex.match(branch) else 0 for branch in
                repository.branches_for(remote=target)
            ] + [0]) + 1

            history_branch = '{}-{}'.format(repository.branch, count)
            log.info("Creating '{}' as a reference branch".format(history_branch))
            if run([
                repository.executable(), 'branch', history_branch, repository.branch,
            ], cwd=repository.root_path).returncode or run([
                repository.executable(), 'push', '-f', target, history_branch,
            ], cwd=repository.root_path).returncode:
                sys.stderr.write("Failed to create and push '{}' to '{}'\n".format(history_branch, target))

        if not remote_repo.pull_requests:
            sys.stderr.write("'{}' cannot generate pull-requests\n".format(remote_repo.url))
            return 1
        if args.draft and not remote_repo.pull_requests.SUPPORTS_DRAFTS:
            sys.stderr.write("'{}' does not support draft pull requests, aborting\n".format(remote_repo.url))
            return 1

        commits = list(repository.commits(begin=dict(hash=branch_point.hash), end=dict(branch=repository.branch)))

        issue = None
        for line in commits[0].message.split() if commits[0] and commits[0].message else []:
            issue = Tracker.from_string(line)
            if issue:
                break

        if existing_pr:
            log.info("Updating pull-request for '{}'...".format(repository.branch))
            pr = remote_repo.pull_requests.update(
                pull_request=existing_pr,
                title=cls.title_for(commits),
                commits=commits,
                base=branch_point.branch,
                head=repository.branch,
                opened=None if existing_pr.opened else True,
                draft=args.draft,
            )
            if not pr:
                sys.stderr.write("Failed to update pull-request '{}'\n".format(existing_pr))
                return 1
            print("Updated '{}'!".format(pr))
        else:
            log.info("Creating pull-request for '{}'...".format(repository.branch))
            pr = remote_repo.pull_requests.create(
                title=cls.title_for(commits),
                commits=commits,
                base=branch_point.branch,
                head=repository.branch,
                draft=args.draft,
            )
            if not pr:
                sys.stderr.write("Failed to create pull-request for '{}'\n".format(repository.branch))
                return 1
            print("Created '{}'!".format(pr))
            if cls.is_revert_commit(commits[0]):
                cls.add_comment_to_reverted_commit_bug_tracker(repository, args, pr, commits[0])

        if issue:
            log.info('Checking issue assignee...')
            if issue.assignee != issue.tracker.me():
                issue.assign(issue.tracker.me())
                print('Assigning associated issue to {}'.format(issue.tracker.me()))
            log.info('Checking for pull request link in associated issue...')
            if pr.url and not any([pr.url in comment.content for comment in issue.comments]):
                if issue.opened:
                    issue.add_comment('Pull request: {}'.format(pr.url))
                else:
                    issue.open(why='Re-opening for pull request {}'.format(pr.url))
                print('Posted pull request link to {}'.format(issue.link))

        if issue and pr._metadata and pr._metadata.get('issue'):
            log.info('Syncing PR labels with issue component...')
            pr_issue = pr._metadata['issue']
            project = pr_issue.tracker.name
            component = issue.component
            if pr_issue.component == component or component not in pr_issue.tracker.projects.get(project, {}).get('components', {}):
                component = None
            version = issue.version
            if pr_issue.version == version or version not in pr_issue.tracker.projects.get(project, {}).get('versions', []):
                version = None
            if component or version:
                pr_issue.set_component(component=component, version=version)
                log.info('Synced PR labels with issue component!')
            else:
                log.info('No label syncing required')

        if pr.url:
            print(pr.url)

        if callback:
            return callback(pr)
        return 0

    @classmethod
    def main(cls, args, repository, **kwargs):
        if not isinstance(repository, local.Git):
            sys.stderr.write("Can only '{}' on a native Git repository\n".format(cls.name))
            return 1
        if not cls.check_pull_request_args(repository, args):
            return 1

        branch_point = cls.pull_request_branch_point(repository, args, **kwargs)
        if not branch_point:
            return 1

        result = cls.create_commit(args, repository, **kwargs)
        if result:
            return result
        if args.squash:
            result = Squash.squash_commit(args, repository, branch_point, **kwargs)
            if result:
                return result

        return cls.create_pull_request(repository, args, branch_point)
