2727
2828from proto .celestial import celestial_pb2 , celestial_pb2_grpc
2929
30- class ConnectionManager ():
30+
31+ class ConnectionManager :
3132 def __init__ (
3233 self ,
3334 hosts : typing .List [str ],
3435 peeringhosts : typing .List [str ],
35- allowed_concurrent : int = 512
36+ allowed_concurrent : int = 512 ,
3637 ):
37-
38- self .stubs : typing .List [celestial_pb2_grpc .CelestialStub ] = []
38+ stubs : typing .List [celestial_pb2_grpc .CelestialStub ] = []
3939
4040 for host in hosts :
4141 channel = grpc .insecure_channel (host )
42- self . stubs .append (celestial_pb2_grpc .CelestialStub (channel ))
42+ stubs .append (celestial_pb2_grpc .CelestialStub (channel ))
4343
4444 self .hosts = hosts
4545
@@ -56,31 +56,75 @@ def __init__(
5656
5757 for i in range (len (self .hosts )):
5858 irr .index = i
59- self . stubs [i ].InitRemotes (irr )
59+ stubs [i ].InitRemotes (irr )
6060
6161 for i in range (len (self .hosts )):
6262 e = celestial_pb2 .Empty ()
63- self . stubs [i ].StartPeering (e )
63+ stubs [i ].StartPeering (e )
6464
6565 def init_mutex (self ) -> None :
66+ self .stubs : typing .List [celestial_pb2_grpc .CelestialStub ] = []
67+
68+ for host in self .hosts :
69+ channel = grpc .insecure_channel (host )
70+ self .stubs .append (celestial_pb2_grpc .CelestialStub (channel ))
71+
6672 self .mutexes : typing .Dict [str , td .Semaphore ] = {}
6773
6874 for host in self .hosts :
6975 self .mutexes [host ] = td .Semaphore (self .allowed_concurrent )
7076
71- def __register (self , conn : MachineConnector , bandwidth : int , active : bool , vcpu_count : int , mem_size_mib : int , ht_enabled : bool , disk_size_mib : int , kernel : str , rootfs : str , bootparams : str ) -> None :
72-
77+ def __register (
78+ self ,
79+ conn : MachineConnector ,
80+ bandwidth : int ,
81+ active : bool ,
82+ vcpu_count : int ,
83+ mem_size_mib : int ,
84+ ht_enabled : bool ,
85+ disk_size_mib : int ,
86+ kernel : str ,
87+ rootfs : str ,
88+ bootparams : str ,
89+ ) -> None :
7390 self .mutexes [conn .host ].acquire ()
7491 try :
75- conn .create_machine (vcpu_count = vcpu_count , mem_size_mib = mem_size_mib , ht_enabled = ht_enabled , disk_size_mib = disk_size_mib , kernel = kernel , rootfs = rootfs , bootparams = bootparams , active = active , bandwidth = bandwidth )
92+ conn .create_machine (
93+ vcpu_count = vcpu_count ,
94+ mem_size_mib = mem_size_mib ,
95+ ht_enabled = ht_enabled ,
96+ disk_size_mib = disk_size_mib ,
97+ kernel = kernel ,
98+ rootfs = rootfs ,
99+ bootparams = bootparams ,
100+ active = active ,
101+ bandwidth = bandwidth ,
102+ )
76103 except Exception as e :
77- print ("❌ caught exception while trying to create machine %d shell %d:" % (conn .id , conn .shell ), e )
104+ print (
105+ "❌ caught exception while trying to create machine %d shell %d:"
106+ % (conn .id , conn .shell ),
107+ e ,
108+ )
78109
79110 self .mutexes [conn .host ].release ()
80111
81-
82- def register_machine (self , shell_no : int , id : int , bandwidth : int , active : bool , vcpu_count : int , mem_size_mib : int , ht_enabled : bool , disk_size_mib : int , kernel : str , rootfs : str , bootparams : str , host_affinity : typing .List [int ], name : str = "" ) -> MachineConnector :
83-
112+ def register_machine (
113+ self ,
114+ shell_no : int ,
115+ id : int ,
116+ bandwidth : int ,
117+ active : bool ,
118+ vcpu_count : int ,
119+ mem_size_mib : int ,
120+ ht_enabled : bool ,
121+ disk_size_mib : int ,
122+ kernel : str ,
123+ rootfs : str ,
124+ bootparams : str ,
125+ host_affinity : typing .List [int ],
126+ name : str = "" ,
127+ ) -> MachineConnector :
84128 # assign a random stub to this connection
85129 #
86130 # how do we get a host for a machine? serveral possibilities
@@ -97,23 +141,25 @@ def register_machine(self, shell_no: int, id: int, bandwidth: int, active: bool,
97141
98142 conn = MachineConnector (stub = stub , host = host , shell = shell_no , id = id , name = name )
99143
100- td .Thread (target = self .__register , kwargs = {
101- "conn" : conn ,
102- "vcpu_count" : vcpu_count ,
103- "mem_size_mib" : mem_size_mib ,
104- "ht_enabled" : ht_enabled ,
105- "disk_size_mib" : disk_size_mib ,
106- "kernel" : kernel ,
107- "rootfs" : rootfs ,
108- "bootparams" : bootparams ,
109- "active" : active ,
110- "bandwidth" : bandwidth
111- }).start ()
144+ td .Thread (
145+ target = self .__register ,
146+ kwargs = {
147+ "conn" : conn ,
148+ "vcpu_count" : vcpu_count ,
149+ "mem_size_mib" : mem_size_mib ,
150+ "ht_enabled" : ht_enabled ,
151+ "disk_size_mib" : disk_size_mib ,
152+ "kernel" : kernel ,
153+ "rootfs" : rootfs ,
154+ "bootparams" : bootparams ,
155+ "active" : active ,
156+ "bandwidth" : bandwidth ,
157+ },
158+ ).start ()
112159
113160 return conn
114161
115162 def collect_host_infos (self ) -> typing .Tuple [int , int , int ]:
116-
117163 cpu_count = 0
118164 mem = 0
119165 machine_count = 0
@@ -129,7 +175,7 @@ def collect_host_infos(self) -> typing.Tuple[int, int, int]:
129175
130176 machine_count += 1
131177 cpu_count += info .cpu
132- mem += info .mem / 1000000
178+ mem += info .mem / 1000000
133179
134180 return machine_count , cpu_count , mem
135181
@@ -141,7 +187,6 @@ def block_host_ready(self, tbar: tqdm.tqdm, total_machines: int) -> None:
141187
142188 while not all (ready ):
143189 for i in range (len (self .hosts )):
144-
145190 if ready [i ]:
146191 continue
147192
@@ -167,8 +212,13 @@ def block_host_ready(self, tbar: tqdm.tqdm, total_machines: int) -> None:
167212 if not sum (total ) == total_machines :
168213 raise ValueError ("reported created machines not equal total machines" )
169214
170- def init (self , db : bool , db_host : typing .Optional [str ], shell_count : int , shells : typing .List [ShellConfig ]) -> None :
171-
215+ def init (
216+ self ,
217+ db : bool ,
218+ db_host : typing .Optional [str ],
219+ shell_count : int ,
220+ shells : typing .List [ShellConfig ],
221+ ) -> None :
172222 isr = celestial_pb2 .InitRequest ()
173223
174224 isr .database = db
@@ -189,4 +239,4 @@ def init(self, db: bool, db_host: typing.Optional[str], shell_count: int, shells
189239
190240 stub = celestial_pb2_grpc .CelestialStub (channel )
191241
192- res = stub .Init (isr )
242+ res = stub .Init (isr )
0 commit comments