From f3d2004868ee5df9a0cb33df9731e9f73c82f416 Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Thu, 1 Feb 2024 19:04:56 +0000 Subject: [PATCH] Rename `certificates` to `root_certificates` in `Driver` (#2890) --- examples/mt-pytorch/driver.py | 2 +- examples/secaggplus-mt/driver.py | 2 +- src/py/flwr/driver/app.py | 4 +++- src/py/flwr/driver/driver.py | 7 ++++--- src/py/flwr/driver/grpc_driver.py | 8 ++++---- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index ad4d5e1caab..184ee683818 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index d0d9a75f1b7..4a7d629924f 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -71,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index cbacd6b53ab..2c0576bde8f 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -110,7 +110,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals # Create the Driver if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() - driver = GrpcDriver(driver_service_address=address, certificates=root_certificates) + driver = GrpcDriver( + driver_service_address=address, root_certificates=root_certificates + ) driver.connect() lock = threading.Lock() diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 512a2001165..b68f7c8de5f 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -49,10 +49,10 @@ class Driver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.addr = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.grpc_driver: Optional[GrpcDriver] = None self.run_id: Optional[int] = None self.node = Node(node_id=0, anonymous=True) @@ -62,7 +62,8 @@ def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: if self.grpc_driver is None or self.run_id is None: # Connect and create run self.grpc_driver = GrpcDriver( - driver_service_address=self.addr, certificates=self.certificates + driver_service_address=self.addr, + root_certificates=self.root_certificates, ) self.grpc_driver.connect() res = self.grpc_driver.create_run(CreateRunRequest()) diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index 23d44979009..c3f66f7343d 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -51,10 +51,10 @@ class GrpcDriver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.driver_service_address = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None @@ -66,8 +66,8 @@ def connect(self) -> None: return self.channel = create_channel( server_address=self.driver_service_address, - insecure=(self.certificates is None), - root_certificates=self.certificates, + insecure=(self.root_certificates is None), + root_certificates=self.root_certificates, ) self.stub = DriverStub(self.channel) log(INFO, "[Driver] Connected to %s", self.driver_service_address)