1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import json
1716import os
18-
17+ import subprocess
18+ import tempfile
19+ import typing
20+ import uuid
1921from abc import ABC , abstractmethod
22+
2023from autosynth import git , github
2124from autosynth .log import logger
22- import subprocess
23- import uuid
24- import tempfile
25+
26+
27+ class AbstractPullRequest (ABC ):
28+ """Abstractly, manipulates an existing pull request."""
29+
30+ @abstractmethod
31+ def add_labels (self , labels : typing .Sequence [str ]) -> None :
32+ """Adds labels to an existing pull request."""
33+ pass
2534
2635
2736class AbstractChangePusher (ABC ):
@@ -30,15 +39,37 @@ class AbstractChangePusher(ABC):
3039 @abstractmethod
3140 def push_changes (
3241 self , commit_count : int , branch : str , pr_title : str , synth_log : str = ""
33- ) -> None :
34- """Creates a pull request from commits in current working directory."""
42+ ) -> AbstractPullRequest :
43+ """Creates a pull request from commits in current working directory.
44+
45+ Arguments:
46+ commit_count {int} -- How many commits are in this pull request?
47+ branch {str} -- The name of the local branch to push.
48+ pr_title {str} -- The title for the pull request.
49+
50+ Keyword Arguments:
51+ synth_log {str} -- The full log of the call to synth. (default: {""})
52+
53+ Returns:
54+ A pull request.
55+ """
3556 pass
3657
3758 @abstractmethod
3859 def check_if_pr_already_exists (self , branch ) -> bool :
3960 pass
4061
4162
63+ class PullRequest (AbstractPullRequest ):
64+ def __init__ (self , gh : github .GitHub , pr : typing .Dict [str , typing .Any ]):
65+ self ._gh = gh
66+ self ._pr = pr
67+
68+ def add_labels (self , labels : typing .Sequence [str ]) -> None :
69+ """Adds labels to an existing pull request."""
70+ self ._gh .update_pull_labels (self ._pr , add = labels )
71+
72+
4273class ChangePusher (AbstractChangePusher ):
4374 """Actually pushes changes to github."""
4475
@@ -49,7 +80,7 @@ def __init__(self, repository: str, gh: github.GitHub, synth_path: str):
4980
5081 def push_changes (
5182 self , commit_count : int , branch : str , pr_title : str , synth_log : str = ""
52- ) -> None :
83+ ) -> AbstractPullRequest :
5384 git .push_changes (branch )
5485
5586 pr = self ._gh .create_pull_request (
@@ -64,7 +95,8 @@ def push_changes(
6495 api_label = self ._gh .get_api_label (self ._repository , self ._synth_path )
6596
6697 if api_label :
67- self ._gh .update_pull_labels (json .loads (pr ), add = [api_label ])
98+ self ._gh .update_pull_labels (pr , add = [api_label ])
99+ return PullRequest (self ._gh , pr )
68100
69101 def check_if_pr_already_exists (self , branch ) -> bool :
70102 repo = self ._repository
@@ -87,13 +119,12 @@ def __init__(self, inner_change_pusher: AbstractChangePusher):
87119
88120 def push_changes (
89121 self , commit_count : int , branch : str , pr_title : str , synth_log : str = ""
90- ) -> None :
122+ ) -> AbstractPullRequest :
91123 if commit_count < 2 :
92124 # Only one change, no need to squash.
93- self .inner_change_pusher .push_changes (
125+ return self .inner_change_pusher .push_changes (
94126 commit_count , branch , pr_title , synth_log
95127 )
96- return
97128
98129 subprocess .check_call (["git" , "checkout" , branch ]) # Probably redundant.
99130 with tempfile .NamedTemporaryFile () as message_file :
@@ -111,7 +142,7 @@ def push_changes(
111142 subprocess .check_call (["git" , "checkout" , "-b" , branch ])
112143 subprocess .check_call (["git" , "merge" , "--squash" , temp_branch ])
113144 subprocess .check_call (["git" , "commit" , "-F" , message_file .name ])
114- self .inner_change_pusher .push_changes (1 , branch , pr_title , synth_log )
145+ return self .inner_change_pusher .push_changes (1 , branch , pr_title , synth_log )
115146
116147 def check_if_pr_already_exists (self , branch ) -> bool :
117148 return self .inner_change_pusher .check_if_pr_already_exists (branch )
0 commit comments