From 327e663a320371aaeda729777db6d533f158a144 Mon Sep 17 00:00:00 2001 From: daler3 Date: Sun, 18 Oct 2020 13:29:19 +0200 Subject: [PATCH 01/21] Started exploration notebook with synthea data --- .../Diabetes_prediction_preprocessing.ipynb | 700 ++++++++++++++++++ 1 file changed, 700 insertions(+) create mode 100644 examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb diff --git a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb new file mode 100644 index 0000000..b1db37e --- /dev/null +++ b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb @@ -0,0 +1,700 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd \n", + "import numpy as np\n", + "\n", + "#load data into pandas dataframes\n", + "data_dir = \"../../data/synthea/\"\n", + "conditions_file = data_dir+\"conditions.csv\"\n", + "medications_file = data_dir+\"medications.csv\"\n", + "observatios_file = data_dir+\"observations.csv\"\n", + "patients_file = data_dir+\"patients.csv\"\n", + "\n", + "df_cond = pd.read_csv(conditions_file)\n", + "df_med = pd.read_csv(medications_file)\n", + "df_obs = pd.read_csv(observatios_file)\n", + "df_pat = pd.read_csv(patients_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
STARTSTOPPATIENTENCOUNTERCODEDESCRIPTION
02012-03-312012-04-307d3e489a-7789-9cd6-2a1b-711074af481b1f2b8067-61bd-88ca-a497-b177756efe62307731004Injury of tendon of the rotator cuff of shoulder
12014-10-082014-10-177d3e489a-7789-9cd6-2a1b-711074af481bc0043d0a-e6b1-7d0a-ab72-263d9591b1b1195662009Acute viral pharyngitis (disorder)
22017-12-082017-12-157d3e489a-7789-9cd6-2a1b-711074af481b9a2ce31d-bf2d-0f0e-f5e9-945602e19b0c444814009Viral sinusitis (disorder)
32020-03-152020-03-297d3e489a-7789-9cd6-2a1b-711074af481b1402ddca-c6d3-3bf0-2369-997840511cfb49727002Cough (finding)
42020-03-152020-03-297d3e489a-7789-9cd6-2a1b-711074af481b1402ddca-c6d3-3bf0-2369-997840511cfb248595008Sputum finding (finding)
\n", + "
" + ], + "text/plain": [ + " START STOP PATIENT \\\n", + "0 2012-03-31 2012-04-30 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "1 2014-10-08 2014-10-17 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "2 2017-12-08 2017-12-15 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "3 2020-03-15 2020-03-29 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "4 2020-03-15 2020-03-29 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "\n", + " ENCOUNTER CODE \\\n", + "0 1f2b8067-61bd-88ca-a497-b177756efe62 307731004 \n", + "1 c0043d0a-e6b1-7d0a-ab72-263d9591b1b1 195662009 \n", + "2 9a2ce31d-bf2d-0f0e-f5e9-945602e19b0c 444814009 \n", + "3 1402ddca-c6d3-3bf0-2369-997840511cfb 49727002 \n", + "4 1402ddca-c6d3-3bf0-2369-997840511cfb 248595008 \n", + "\n", + " DESCRIPTION \n", + "0 Injury of tendon of the rotator cuff of shoulder \n", + "1 Acute viral pharyngitis (disorder) \n", + "2 Viral sinusitis (disorder) \n", + "3 Cough (finding) \n", + "4 Sputum finding (finding) " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_cond.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
STARTSTOPPATIENTPAYERENCOUNTERCODEDESCRIPTIONBASE_COSTPAYER_COVERAGEDISPENSESTOTALCOSTREASONCODEREASONDESCRIPTION
01989-12-09T22:06:58ZNaNa3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c76959d732e-e052-d01d-4b6d-428a208c93fd106258Hydrocortisone 10 MG/ML Topical Cream5.050.03751893.7540275004.0Contact dermatitis
11989-12-23T22:35:58ZNaNa3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c767c1f93ba-a68f-03b3-3999-38ea56e424c2141918Terfenadine 60 MG Oral Tablet7.970.03752988.75NaNNaN
21989-12-23T22:35:58ZNaNa3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c767c1f93ba-a68f-03b3-3999-38ea56e424c21870230NDA020800 0.3 ML Epinephrine 1 MG/ML Auto-Inje...406.870.0375152576.25NaNNaN
32015-04-01T22:27:58Z2015-04-15T22:27:58Za3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c763af8003b-e9c7-27aa-cbe9-038d73a3ac21313782Acetaminophen 325 MG Oral Tablet10.030.0110.0310509002.0Acute bronchitis (disorder)
42008-08-27T19:55:43Z2009-09-02T19:55:43Zd7acfddb-f4c2-69f4-2081-ad1fb849044842c4fca7-f8a9-3cd1-982a-dd9751bf3e2a2c054c1f-a06a-06a4-c3d3-e33daeaf6560310798Hydrochlorothiazide 25 MG Oral Tablet0.010.0120.1259621000.0Hypertension
\n", + "
" + ], + "text/plain": [ + " START STOP \\\n", + "0 1989-12-09T22:06:58Z NaN \n", + "1 1989-12-23T22:35:58Z NaN \n", + "2 1989-12-23T22:35:58Z NaN \n", + "3 2015-04-01T22:27:58Z 2015-04-15T22:27:58Z \n", + "4 2008-08-27T19:55:43Z 2009-09-02T19:55:43Z \n", + "\n", + " PATIENT PAYER \\\n", + "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", + "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", + "2 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", + "3 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", + "4 d7acfddb-f4c2-69f4-2081-ad1fb8490448 42c4fca7-f8a9-3cd1-982a-dd9751bf3e2a \n", + "\n", + " ENCOUNTER CODE \\\n", + "0 959d732e-e052-d01d-4b6d-428a208c93fd 106258 \n", + "1 7c1f93ba-a68f-03b3-3999-38ea56e424c2 141918 \n", + "2 7c1f93ba-a68f-03b3-3999-38ea56e424c2 1870230 \n", + "3 3af8003b-e9c7-27aa-cbe9-038d73a3ac21 313782 \n", + "4 2c054c1f-a06a-06a4-c3d3-e33daeaf6560 310798 \n", + "\n", + " DESCRIPTION BASE_COST \\\n", + "0 Hydrocortisone 10 MG/ML Topical Cream 5.05 \n", + "1 Terfenadine 60 MG Oral Tablet 7.97 \n", + "2 NDA020800 0.3 ML Epinephrine 1 MG/ML Auto-Inje... 406.87 \n", + "3 Acetaminophen 325 MG Oral Tablet 10.03 \n", + "4 Hydrochlorothiazide 25 MG Oral Tablet 0.01 \n", + "\n", + " PAYER_COVERAGE DISPENSES TOTALCOST REASONCODE \\\n", + "0 0.0 375 1893.75 40275004.0 \n", + "1 0.0 375 2988.75 NaN \n", + "2 0.0 375 152576.25 NaN \n", + "3 0.0 1 10.03 10509002.0 \n", + "4 0.0 12 0.12 59621000.0 \n", + "\n", + " REASONDESCRIPTION \n", + "0 Contact dermatitis \n", + "1 NaN \n", + "2 NaN \n", + "3 Acute bronchitis (disorder) \n", + "4 Hypertension " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_med.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DATEPATIENTENCOUNTERCODEDESCRIPTIONVALUEUNITSTYPE
02011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c91498302-2Body Height167.0cmnumeric
12011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914972514-3Pain severity - 0-10 verbal numeric rating [Sc...3.0{score}numeric
22011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914929463-7Body Weight71.1kgnumeric
32011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914939156-5Body Mass Index25.5kg/m2numeric
42011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914959576-9Body mass index (BMI) [Percentile] Per age and...83.6%numeric
\n", + "
" + ], + "text/plain": [ + " DATE PATIENT \\\n", + "0 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "1 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "2 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "3 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "4 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", + "\n", + " ENCOUNTER CODE \\\n", + "0 814174f3-2e0e-1625-de48-9c40732c9149 8302-2 \n", + "1 814174f3-2e0e-1625-de48-9c40732c9149 72514-3 \n", + "2 814174f3-2e0e-1625-de48-9c40732c9149 29463-7 \n", + "3 814174f3-2e0e-1625-de48-9c40732c9149 39156-5 \n", + "4 814174f3-2e0e-1625-de48-9c40732c9149 59576-9 \n", + "\n", + " DESCRIPTION VALUE UNITS TYPE \n", + "0 Body Height 167.0 cm numeric \n", + "1 Pain severity - 0-10 verbal numeric rating [Sc... 3.0 {score} numeric \n", + "2 Body Weight 71.1 kg numeric \n", + "3 Body Mass Index 25.5 kg/m2 numeric \n", + "4 Body mass index (BMI) [Percentile] Per age and... 83.6 % numeric " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_obs.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
IdBIRTHDATEDEATHDATESSNDRIVERSPASSPORTPREFIXFIRSTLASTSUFFIX...BIRTHPLACEADDRESSCITYSTATECOUNTYZIPLATLONHEALTHCARE_EXPENSESHEALTHCARE_COVERAGE
07d3e489a-7789-9cd6-2a1b-711074af481b1993-01-28NaN999-95-8631S99916705X24646789XMr.Jon665Pacocha935NaN...Lawrence Massachusetts US942 Fahey Overpass Apt 21NatickMassachusettsMiddlesex CountyNaN42.309347-71.349633569019.692293.12
1a3795ec8-54f3-e99e-a4b1-4c067f3141d71971-12-01NaN999-62-4431S99941017X38787090XMr.Dick869Streich926NaN...Swansea Massachusetts US1064 Hickle View Apt 7ChicopeeMassachusettsHampden County1020.042.198239-72.55475218755.460.00
23829c803-1f4c-74ed-0d8f-36e502cadd0f2005-01-07NaN999-21-2332NaNNaNNaNCordell41Eichmann909NaN...Chelmsford Massachusetts US560 Ritchie Way Suite 68SwanseaMassachusettsBristol CountyNaN41.748125-71.182914361770.002768.96
3d7acfddb-f4c2-69f4-2081-ad1fb84904481990-07-04NaN999-53-1990S99932677X67053099XMrs.Cheri871Oberbrunner298NaN...Cambridge Massachusetts US268 Hansen Loaf Apt 62LowellMassachusettsMiddlesex County1850.042.662520-71.368933703332.775551.19
4474766f3-ee93-f5d6-84c3-db38ba8033942012-04-03NaN999-57-2653NaNNaNNaNDesmond566O'Conner199NaN...Cohasset Massachusetts US831 Schumm Lock Apt 62WestboroughMassachusettsWorcester CountyNaN42.253951-71.563825206450.272284.86
\n", + "

5 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " Id BIRTHDATE DEATHDATE SSN \\\n", + "0 7d3e489a-7789-9cd6-2a1b-711074af481b 1993-01-28 NaN 999-95-8631 \n", + "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 1971-12-01 NaN 999-62-4431 \n", + "2 3829c803-1f4c-74ed-0d8f-36e502cadd0f 2005-01-07 NaN 999-21-2332 \n", + "3 d7acfddb-f4c2-69f4-2081-ad1fb8490448 1990-07-04 NaN 999-53-1990 \n", + "4 474766f3-ee93-f5d6-84c3-db38ba803394 2012-04-03 NaN 999-57-2653 \n", + "\n", + " DRIVERS PASSPORT PREFIX FIRST LAST SUFFIX ... \\\n", + "0 S99916705 X24646789X Mr. Jon665 Pacocha935 NaN ... \n", + "1 S99941017 X38787090X Mr. Dick869 Streich926 NaN ... \n", + "2 NaN NaN NaN Cordell41 Eichmann909 NaN ... \n", + "3 S99932677 X67053099X Mrs. Cheri871 Oberbrunner298 NaN ... \n", + "4 NaN NaN NaN Desmond566 O'Conner199 NaN ... \n", + "\n", + " BIRTHPLACE ADDRESS CITY \\\n", + "0 Lawrence Massachusetts US 942 Fahey Overpass Apt 21 Natick \n", + "1 Swansea Massachusetts US 1064 Hickle View Apt 7 Chicopee \n", + "2 Chelmsford Massachusetts US 560 Ritchie Way Suite 68 Swansea \n", + "3 Cambridge Massachusetts US 268 Hansen Loaf Apt 62 Lowell \n", + "4 Cohasset Massachusetts US 831 Schumm Lock Apt 62 Westborough \n", + "\n", + " STATE COUNTY ZIP LAT LON \\\n", + "0 Massachusetts Middlesex County NaN 42.309347 -71.349633 \n", + "1 Massachusetts Hampden County 1020.0 42.198239 -72.554752 \n", + "2 Massachusetts Bristol County NaN 41.748125 -71.182914 \n", + "3 Massachusetts Middlesex County 1850.0 42.662520 -71.368933 \n", + "4 Massachusetts Worcester County NaN 42.253951 -71.563825 \n", + "\n", + " HEALTHCARE_EXPENSES HEALTHCARE_COVERAGE \n", + "0 569019.69 2293.12 \n", + "1 18755.46 0.00 \n", + "2 361770.00 2768.96 \n", + "3 703332.77 5551.19 \n", + "4 206450.27 2284.86 \n", + "\n", + "[5 rows x 25 columns]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_pat.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From d92667a1c6e8beddc86b6cd1eeeb76eadf2daa3c Mon Sep 17 00:00:00 2001 From: daler3 Date: Mon, 19 Oct 2020 06:51:47 +0200 Subject: [PATCH 02/21] Feature loading --- .../Diabetes_prediction_preprocessing.ipynb | 553 ++++++++---------- 1 file changed, 253 insertions(+), 300 deletions(-) diff --git a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb index b1db37e..d194c64 100644 --- a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb +++ b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb @@ -1,8 +1,24 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Diabetes prediction with synthea data\n", + "\n", + "###### Mostly from https://github.com/IBM/example-health-machine-learning/blob/master/diabetes-prediction.ipynb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import data in pandas dataframes " + ] + }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -24,302 +40,7 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
STARTSTOPPATIENTENCOUNTERCODEDESCRIPTION
02012-03-312012-04-307d3e489a-7789-9cd6-2a1b-711074af481b1f2b8067-61bd-88ca-a497-b177756efe62307731004Injury of tendon of the rotator cuff of shoulder
12014-10-082014-10-177d3e489a-7789-9cd6-2a1b-711074af481bc0043d0a-e6b1-7d0a-ab72-263d9591b1b1195662009Acute viral pharyngitis (disorder)
22017-12-082017-12-157d3e489a-7789-9cd6-2a1b-711074af481b9a2ce31d-bf2d-0f0e-f5e9-945602e19b0c444814009Viral sinusitis (disorder)
32020-03-152020-03-297d3e489a-7789-9cd6-2a1b-711074af481b1402ddca-c6d3-3bf0-2369-997840511cfb49727002Cough (finding)
42020-03-152020-03-297d3e489a-7789-9cd6-2a1b-711074af481b1402ddca-c6d3-3bf0-2369-997840511cfb248595008Sputum finding (finding)
\n", - "
" - ], - "text/plain": [ - " START STOP PATIENT \\\n", - "0 2012-03-31 2012-04-30 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "1 2014-10-08 2014-10-17 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "2 2017-12-08 2017-12-15 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "3 2020-03-15 2020-03-29 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "4 2020-03-15 2020-03-29 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "\n", - " ENCOUNTER CODE \\\n", - "0 1f2b8067-61bd-88ca-a497-b177756efe62 307731004 \n", - "1 c0043d0a-e6b1-7d0a-ab72-263d9591b1b1 195662009 \n", - "2 9a2ce31d-bf2d-0f0e-f5e9-945602e19b0c 444814009 \n", - "3 1402ddca-c6d3-3bf0-2369-997840511cfb 49727002 \n", - "4 1402ddca-c6d3-3bf0-2369-997840511cfb 248595008 \n", - "\n", - " DESCRIPTION \n", - "0 Injury of tendon of the rotator cuff of shoulder \n", - "1 Acute viral pharyngitis (disorder) \n", - "2 Viral sinusitis (disorder) \n", - "3 Cough (finding) \n", - "4 Sputum finding (finding) " - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_cond.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
STARTSTOPPATIENTPAYERENCOUNTERCODEDESCRIPTIONBASE_COSTPAYER_COVERAGEDISPENSESTOTALCOSTREASONCODEREASONDESCRIPTION
01989-12-09T22:06:58ZNaNa3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c76959d732e-e052-d01d-4b6d-428a208c93fd106258Hydrocortisone 10 MG/ML Topical Cream5.050.03751893.7540275004.0Contact dermatitis
11989-12-23T22:35:58ZNaNa3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c767c1f93ba-a68f-03b3-3999-38ea56e424c2141918Terfenadine 60 MG Oral Tablet7.970.03752988.75NaNNaN
21989-12-23T22:35:58ZNaNa3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c767c1f93ba-a68f-03b3-3999-38ea56e424c21870230NDA020800 0.3 ML Epinephrine 1 MG/ML Auto-Inje...406.870.0375152576.25NaNNaN
32015-04-01T22:27:58Z2015-04-15T22:27:58Za3795ec8-54f3-e99e-a4b1-4c067f3141d7b1c428d6-4f07-31e0-90f0-68ffa6ff8c763af8003b-e9c7-27aa-cbe9-038d73a3ac21313782Acetaminophen 325 MG Oral Tablet10.030.0110.0310509002.0Acute bronchitis (disorder)
42008-08-27T19:55:43Z2009-09-02T19:55:43Zd7acfddb-f4c2-69f4-2081-ad1fb849044842c4fca7-f8a9-3cd1-982a-dd9751bf3e2a2c054c1f-a06a-06a4-c3d3-e33daeaf6560310798Hydrochlorothiazide 25 MG Oral Tablet0.010.0120.1259621000.0Hypertension
\n", - "
" - ], - "text/plain": [ - " START STOP \\\n", - "0 1989-12-09T22:06:58Z NaN \n", - "1 1989-12-23T22:35:58Z NaN \n", - "2 1989-12-23T22:35:58Z NaN \n", - "3 2015-04-01T22:27:58Z 2015-04-15T22:27:58Z \n", - "4 2008-08-27T19:55:43Z 2009-09-02T19:55:43Z \n", - "\n", - " PATIENT PAYER \\\n", - "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", - "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", - "2 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", - "3 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 b1c428d6-4f07-31e0-90f0-68ffa6ff8c76 \n", - "4 d7acfddb-f4c2-69f4-2081-ad1fb8490448 42c4fca7-f8a9-3cd1-982a-dd9751bf3e2a \n", - "\n", - " ENCOUNTER CODE \\\n", - "0 959d732e-e052-d01d-4b6d-428a208c93fd 106258 \n", - "1 7c1f93ba-a68f-03b3-3999-38ea56e424c2 141918 \n", - "2 7c1f93ba-a68f-03b3-3999-38ea56e424c2 1870230 \n", - "3 3af8003b-e9c7-27aa-cbe9-038d73a3ac21 313782 \n", - "4 2c054c1f-a06a-06a4-c3d3-e33daeaf6560 310798 \n", - "\n", - " DESCRIPTION BASE_COST \\\n", - "0 Hydrocortisone 10 MG/ML Topical Cream 5.05 \n", - "1 Terfenadine 60 MG Oral Tablet 7.97 \n", - "2 NDA020800 0.3 ML Epinephrine 1 MG/ML Auto-Inje... 406.87 \n", - "3 Acetaminophen 325 MG Oral Tablet 10.03 \n", - "4 Hydrochlorothiazide 25 MG Oral Tablet 0.01 \n", - "\n", - " PAYER_COVERAGE DISPENSES TOTALCOST REASONCODE \\\n", - "0 0.0 375 1893.75 40275004.0 \n", - "1 0.0 375 2988.75 NaN \n", - "2 0.0 375 152576.25 NaN \n", - "3 0.0 1 10.03 10509002.0 \n", - "4 0.0 12 0.12 59621000.0 \n", - "\n", - " REASONDESCRIPTION \n", - "0 Contact dermatitis \n", - "1 NaN \n", - "2 NaN \n", - "3 Acute bronchitis (disorder) \n", - "4 Hypertension " - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_med.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 17, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -436,7 +157,7 @@ "4 Body mass index (BMI) [Percentile] Per age and... 83.6 % numeric " ] }, - "execution_count": 17, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -447,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -659,7 +380,7 @@ "[5 rows x 25 columns]" ] }, - "execution_count": 18, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -668,6 +389,238 @@ "df_pat.head()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Feature selection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Select the features of interests: \n", + "\n", + "- Systolic blood pressure readings from the observations (code 8480-6).\n", + "- Select diastolic blood pressure readings (code 8462-4).\n", + "- Select HDL cholesterol readings (code 2085-9).\n", + "- Select LDL cholesterol readings (code 18262-6).\n", + "- Select BMI (body mass index) readings (code 39156-5).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "def feature_selection_obs(df, code):\n", + " return df[df[\"CODE\"]==code][[\"PATIENT\", \"DATE\", \"VALUE\"]].drop_duplicates().reset_index(drop=True)\n", + "\n", + "#select feautures from observations\n", + "df_systolic = feature_selection_obs(df_obs, \"8480-6\").rename(columns={\"VALUE\": \"SYSTOLIC_BP\"})\n", + "df_diastolic = feature_selection_obs(df_obs, \"8462-4\").rename(columns={\"VALUE\": \"DIASTOLIC_BP\"})\n", + "df_hdl = feature_selection_obs(df_obs, \"2085-9\").rename(columns={\"VALUE\": \"HDL\"})\n", + "df_ldl = feature_selection_obs(df_obs, \"18262-6\").rename(columns={\"VALUE\": \"LDL\"})\n", + "df_bmi = feature_selection_obs(df_obs, \"39156-5\").rename(columns={\"VALUE\": \"BMI\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(83540, 83541, 26900, 26900, 57880)" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(df_systolic), len(df_diastolic), len(df_hdl), len(df_ldl), len(df_bmi)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Merge the dataframes (inner join for now, to avoid dealing with missing values)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "df1 = pd.merge(df_systolic, df_diastolic, on=[\"PATIENT\", \"DATE\"], how='inner')\n", + "df2 = pd.merge(df1, df_hdl, on=[\"PATIENT\", \"DATE\"], how='inner')\n", + "df3 = pd.merge(df2, df_ldl, on=[\"PATIENT\", \"DATE\"], how='inner')\n", + "df4 = pd.merge(df3, df_bmi, on=[\"PATIENT\", \"DATE\"], how='inner')" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "21224" + ] + }, + "execution_count": 92, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(df4)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PATIENTDATESYSTOLIC_BPDIASTOLIC_BPHDLLDLBMI
0a3795ec8-54f3-e99e-a4b1-4c067f3141d72013-01-16T22:06:58Z128.088.064.589.222.4
1a3795ec8-54f3-e99e-a4b1-4c067f3141d72017-12-20T22:06:58Z116.071.064.978.022.4
29bafdf36-6e60-e93e-7925-c8d15a49ea622012-11-25T09:32:01Z125.082.072.997.327.6
39bafdf36-6e60-e93e-7925-c8d15a49ea622015-12-13T09:32:01Z104.089.064.371.327.6
49bafdf36-6e60-e93e-7925-c8d15a49ea622018-12-30T09:32:01Z121.077.061.277.827.6
\n", + "
" + ], + "text/plain": [ + " PATIENT DATE SYSTOLIC_BP \\\n", + "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2013-01-16T22:06:58Z 128.0 \n", + "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2017-12-20T22:06:58Z 116.0 \n", + "2 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2012-11-25T09:32:01Z 125.0 \n", + "3 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2015-12-13T09:32:01Z 104.0 \n", + "4 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2018-12-30T09:32:01Z 121.0 \n", + "\n", + " DIASTOLIC_BP HDL LDL BMI \n", + "0 88.0 64.5 89.2 22.4 \n", + "1 71.0 64.9 78.0 22.4 \n", + "2 82.0 72.9 97.3 27.6 \n", + "3 89.0 64.3 71.3 27.6 \n", + "4 77.0 61.2 77.8 27.6 " + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df4.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, From fd333ca4ae621ceb623148e809da23f4986b64b7 Mon Sep 17 00:00:00 2001 From: daler3 Date: Fri, 23 Oct 2020 09:35:13 +0200 Subject: [PATCH 03/21] Added model and training loop --- .../Diabetes_prediction_preprocessing.ipynb | 971 +++++++++++++++++- 1 file changed, 958 insertions(+), 13 deletions(-) diff --git a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb index d194c64..3fb382d 100644 --- a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb +++ b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -157,7 +157,7 @@ "4 Body mass index (BMI) [Percentile] Per age and... 83.6 % numeric " ] }, - "execution_count": 9, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -168,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -380,7 +380,7 @@ "[5 rows x 25 columns]" ] }, - "execution_count": 29, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -411,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -428,7 +428,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -437,7 +437,7 @@ "(83540, 83541, 26900, 26900, 57880)" ] }, - "execution_count": 90, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -455,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -467,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -476,7 +476,7 @@ "21224" ] }, - "execution_count": 92, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -487,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -591,7 +591,7 @@ "4 77.0 61.2 77.8 27.6 " ] }, - "execution_count": 93, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -600,6 +600,951 @@ "df4.head()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Join also the age (derived from birth date in PATIENT dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "df5 = pd.merge(df4, df_pat[[\"Id\", \"BIRTHDATE\"]].rename(columns={\"Id\": \"PATIENT\"}), on =[\"PATIENT\"], how='inner')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "df5[\"DATE\"] = [x.split(\"T\")[0] for x in list(df5[\"DATE\"])]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "#from https://stackoverflow.com/questions/8419564/difference-between-two-dates-in-python\n", + "def days_between(d1, d2):\n", + " d1 = datetime.strptime(d1, \"%Y-%m-%d\")\n", + " d2 = datetime.strptime(d2, \"%Y-%m-%d\")\n", + " return abs((d2 - d1).days)\n", + "\n", + "def age_calculation(l1, l2):\n", + " age_list = []\n", + " i = 0\n", + " for i in range(0, len(l1)):\n", + " age_list.append(days_between(l1[i], l2[i]) / 365.00)\n", + " return age_list\n", + "\n", + "df5[\"AGE\"] = age_calculation(list(df5[\"DATE\"]), list(df5[\"BIRTHDATE\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "df5.drop([\"BIRTHDATE\"], axis=1, inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PATIENTDATESYSTOLIC_BPDIASTOLIC_BPHDLLDLBMIAGE
0a3795ec8-54f3-e99e-a4b1-4c067f3141d72013-01-16128.088.064.589.222.441.156164
1a3795ec8-54f3-e99e-a4b1-4c067f3141d72017-12-20116.071.064.978.022.446.084932
29bafdf36-6e60-e93e-7925-c8d15a49ea622012-11-25125.082.072.997.327.658.167123
39bafdf36-6e60-e93e-7925-c8d15a49ea622015-12-13104.089.064.371.327.661.216438
49bafdf36-6e60-e93e-7925-c8d15a49ea622018-12-30121.077.061.277.827.664.265753
\n", + "
" + ], + "text/plain": [ + " PATIENT DATE SYSTOLIC_BP DIASTOLIC_BP \\\n", + "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2013-01-16 128.0 88.0 \n", + "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2017-12-20 116.0 71.0 \n", + "2 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2012-11-25 125.0 82.0 \n", + "3 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2015-12-13 104.0 89.0 \n", + "4 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2018-12-30 121.0 77.0 \n", + "\n", + " HDL LDL BMI AGE \n", + "0 64.5 89.2 22.4 41.156164 \n", + "1 64.9 78.0 22.4 46.084932 \n", + "2 72.9 97.3 27.6 58.167123 \n", + "3 64.3 71.3 27.6 61.216438 \n", + "4 61.2 77.8 27.6 64.265753 " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df5.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we find the patient with diabetes diagnosis, and select the start date column (equivalent to first diagnosis), in the CONDITION dataset. " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "df_pat_diab = df_cond[df_cond.DESCRIPTION == \"Diabetes\"][[\"PATIENT\", \"START\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "df6 = pd.merge(df5, df_pat_diab, on=[\"PATIENT\"], how=\"left\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "df6[\"HAS_DIABETES\"] = [(0 if (type(el) == float and np.isnan(el)) else 1) for el in list(df6[\"START\"])]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PATIENTDATESYSTOLIC_BPDIASTOLIC_BPHDLLDLBMIAGESTARTHAS_DIABETES
0a3795ec8-54f3-e99e-a4b1-4c067f3141d72013-01-16128.088.064.589.222.441.156164NaN0
1a3795ec8-54f3-e99e-a4b1-4c067f3141d72017-12-20116.071.064.978.022.446.084932NaN0
29bafdf36-6e60-e93e-7925-c8d15a49ea622012-11-25125.082.072.997.327.658.167123NaN0
39bafdf36-6e60-e93e-7925-c8d15a49ea622015-12-13104.089.064.371.327.661.216438NaN0
49bafdf36-6e60-e93e-7925-c8d15a49ea622018-12-30121.077.061.277.827.664.265753NaN0
\n", + "
" + ], + "text/plain": [ + " PATIENT DATE SYSTOLIC_BP DIASTOLIC_BP \\\n", + "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2013-01-16 128.0 88.0 \n", + "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2017-12-20 116.0 71.0 \n", + "2 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2012-11-25 125.0 82.0 \n", + "3 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2015-12-13 104.0 89.0 \n", + "4 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2018-12-30 121.0 77.0 \n", + "\n", + " HDL LDL BMI AGE START HAS_DIABETES \n", + "0 64.5 89.2 22.4 41.156164 NaN 0 \n", + "1 64.9 78.0 22.4 46.084932 NaN 0 \n", + "2 72.9 97.3 27.6 58.167123 NaN 0 \n", + "3 64.3 71.3 27.6 61.216438 NaN 0 \n", + "4 61.2 77.8 27.6 64.265753 NaN 0 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df6.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filter data\n", + "For this example, we filter the positive observations taken before diagnosis and then we reduce the observations to a single one per patient. In the future, it might be valuable keeping the time and try to predict diabetes before it occurs (e.g. with RNN). In this case, however, we need to check better the generative model underlying synthea, as in the notebook we are trying to reproduce here: \"The impact of the condition (diabetes) is not reflected in the observations until the patient is diagnosed with the condition in a wellness visit\". However, there is a condition called \"Prediabetes\" which we could take into account. " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def date_to_int(string_date):\n", + " #a date can also be nan (float type)\n", + " return int(string_date.replace(\"-\", \"\")) if type(string_date) == str else 0\n", + "def col_date_to_int(col_date):\n", + " return list(map(date_to_int, col_date))\n", + "\n", + "df6[\"temp_date\"] = col_date_to_int(df6[\"DATE\"])\n", + "df6[\"temp_start\"] = col_date_to_int(df6[\"START\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "57" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_with_diab = df6[df6.HAS_DIABETES == 1]\n", + "df_to_discard = df_with_diab[df_with_diab[\"temp_start\"] > df_with_diab[\"temp_date\"]]\n", + "len(df_to_discard)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "df7 = df6.drop(index = df_to_discard.index, inplace=False).reset_index().drop(columns=[\"index\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now reduce the observations to a single observation per patient (the earliest available observation)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PATIENTSYSTOLIC_BPDIASTOLIC_BPHDLLDLBMIAGEHAS_DIABETES
028de1ba3-efdc-1797-5a32-d4f8d3be0936109.078.075.174.425.230.3561640
16686010d-3d7b-d69d-d2cd-8bbe4b3e6041160.0104.062.182.628.230.3561640
242217d99-02e1-2cc8-83ed-5101c246a559112.084.069.294.630.131.2219180
3c429d985-225b-f380-4462-57852cf61186121.074.063.1109.029.631.2219180
406365dfa-6203-5413-2c6d-553a4a988c1f117.082.064.694.134.334.2328770
...........................
3821391d2527-19ca-8b83-2cd0-5c484946b2b7156.095.076.576.434.130.3589040
3822964603c6-e5bd-6d86-5b21-34056a4a651a124.078.068.0102.125.331.2219180
38233d44d71e-2241-29c2-6aed-7f7f4d65d8d4123.084.075.066.530.030.3589040
3824bbcfa21f-caa2-4540-d2b6-929f48ae27c2130.071.069.199.329.331.2219180
3825d9f4d286-8db4-201d-c1ed-80909f6b3927121.083.068.1104.422.730.3589040
\n", + "

3826 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " PATIENT SYSTOLIC_BP DIASTOLIC_BP HDL \\\n", + "0 28de1ba3-efdc-1797-5a32-d4f8d3be0936 109.0 78.0 75.1 \n", + "1 6686010d-3d7b-d69d-d2cd-8bbe4b3e6041 160.0 104.0 62.1 \n", + "2 42217d99-02e1-2cc8-83ed-5101c246a559 112.0 84.0 69.2 \n", + "3 c429d985-225b-f380-4462-57852cf61186 121.0 74.0 63.1 \n", + "4 06365dfa-6203-5413-2c6d-553a4a988c1f 117.0 82.0 64.6 \n", + "... ... ... ... ... \n", + "3821 391d2527-19ca-8b83-2cd0-5c484946b2b7 156.0 95.0 76.5 \n", + "3822 964603c6-e5bd-6d86-5b21-34056a4a651a 124.0 78.0 68.0 \n", + "3823 3d44d71e-2241-29c2-6aed-7f7f4d65d8d4 123.0 84.0 75.0 \n", + "3824 bbcfa21f-caa2-4540-d2b6-929f48ae27c2 130.0 71.0 69.1 \n", + "3825 d9f4d286-8db4-201d-c1ed-80909f6b3927 121.0 83.0 68.1 \n", + "\n", + " LDL BMI AGE HAS_DIABETES \n", + "0 74.4 25.2 30.356164 0 \n", + "1 82.6 28.2 30.356164 0 \n", + "2 94.6 30.1 31.221918 0 \n", + "3 109.0 29.6 31.221918 0 \n", + "4 94.1 34.3 34.232877 0 \n", + "... ... ... ... ... \n", + "3821 76.4 34.1 30.358904 0 \n", + "3822 102.1 25.3 31.221918 0 \n", + "3823 66.5 30.0 30.358904 0 \n", + "3824 99.3 29.3 31.221918 0 \n", + "3825 104.4 22.7 30.358904 0 \n", + "\n", + "[3826 rows x 8 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df7.sort_values(by=[\"temp_date\"], axis=0, inplace=True)\n", + "df7.reset_index(inplace=True)\n", + "df7 = df7.drop(columns=\"index\")\n", + "df7[\"OBS_INDEX\"] = df7.groupby([\"PATIENT\"]).cumcount()+1\n", + "df8 = df7[df7.OBS_INDEX == 1]\n", + "df8.reset_index(inplace=True)\n", + "df8 = df8.drop(columns=[\"index\", \"temp_date\", \"temp_start\", \"OBS_INDEX\", \"START\", \"DATE\"])\n", + "df8" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3480, 346)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(df8[\"HAS_DIABETES\"]).count(0), list(df8[\"HAS_DIABETES\"]).count(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "#prepare dataset for pytorch\n", + "df9 = df8.drop(columns=\"PATIENT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['SYSTOLIC_BP', 'DIASTOLIC_BP', 'HDL', 'LDL', 'BMI', 'AGE', 'HAS_DIABETES']" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(df9.columns)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Divide into training and test set, and define train and test dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "df_input = df9.drop(columns=\"HAS_DIABETES\")\n", + "df_y = df9[\"HAS_DIABETES\"]\n", + "\n", + "for col in list(df_input.columns):\n", + " df_input[col] = list(map(float, df_input[col]))\n", + " \n", + "train_ratio = 0.7\n", + "\n", + "msk = np.random.rand(len(df9)) < train_ratio\n", + "train_set = df_input[msk].values\n", + "train_labels = df_y[msk].values\n", + "test_set = df_input[~msk].values\n", + "test_labels = df_y[~msk].values" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "\n", + "train_target = torch.tensor(train_labels.astype(np.float32))\n", + "train = torch.tensor(train_set.astype(np.float32)) \n", + "\n", + "test_target = torch.tensor(test_labels.astype(np.float32))\n", + "test = torch.tensor(test_set.astype(np.float32)) \n", + "\n", + "bs = 10\n", + "train_tensor = torch.utils.data.TensorDataset(train, train_target) \n", + "train_loader = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = bs, shuffle = True)\n", + "\n", + "test_tensor = torch.utils.data.TensorDataset(test, test_target) \n", + "test_loader = torch.utils.data.DataLoader(dataset = test_tensor, batch_size = bs, shuffle = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simple logistic regression model" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__()\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = self.linear(x)\n", + " return outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10\n", + "input_dim = 6\n", + "output_dim = 2\n", + "lr_rate = 0.001\n", + "\n", + "model = LogisticRegression(input_dim, output_dim)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training loop and evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0. Iteration: 200. Loss: 0.5905064940452576. Accuracy: 91.9602529358627.\n", + "Epoch: 1. Iteration: 400. Loss: 0.016342487186193466. Accuracy: 93.4959349593496.\n", + "Epoch: 2. Iteration: 600. Loss: 0.32516592741012573. Accuracy: 92.05058717253839.\n", + "Epoch: 2. Iteration: 800. Loss: 0.003679836168885231. Accuracy: 91.9602529358627.\n", + "Epoch: 3. Iteration: 1000. Loss: 0.0001037298352457583. Accuracy: 90.96657633243.\n", + "Epoch: 4. Iteration: 1200. Loss: 1.4651803970336914. Accuracy: 71.90605239385727.\n", + "Epoch: 5. Iteration: 1400. Loss: 2.4813601970672607. Accuracy: 90.51490514905149.\n", + "Epoch: 5. Iteration: 1600. Loss: 0.7050111889839172. Accuracy: 90.51490514905149.\n", + "Epoch: 6. Iteration: 1800. Loss: 0.23979775607585907. Accuracy: 71.27371273712737.\n", + "Epoch: 7. Iteration: 2000. Loss: 0.9797161817550659. Accuracy: 37.39837398373984.\n", + "Epoch: 8. Iteration: 2200. Loss: 0.0. Accuracy: 90.51490514905149.\n", + "Epoch: 8. Iteration: 2400. Loss: 0.851937472820282. Accuracy: 81.39114724480578.\n", + "Epoch: 9. Iteration: 2600. Loss: 5.144383430480957. Accuracy: 90.51490514905149.\n" + ] + } + ], + "source": [ + "from torch.autograd import Variable\n", + "iter = 0\n", + "for epoch in range(int(epochs)):\n", + " for i, (point, labels) in enumerate(train_loader):\n", + " points = Variable(point.view(-1, 6))\n", + " labels = Variable(labels.type(torch.LongTensor))\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(points)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " iter+=1\n", + " if iter%200==0:\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for points, labels in test_loader:\n", + " points = Variable(points.view(-1, 6))\n", + " outputs = model(points)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * float(correct)/float(total)\n", + " print(\"Epoch: {}. Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, iter, loss.item(), accuracy))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, From 30a2a383bc9862c12acea178035892cac526aa51 Mon Sep 17 00:00:00 2001 From: daler3 Date: Fri, 23 Oct 2020 09:56:46 +0200 Subject: [PATCH 04/21] Added confusion matrix --- .../Diabetes_prediction_preprocessing.ipynb | 105 +++++++++++++----- 1 file changed, 77 insertions(+), 28 deletions(-) diff --git a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb index 3fb382d..55e3697 100644 --- a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb +++ b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb @@ -1303,7 +1303,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -1320,7 +1320,7 @@ "train_loader = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = bs, shuffle = True)\n", "\n", "test_tensor = torch.utils.data.TensorDataset(test, test_target) \n", - "test_loader = torch.utils.data.DataLoader(dataset = test_tensor, batch_size = bs, shuffle = True)" + "test_loader = torch.utils.data.DataLoader(dataset = test_tensor, batch_size = bs, shuffle = False)" ] }, { @@ -1332,7 +1332,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ @@ -1348,7 +1348,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -1372,32 +1372,33 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 177, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0. Iteration: 200. Loss: 0.5905064940452576. Accuracy: 91.9602529358627.\n", - "Epoch: 1. Iteration: 400. Loss: 0.016342487186193466. Accuracy: 93.4959349593496.\n", - "Epoch: 2. Iteration: 600. Loss: 0.32516592741012573. Accuracy: 92.05058717253839.\n", - "Epoch: 2. Iteration: 800. Loss: 0.003679836168885231. Accuracy: 91.9602529358627.\n", - "Epoch: 3. Iteration: 1000. Loss: 0.0001037298352457583. Accuracy: 90.96657633243.\n", - "Epoch: 4. Iteration: 1200. Loss: 1.4651803970336914. Accuracy: 71.90605239385727.\n", - "Epoch: 5. Iteration: 1400. Loss: 2.4813601970672607. Accuracy: 90.51490514905149.\n", - "Epoch: 5. Iteration: 1600. Loss: 0.7050111889839172. Accuracy: 90.51490514905149.\n", - "Epoch: 6. Iteration: 1800. Loss: 0.23979775607585907. Accuracy: 71.27371273712737.\n", - "Epoch: 7. Iteration: 2000. Loss: 0.9797161817550659. Accuracy: 37.39837398373984.\n", - "Epoch: 8. Iteration: 2200. Loss: 0.0. Accuracy: 90.51490514905149.\n", - "Epoch: 8. Iteration: 2400. Loss: 0.851937472820282. Accuracy: 81.39114724480578.\n", - "Epoch: 9. Iteration: 2600. Loss: 5.144383430480957. Accuracy: 90.51490514905149.\n" + "Epoch: 0. Iteration: 200. Loss: 7.707485929131508e-05. Accuracy: 90.78590785907859.\n", + "Epoch: 1. Iteration: 400. Loss: 0.0004317985731177032. Accuracy: 94.76061427280939.\n", + "Epoch: 2. Iteration: 600. Loss: 0.0. Accuracy: 90.60523938572719.\n", + "Epoch: 2. Iteration: 800. Loss: 1.2545890808105469. Accuracy: 45.347786811201445.\n", + "Epoch: 3. Iteration: 1000. Loss: 0.07240111380815506. Accuracy: 94.03794037940379.\n", + "Epoch: 4. Iteration: 1200. Loss: 0.0018771663308143616. Accuracy: 93.13459801264679.\n", + "Epoch: 5. Iteration: 1400. Loss: 1.3109453916549683. Accuracy: 79.67479674796748.\n", + "Epoch: 5. Iteration: 1600. Loss: 1.7740905284881592. Accuracy: 92.6829268292683.\n", + "Epoch: 6. Iteration: 1800. Loss: 0.0184627752751112. Accuracy: 94.2186088527552.\n", + "Epoch: 7. Iteration: 2000. Loss: 0.6591505408287048. Accuracy: 75.51942186088527.\n", + "Epoch: 8. Iteration: 2200. Loss: 1.10377836227417. Accuracy: 90.60523938572719.\n", + "Epoch: 8. Iteration: 2400. Loss: 0.2253536880016327. Accuracy: 92.773261065944.\n", + "Epoch: 9. Iteration: 2600. Loss: 3.0842161178588867. Accuracy: 90.51490514905149.\n" ] } ], "source": [ "from torch.autograd import Variable\n", "iter = 0\n", + "loss_v = []\n", "for epoch in range(int(epochs)):\n", " for i, (point, labels) in enumerate(train_loader):\n", " points = Variable(point.view(-1, 6))\n", @@ -1406,6 +1407,7 @@ " optimizer.zero_grad()\n", " outputs = model(points)\n", " loss = criterion(outputs, labels)\n", + " loss_v.append(loss.detach().numpy())\n", " loss.backward()\n", " optimizer.step()\n", " \n", @@ -1426,24 +1428,73 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 141, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "#final test\n", + "labels = test_target\n", + "outputs = model(test)\n", + "_, predicted = torch.max(outputs.data,1)\n", + "total = labels.size(0)\n", + "correct = (predicted == labels).sum()\n", + "accuracy = float(correct)/float(total)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 166, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sn\n", + "\n", + "cm = confusion_matrix(labels, predicted, labels=[0,1])\n", + "tn, fp, fn, tp = cm.ravel()" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 170, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1002, 0],\n", + " [ 105, 0]])" + ] + }, + "execution_count": 170, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cm" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1002, 0, 105, 0)" + ] + }, + "execution_count": 171, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tn, fp, fn, tp" + ] }, { "cell_type": "code", @@ -1464,9 +1515,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "\n" - ] + "source": [] }, { "cell_type": "code", From 5154a38e9ec090691ad90eb3d793e6e9c7def6b4 Mon Sep 17 00:00:00 2001 From: daler3 Date: Wed, 18 Nov 2020 13:40:28 +0100 Subject: [PATCH 05/21] started experimenting with dualhead --- examples/dualheaded/MNIST-DualHeaded.ipynb | 430 +++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 examples/dualheaded/MNIST-DualHeaded.ipynb diff --git a/examples/dualheaded/MNIST-DualHeaded.ipynb b/examples/dualheaded/MNIST-DualHeaded.ipynb new file mode 100644 index 0000000..e62499e --- /dev/null +++ b/examples/dualheaded/MNIST-DualHeaded.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# prerequisites\n", + "from __future__ import print_function\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torchvision import datasets, transforms\n", + "from torch.autograd import Variable\n", + "from torchvision.utils import save_image\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import argparse\n", + "from torch.optim.lr_scheduler import StepLR\n", + "\n", + "transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + "bs = 64\n", + "# MNIST Dataset\n", + "train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", + "test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)\n", + "\n", + "# Data Loader (Input Pipeline)\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([64, 1, 28, 28])\n", + "torch.Size([64])\n" + ] + } + ], + "source": [ + "dataiter = iter(train_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "print(images.shape)\n", + "print(labels.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOLklEQVR4nO3df6xU9ZnH8c+zWPjDEsOPi16BSLfBuGaTpTghm6gNKtv4g4gktoE/GjZRb6NoijYR7aolhESz2jZrskFBSNkVrU2ogkJ2q0gwNaZxuEHBJXsva9hyy4U7aEzhH7vSZ/+4h+YCd75zmXNmzsDzfiWTmTnPfOc8TPjcMzPfmfmauwvAxe+vym4AQHsQdiAIwg4EQdiBIAg7EMQl7dzZ1KlTfdasWe3cJRDKoUOHdPz4cRutlivsZnarpH+RNE7SS+7+TOr2s2bNUrVazbNLAAmVSqVuremn8WY2TtK/SrpN0rWSlprZtc3eH4DWyvOafZ6kg+7+qbv/SdIvJS0qpi0ARcsT9umSDo+4PpBtO4OZ9ZhZ1cyqtVotx+4A5JEn7KO9CXDOZ2/dfZ27V9y90tXVlWN3APLIE/YBSTNHXJ8h6Ui+dgC0Sp6wfyhptpl9w8zGS1oiaVsxbQEoWtNTb+7+lZk9KOk/NTz1ttHdPymsMwCFyjXP7u47JO0oqBcALcTHZYEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0Ioq1LNqPzbN26NVnfvXt3sn7w4MFk/a233qpbcz9nAaEzTJ9+zmpiZ3jiiSeS9Xvvvbdu7ZJL4v3X58gOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0HEm2wM5siRI8n6ypUrk/X+/v5c+zezurUrr7wyOXZwcDBZX758ebK+Y0f9BYZ7enqSYxcuXJisX4hyhd3MDkk6IemUpK/cvVJEUwCKV8SR/SZ3P17A/QBoIV6zA0HkDbtL+o2Z7TGzUV8EmVmPmVXNrFqr1XLuDkCz8ob9enefK+k2ScvN7Ntn38Dd17l7xd0rXV1dOXcHoFm5wu7uR7LzIUmvS5pXRFMAitd02M3sUjObePqypO9I2l9UYwCKlefd+MslvZ7No14i6RV3/49CukJh1qxZk6z39fUl641eet10003Jeuo75bNnz06O7e3tTdZfe+21ZH379u11azNmzEiOZZ59BHf/VNLfFdgLgBZi6g0IgrADQRB2IAjCDgRB2IEg+IrrRWD16tV1axs2bEiOnThxYrK+ZcuWZP2GG25I1vO46qqrkvXFixc3Xe/u7m6qpwsZR3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIJ59gvABx98kKyvWrWq6ft++umnk/VWzqM38uWXXybrQ0NDyfrzzz/f9NiLEUd2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCefaLQGpZ5EbK/MnkRvPoDz/8cLL+4osvNr3vU6dONT32QsWRHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJ49uKuvvrq0fb///vvJep55dEl66qmnco2/2DQ8spvZRjMbMrP9I7ZNNrO3zaw/O5/U2jYB5DWWp/G/kHTrWdsek7TT3WdL2pldB9DBGobd3d+T9PlZmxdJ2pRd3iTproL7AlCwZt+gu9zdByUpO59W74Zm1mNmVTOr1mq1JncHIK+Wvxvv7uvcveLula6urlbvDkAdzYb9mJl1S1J2Hu+nOoELTLNh3yZpWXZ5maStxbQDoFUazrOb2auS5kuaamYDkn4i6RlJvzKzeyT9XtJ3W9lkdHPnzm263tvbmxy7fv36ZP2+++5L1hu5//7769ZeeeWVXPe9dOnSZH3FihW57v9i0zDs7l7vEb2l4F4AtBAflwWCIOxAEIQdCIKwA0EQdiAIvuJ6AZgwYUKyvnnz5rq1BQsWJMf29PQk62+++WayXq1Wk/WjR4/WrY0bNy459s4770zWX3755WQdZ+LIDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBMM9+EUj9HPQ777yTHHvNNdck69u3b2+qp9NSy0k3+qnnJ598Mte+cSaO7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBPPsF7m+vr5S95/6DADz6O3FkR0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmCe/SK3fPnyUvc/ODhYt7Znz57k2Ouuu67odkJreGQ3s41mNmRm+0dsW2VmfzCzvdnp9ta2CSCvsTyN/4WkW0fZ/nN3n5OddhTbFoCiNQy7u78n6fM29AKghfK8QfegmX2cPc2fVO9GZtZjZlUzq9ZqtRy7A5BHs2FfK+mbkuZIGpT003o3dPd17l5x90pXV1eTuwOQV1Nhd/dj7n7K3f8sab2kecW2BaBoTYXdzLpHXF0saX+92wLoDA3n2c3sVUnzJU01swFJP5E038zmSHJJhyT9oIU9hvfZZ58l648//njd2uHDh5NjL7vssmT9kUceSdbdPVlfs2ZN3dru3buTY5lnL1bDsLv70lE2b2hBLwBaiI/LAkEQdiAIwg4EQdiBIAg7EARfce0A+/btS9YXLlyYrA8MDNStpZZMlqS1a9cm60uWLEnWG9mwof7EzRtvvJEcu2zZsmR9ypQpTfUUFUd2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCefY2WL16dbL+wgsvJOvHjh1L1mfOnFm39sADDyTH5p1HbyR1/88++2xy7K5du5L1u+++u6meouLIDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBMM/eBqnvdEvS0aNHk/VG30l/6aWX6tYWLFiQHNtqV1xxRdNjU/8uiXn288WRHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJ69AHv37k3Wv/jii2S90bLHjZYuvvHGG+vWTpw4kRzbqLft27cn642+q5/6DEGjfzeK1fDIbmYzzWyXmR0ws0/M7IfZ9slm9raZ9Wfnk1rfLoBmjeVp/FeSfuTufyPp7yUtN7NrJT0maae7z5a0M7sOoEM1DLu7D7p7b3b5hKQDkqZLWiRpU3azTZLualWTAPI7rzfozGyWpG9J+p2ky919UBr+gyBpWp0xPWZWNbNqrVbL1y2Apo057Gb2dUlbJK1w9z+OdZy7r3P3irtXurq6mukRQAHGFHYz+5qGg77Z3X+dbT5mZt1ZvVvSUGtaBFCEhlNvNvz9yg2SDrj7z0aUtklaJumZ7HxrSzq8APT39yfrJ0+eTNYbfYW1t7c3WU8t6Xz8+PHk2I8++ihZb9RbI6nxEyZMSI5duXJlrn3jTGOZZ79e0vcl7TOz0xPKP9ZwyH9lZvdI+r2k77amRQBFaBh2d/+tpHp/nm8pth0ArcLHZYEgCDsQBGEHgiDsQBCEHQiCr7gW4Oabb07Wp00b9ZPEf9FoSeZG3n333Vzj87jllvSEzKOPPlq3Nnny5OTYuXPnNtUTRseRHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJ69AFOmTEnWG30f/bnnnkvWd+zYkaz39fUl6ynz589P1u+4445k/aGHHkrWx48ff74toUU4sgNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAENbOZXMrlYpXq9W27Q+IplKpqFqtjvpr0BzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIhmE3s5lmtsvMDpjZJ2b2w2z7KjP7g5ntzU63t75dAM0ay49XfCXpR+7ea2YTJe0xs7ez2s/dPf3LCwA6wljWZx+UNJhdPmFmByRNb3VjAIp1Xq/ZzWyWpG9J+l226UEz+9jMNprZpDpjesysambVWq2Wq1kAzRtz2M3s65K2SFrh7n+UtFbSNyXN0fCR/6ejjXP3de5ecfdKV1dXAS0DaMaYwm5mX9Nw0De7+68lyd2Pufspd/+zpPWS5rWuTQB5jeXdeJO0QdIBd//ZiO3dI262WNL+4tsDUJSxvBt/vaTvS9pnZnuzbT+WtNTM5khySYck/aAlHQIoxFjejf+tpNG+H5v+MXMAHYVP0AFBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Jo65LNZlaT9L8jNk2VdLxtDZyfTu2tU/uS6K1ZRfZ2lbuP+vtvbQ37OTs3q7p7pbQGEjq1t07tS6K3ZrWrN57GA0EQdiCIssO+ruT9p3Rqb53al0RvzWpLb6W+ZgfQPmUf2QG0CWEHgigl7GZ2q5n9t5kdNLPHyuihHjM7ZGb7smWoqyX3stHMhsxs/4htk83sbTPrz85HXWOvpN46YhnvxDLjpT52ZS9/3vbX7GY2TlKfpH+QNCDpQ0lL3f2/2tpIHWZ2SFLF3Uv/AIaZfVvSSUn/5u5/m237Z0mfu/sz2R/KSe6+skN6WyXpZNnLeGerFXWPXGZc0l2S/lElPnaJvr6nNjxuZRzZ50k66O6fuvufJP1S0qIS+uh47v6epM/P2rxI0qbs8iYN/2dpuzq9dQR3H3T33uzyCUmnlxkv9bFL9NUWZYR9uqTDI64PqLPWe3dJvzGzPWbWU3Yzo7jc3Qel4f88kqaV3M/ZGi7j3U5nLTPeMY9dM8uf51VG2EdbSqqT5v+ud/e5km6TtDx7uoqxGdMy3u0yyjLjHaHZ5c/zKiPsA5Jmjrg+Q9KREvoYlbsfyc6HJL2uzluK+tjpFXSz86GS+/mLTlrGe7RlxtUBj12Zy5+XEfYPJc02s2+Y2XhJSyRtK6GPc5jZpdkbJzKzSyV9R523FPU2Scuyy8skbS2xlzN0yjLe9ZYZV8mPXenLn7t720+SbtfwO/L/I+mfyuihTl9/Lemj7PRJ2b1JelXDT+v+T8PPiO6RNEXSTkn92fnkDurt3yXtk/SxhoPVXVJvN2j4peHHkvZmp9vLfuwSfbXlcePjskAQfIIOCIKwA0EQdiAIwg4EQdiBIAg7EARhB4L4f8J+RLlbURtyAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(images[0].numpy().squeeze(), cmap='gray_r');" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", + " self.dropout1 = nn.Dropout(0.25)\n", + " self.dropout2 = nn.Dropout(0.5)\n", + " self.fc1 = nn.Linear(9216, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, 2)\n", + " x = self.dropout1(x)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.dropout2(x)\n", + " x = self.fc2(x)\n", + " output = F.log_softmax(x, dim=1)\n", + " return output\n", + "\n", + "\n", + "def train(args, model, device, train_loader, optimizer, epoch):\n", + " model.train()\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(device), target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if batch_idx % args.log_interval == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item()))\n", + " if args.dry_run:\n", + " break\n", + "\n", + "\n", + "def test(model, device, test_loader):\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", + " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + "\n", + " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lr = 0.001\n", + "epochs = 10\n", + "\n", + "model = Net().to(device)\n", + "optimizer = optim.Adadelta(model.parameters(), lr=lr)\n", + "\n", + "scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n", + "for epoch in range(1, epochs + 1):\n", + " train(args, model, device, train_loader, optimizer, epoch)\n", + " test(model, device, test_loader)\n", + " scheduler.step()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**DUAL HEADED**\n", + "\n", + "\n", + "* Client_1\n", + " * Has Model Segment 1\n", + " * Has the handwritten images segment 1 (vertical distr: left part)\n", + "* Client_2\n", + " * Has model Segment 1\n", + " * Has the handwritten images segment 2 (vertical distr: right part)\n", + "* Server\n", + " * Has Model Segment 2\n", + " * Has the image labels\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import pytest\n", + "import syft as sy\n", + "from syft.core.node.common.service.auth import AuthorizationException\n", + "\n", + "import torch as th\n", + "from torchvision import datasets, transforms\n", + "from torch import nn, optim\n", + "import syft as sy\n", + "import numpy as np\n", + "from syft.util import key_emoji" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "# create some workers\n", + "dev_1 = sy.Device(name=\"client_1\")\n", + "dev_2 = sy.Device(name=\"client_2\")\n", + "client_1 = dev_1.get_client()\n", + "client_2 = dev_2.get_client()\n", + "\n", + "server_dev = sy.Device(name=\"server\")\n", + "server = server_dev.get_client()\n", + "\n", + "data_owners = (client_1, client_2)\n", + "model_locations = [client_1, client_2, server]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">\n", + ">\n", + ">\n" + ] + } + ], + "source": [ + "for location in model_locations:\n", + " print(location)\n", + " x = location" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'client_2 Client'" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_locations[1].name.split()" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "input_size= [28*14, 28*14]\n", + "hidden_sizes= {\"client_1\": [32, 64], \"client_2\":[32, 64], \"server\":[128, 64]}\n", + "\n", + "#create model segment for each worker\n", + "models = {\n", + " \"client_1\": nn.Sequential(\n", + " nn.Linear(input_size[0], hidden_sizes[\"client_1\"][0]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"client_1\"][0], hidden_sizes[\"client_1\"][1]),\n", + " nn.ReLU(),\n", + " ),\n", + " \"client_2\": nn.Sequential(\n", + " nn.Linear(input_size[1], hidden_sizes[\"client_2\"][0]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"client_2\"][0], hidden_sizes[\"client_2\"][1]),\n", + " nn.ReLU(),\n", + " ),\n", + " \"server\": nn.Sequential(\n", + " nn.Linear(hidden_sizes[\"server\"][0], hidden_sizes[\"server\"][1]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"server\"][1], 10),\n", + " nn.LogSoftmax(dim=1)\n", + " )\n", + "}\n", + "\n", + "# Create optimisers for each segment and link to their segment\n", + "optimizers = [\n", + " optim.SGD(models[location.name.split(\" \")[0]].parameters(), lr=0.05,)\n", + " for location in model_locations\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "ename": "Exception", + "evalue": "Object has no serializable_wrapper_type", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#send model segement to each client and server\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlocation\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel_locations\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmodels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlocation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\" \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/ast/klass.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, client, searchable)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0;31m# Step 3: send message\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 189\u001b[0;31m \u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_immediate_msg_without_reply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;31m# Step 4: return pointer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/node/common/client.py\u001b[0m in \u001b[0;36msend_immediate_msg_without_reply\u001b[0;34m(self, msg, route_index)\u001b[0m\n\u001b[1;32m 255\u001b[0m )\n\u001b[1;32m 256\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msigning_key\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigning_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"> Sending {msg.pprint} {self.pprint} ➡️ {msg.address.pprint}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroutes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mroute_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_immediate_msg_without_reply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/common/message.py\u001b[0m in \u001b[0;36msign\u001b[0;34m(self, signing_key)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34mf\"> Signing with {self.address.key_emoji(key=signing_key.verify_key)}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m )\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0msigned_message\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msigning_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_bytes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# signed_type will be the final subclass callee's closest parent signed_type\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/common/serde/serializable.py\u001b[0m in \u001b[0;36mserialize\u001b[0;34m(self, to_proto, to_bytes)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;31m# indent=None means no white space or \\n in the serialized version\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[0;31m# this is compatible with json.dumps(x, indent=None)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 264\u001b[0;31m \u001b[0mblob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_object2proto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSerializeToString\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 265\u001b[0m blob = DataMessage(\n\u001b[1;32m 266\u001b[0m \u001b[0mobj_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_fully_qualified_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mblob\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/node/common/action/save_object_action.py\u001b[0m in \u001b[0;36m_object2proto\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mid_at_location\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid_at_location\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m \u001b[0mobj_ob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 76\u001b[0m \u001b[0maddr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/ast/klass.py\u001b[0m in \u001b[0;36mserialize\u001b[0;34m(self, to_proto, to_bytes)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0mto_proto\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mto_proto\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 228\u001b[0;31m \u001b[0mto_bytes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mto_bytes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 229\u001b[0m )\n\u001b[1;32m 230\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/common/serde/serialize.py\u001b[0m in \u001b[0;36m_serialize\u001b[0;34m(obj, to_proto, to_bytes)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mis_serializable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserializable_wrapper_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Object {type(obj)} has no serializable_wrapper_type\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0mis_serializable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mException\u001b[0m: Object has no serializable_wrapper_type" + ] + } + ], + "source": [ + "\n", + "\n", + "\n", + "#send model segement to each client and server\n", + "for location in model_locations:\n", + " models[location.name.split(\" \")[0]].send(location)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "bob_vm = client_1.get_client()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "xp = th.tensor([1,2,3]).tag(\"some\", \"diabetes\", \"data\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xp.send(vm_sever_cli)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From c782ddc5bae5a806345b64244808ef1b548732f6 Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Fri, 4 Dec 2020 15:54:26 +0100 Subject: [PATCH 06/21] Added verticalfederateddataset class --- .../dualheaded/verticalfederateddataset.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 examples/dualheaded/verticalfederateddataset.py diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py new file mode 100644 index 0000000..068c135 --- /dev/null +++ b/examples/dualheaded/verticalfederateddataset.py @@ -0,0 +1,83 @@ +import syft as sy +from __future__ import print_function +import torch +from torch.utils.data import Dataset + + +def split_data(n_workers) + idx = 0 + dic_single_datasets = {} + for i in range(0, n_workers): + dic_single_datasets[i] = [] + + for tensor, label in train_dataset: + height = tensor.shape[-1]//n_workers + data_parts_list = [] + #put in a list the parts to give to single workers + for i in range(0,n_workers-1): + dic_single_datasets[i].append(tuple([tensor[:, :, height * i : height * (i + 1)], label, idx])) + dic_single_datasets[n_workers-1].append(tuple([tensor[:, :, height * (i+1) : ], label, idx ])) #last part of the image + + idx += 1 + #each value of the dictionary is a list of triples + return dic_single_datasets + + +class BaseVerticalDataset(Dataset): + def __init__(self, datalist): + self.dataset = datalist + self.get_data_tensor() + self.worker_id = None + self.data_pointer = None + self.label_pointer = None + self.index_pointer = None + + def __len__(self): + return len(self.dataset) + + def __get_item__(self, idx): + return self.dataset[i] + + def get_data_tensor(self): + self.data_tensor = [] + self.label_tensor = [] + self.index_tensor = [] + for el in self.dataset: + self.data_tensor.append(el[0]) + self.label_tensor.append(el[1]) + self.index_tensor.append(el[2]) + self.data_tensor = torch.stack(self.data_tensor) + self.label_tensor = torch.Tensor(self.label_tensor) + self.index_tensor = torch.Tensor(self.index_tensor) + + def send_to_worker(self, worker): + self.worker_id = worker + self.data_pointer = self.data_tensor.send(worker) + self.label_pointer = self.label_tensor.send(worker) + self.index_pointer = self.index_tensor.send(worker) + return self.data_pointer, self.label_pointer, self.index_pointer + + +class VerticalFederatedDataset(): + #takes a list of BaseVerticalDatasets (already sent to workers) + def __init__(self, datasets): + + self.datasets = {} + + for dataset in datasets: + worker_id = dataset.worker_id + self.datasets[worker_id] = dataset + + + def workers(self): + """ + Returns: list of workers + """ + + return list(self.datasets.keys()) + + + class VerticalFederatedDataLoader(): + + def __init__(self, vertical_fed_dataset): + self.vertical_fed_dataset = vertical_fed_dataset \ No newline at end of file From 465a91c9887f5c089e12f9bb3bf6f892f9943e65 Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Fri, 4 Dec 2020 16:48:14 +0100 Subject: [PATCH 07/21] Added dataset parameter to split_data --- examples/dualheaded/verticalfederateddataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index 068c135..d05acf3 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -4,13 +4,13 @@ from torch.utils.data import Dataset -def split_data(n_workers) +def split_data(n_workers, dataset) idx = 0 dic_single_datasets = {} for i in range(0, n_workers): dic_single_datasets[i] = [] - for tensor, label in train_dataset: + for tensor, label in dataset: height = tensor.shape[-1]//n_workers data_parts_list = [] #put in a list the parts to give to single workers @@ -80,4 +80,4 @@ def workers(self): class VerticalFederatedDataLoader(): def __init__(self, vertical_fed_dataset): - self.vertical_fed_dataset = vertical_fed_dataset \ No newline at end of file + self.vertical_fed_dataset = vertical_fed_dataset From d4615e373e799ed32a38c83a5caa5318ca930cf9 Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Fri, 4 Dec 2020 21:01:34 +0100 Subject: [PATCH 08/21] Reformatted and added create vertical method reformatted and added method to split and directly create vertical federated dataset --- .../dualheaded/verticalfederateddataset.py | 89 ++++++++++++------- 1 file changed, 58 insertions(+), 31 deletions(-) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index d05acf3..4f11ace 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -4,51 +4,78 @@ from torch.utils.data import Dataset -def split_data(n_workers, dataset) +def split_data(dataset, worker_list=None, n_workers=None): + + if worker_list == None: + if n_workers == None: + n_workers = 2 #default + worker_list = list(range(0, n_workers)) + idx = 0 + dic_single_datasets = {} - for i in range(0, n_workers): - dic_single_datasets[i] = [] - + for worker in worker_list: + dic_single_datasets[worker] = [[],[],[]] + for tensor, label in dataset: - height = tensor.shape[-1]//n_workers - data_parts_list = [] - #put in a list the parts to give to single workers - for i in range(0,n_workers-1): - dic_single_datasets[i].append(tuple([tensor[:, :, height * i : height * (i + 1)], label, idx])) - dic_single_datasets[n_workers-1].append(tuple([tensor[:, :, height * (i+1) : ], label, idx ])) #last part of the image - + height = tensor.shape[-1]//len(worker_list) + i = 0 + for worker in worker_list[:-1]: + dic_single_datasets[worker][0].append(tensor[:, :, height * i : height * (i + 1)]) + dic_single_datasets[worker][1].append(label) + dic_single_datasets[worker][2].append(idx) + i += 1 + + dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i+1) : ]) + dic_single_datasets[worker_list[-1]][1].append(label) + dic_single_datasets[worker_list[-1]][2].append(idx) + idx += 1 - #each value of the dictionary is a list of triples + return dic_single_datasets +def split_data_create_vertical_dataset(dataset, worker_list): + + dic_single_datasets = split_data(dataset, worker_list=worker_list) + + #create base datasets + base_datasets_list = [] + for worker in worker_list: + base_datasets_list.append(BaseVerticalDataset(dic_single_datasets[worker], worker_id=worker)) + + #create VerticalFederatedDataset + return VerticalFederatedDataset(base_datasets_list) + class BaseVerticalDataset(Dataset): - def __init__(self, datalist): - self.dataset = datalist - self.get_data_tensor() + def __init__(self, datatuples, worker_id=None): + + self.fill_tensors(datatuples) + self.worker_id = None - self.data_pointer = None - self.label_pointer = None - self.index_pointer = None + if worker_id != None: + self.send_to_worker(worker_id) + self.worker_id = worker_id + + self.dataset_tolist() + def __len__(self): - return len(self.dataset) + return self.data_tensor.shape[0] def __get_item__(self, idx): - return self.dataset[i] + return tuple([self.data_tensor[idx], self.label_tensor[idx], self.index_tensor[idx]]) + + def __fill_tensors(self, data_tuples): + self.data_tensor = torch.stack(data_tuples[0]) + self.label_tensor = torch.Tensor(data_tuples[1]) + self.index_tensor = torch.Tensor(data_tuples[2]) - def get_data_tensor(self): - self.data_tensor = [] - self.label_tensor = [] - self.index_tensor = [] - for el in self.dataset: - self.data_tensor.append(el[0]) - self.label_tensor.append(el[1]) - self.index_tensor.append(el[2]) - self.data_tensor = torch.stack(self.data_tensor) - self.label_tensor = torch.Tensor(self.label_tensor) - self.index_tensor = torch.Tensor(self.index_tensor) + def __dataset_tolist(self): + flat_dataset = [] + for i in range(0, self.__len__()): + flat_dataset.append(self.__get_item__(i)) + self.dataset = flat_dataset def send_to_worker(self, worker): self.worker_id = worker From ca1ff75afb53f7f5ed96f870a0931a21209c2149 Mon Sep 17 00:00:00 2001 From: daler3 Date: Sat, 5 Dec 2020 10:47:22 +0100 Subject: [PATCH 09/21] added comments and TODOs to the dataset file --- .../dualheaded/verticalfederateddataset.py | 151 +++++++++++++++--- 1 file changed, 132 insertions(+), 19 deletions(-) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index 4f11ace..6098fb4 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -3,20 +3,64 @@ import torch from torch.utils.data import Dataset +""" +Utility functions to split and distribute the data across different workers, +create vertical datasets and federate them. It also contains datasets and dataloader classes. +This code is meant to be used with dual-headed Neural Networks, where there are a bunch of different workers, +which agrees on the labels, and there is a server with the labels only. -def split_data(dataset, worker_list=None, n_workers=None): +Code built upon: +- Abbas Ismail's (@abbas5253) work on dual-headed NN. In particular, check Configuration 1: + https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb +- Syft 2.0 Federated Learning dataset and dataloader: https://github.com/OpenMined/PySyft/tree/syft_0.2.x/syft/frameworks/torch/fl + +""" + + +def split_data(dataset, worker_list=None, n_workers=2, label_server=None): + """ + Utility function to create a vertical split of the data. It also creates a numerical index to keep + track of the single data across different split. + + Args: + dataset: an iterable object represent the dataset. Each element of the iterable + is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. + #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. + + worker_list (optional): The list of VirtualWorkers to distribute the data vertically across. + + n_workers(optional, default=2): The number of workers to split the data across. If worker_list is not passed, this is necessary to create the split. + + label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) + #TODO: add the code to send labels to the server + + Returns: + a dictionary holding as keys the workers passed as parameters, or integers corresponding to the split, + and as values a list of lists, where the first element are the single tensor of the data, the second the labels, + the third the index, which is to keep track of the same data point. + """ if worker_list == None: - if n_workers == None: - n_workers = 2 #default worker_list = list(range(0, n_workers)) - idx = 0 + #counter to create the index of different data samples + idx = 0 + #dictionary to accomodate the split data dic_single_datasets = {} for worker in worker_list: - dic_single_datasets[worker] = [[],[],[]] - + """ + Each value is a list of three elements, to accomodate, in order: + - data examples (as tensors) + - label + - index + """ + dic_single_datasets[worker] = [[],[],[]] + + """ + Loop through the dataset to split the data and labels vertically across workers. + Splitting method from @abbas5253: https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/distribute_data.py + """ for tensor, label in dataset: height = tensor.shape[-1]//len(worker_list) i = 0 @@ -26,6 +70,7 @@ def split_data(dataset, worker_list=None, n_workers=None): dic_single_datasets[worker][2].append(idx) i += 1 + #add the value of the last worker / split dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i+1) : ]) dic_single_datasets[worker_list[-1]][1].append(label) dic_single_datasets[worker_list[-1]][2].append(idx) @@ -35,49 +80,101 @@ def split_data(dataset, worker_list=None, n_workers=None): return dic_single_datasets -def split_data_create_vertical_dataset(dataset, worker_list): - - dic_single_datasets = split_data(dataset, worker_list=worker_list) +def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): + """ + Utility function to distribute the data vertically across workers and create a vertical federated dataset. + + Args: + dataset: an iterable object represent the dataset. Each element of the iterable + is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. + #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. + + worker_list: The list of VirtualWorkers to distribute the data vertically across. - #create base datasets + label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) + + Returns: + a VerticalFederatedDataset. + """ + + #get a dictionary of workers --> list of triples (data, label, idx) representing the dataset. + dic_single_datasets = split_data(dataset, worker_list=worker_list, label_server=label_server) + + #create base vertical datasets list, to be passed to a vertical federated dataset base_datasets_list = [] for worker in worker_list: base_datasets_list.append(BaseVerticalDataset(dic_single_datasets[worker], worker_id=worker)) - #create VerticalFederatedDataset + #create VerticalFederatedDataset return VerticalFederatedDataset(base_datasets_list) class BaseVerticalDataset(Dataset): + """ + Base Vertical Dataset class, containing a portion of a vertically splitted dataset. + + Args: + datatuples: a list where each element is another list (or tuple) of exactly 3 elements. + The first one is a sample data, the second one is the corresponding label, and the third one the index + (necessary to keep track of the same vertically splitted examples across multiple workers) + + worker_id (optional): the worker to which we want to send the dataset + + """ def __init__(self, datatuples, worker_id=None): - self.fill_tensors(datatuples) + self._fill_tensors(datatuples) self.worker_id = None if worker_id != None: self.send_to_worker(worker_id) self.worker_id = worker_id - self.dataset_tolist() + self._dataset_tolist() def __len__(self): + """ + Returns: amount of samples in the dataset + """ return self.data_tensor.shape[0] def __get_item__(self, idx): + """ + Args: + idx: index of the example we want to get + + Returns: a tuple with data, label, index of a single example. + """ return tuple([self.data_tensor[idx], self.label_tensor[idx], self.index_tensor[idx]]) def __fill_tensors(self, data_tuples): + """ + Private method to fill the tensors of the tuples, labels and index. + """ self.data_tensor = torch.stack(data_tuples[0]) self.label_tensor = torch.Tensor(data_tuples[1]) self.index_tensor = torch.Tensor(data_tuples[2]) def __dataset_tolist(self): - flat_dataset = [] + """ + Private method to create a compact list version of the dataset, so that len(dataset) is the number of examples. + """ + list_dataset = [] for i in range(0, self.__len__()): - flat_dataset.append(self.__get_item__(i)) - self.dataset = flat_dataset + list_dataset.append(self.__get_item__(i)) + self.dataset = list_dataset + def send_to_worker(self, worker): + """ + Send the dataset to a worker. + + Args: + the worker to which we want to send the dataset + + Returns: + pointers to the remote data, the labels and the index tensors + """ self.worker_id = worker self.data_pointer = self.data_tensor.send(worker) self.label_pointer = self.label_tensor.send(worker) @@ -85,11 +182,22 @@ def send_to_worker(self, worker): return self.data_pointer, self.label_pointer, self.index_pointer + + class VerticalFederatedDataset(): - #takes a list of BaseVerticalDatasets (already sent to workers) + """ + VerticalFederatedDataset, which acts as a dictionary between BaseVerticalDatasets, + already sent to remote workers, and the corresponding workers. + This serves as an input to VerticalFederatedDataLoader. + Same principle as in Syft 2.0 for FederatedDataset: + https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataset.py + + Args: + datasets: list of BaseVerticalDatasets. + """ def __init__(self, datasets): - self.datasets = {} + self.datasets = {} #dictionary to keep track of BaseVerticalDatasets and corresponding workers for dataset in datasets: worker_id = dataset.worker_id @@ -100,11 +208,16 @@ def workers(self): """ Returns: list of workers """ - return list(self.datasets.keys()) class VerticalFederatedDataLoader(): + """ + Data loader class, It combines a VerticalFederatedDataset and a sampler. + + TODO: implement the class. Check also FederatedDataLoader of Syft 2.0: + https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataloader.py + """ def __init__(self, vertical_fed_dataset): self.vertical_fed_dataset = vertical_fed_dataset From 693a334a8e064b8bbabd348b39ff9f2f4b4831ad Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Sat, 5 Dec 2020 19:42:48 +0100 Subject: [PATCH 10/21] first version of verticalFederatedDataLoader --- .../dualheaded/verticalfederateddataset.py | 106 ++++++++++++------ 1 file changed, 74 insertions(+), 32 deletions(-) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index 6098fb4..7c5ff5d 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -1,19 +1,26 @@ -import syft as sy from __future__ import print_function +import syft as sy import torch from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate """ Utility functions to split and distribute the data across different workers, create vertical datasets and federate them. It also contains datasets and dataloader classes. This code is meant to be used with dual-headed Neural Networks, where there are a bunch of different workers, which agrees on the labels, and there is a server with the labels only. - Code built upon: - Abbas Ismail's (@abbas5253) work on dual-headed NN. In particular, check Configuration 1: https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb - Syft 2.0 Federated Learning dataset and dataloader: https://github.com/OpenMined/PySyft/tree/syft_0.2.x/syft/frameworks/torch/fl +TODO: + - replace ids with UUIDs + - there is a bug in creation of BaseDataset + - create class for splitting the data + - create LabelSet and SampleSet (to accomodate later different roles of workers) + - improve DataLoader to accomodate different sampler (e.g. random sampler when shuffle) and different batch size """ @@ -21,19 +28,14 @@ def split_data(dataset, worker_list=None, n_workers=2, label_server=None): """ Utility function to create a vertical split of the data. It also creates a numerical index to keep track of the single data across different split. - Args: dataset: an iterable object represent the dataset. Each element of the iterable is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. - worker_list (optional): The list of VirtualWorkers to distribute the data vertically across. - n_workers(optional, default=2): The number of workers to split the data across. If worker_list is not passed, this is necessary to create the split. - label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) #TODO: add the code to send labels to the server - Returns: a dictionary holding as keys the workers passed as parameters, or integers corresponding to the split, and as values a list of lists, where the first element are the single tensor of the data, the second the labels, @@ -83,16 +85,12 @@ def split_data(dataset, worker_list=None, n_workers=2, label_server=None): def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): """ Utility function to distribute the data vertically across workers and create a vertical federated dataset. - Args: dataset: an iterable object represent the dataset. Each element of the iterable is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. - worker_list: The list of VirtualWorkers to distribute the data vertically across. - label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) - Returns: a VerticalFederatedDataset. """ @@ -111,25 +109,22 @@ def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): class BaseVerticalDataset(Dataset): """ Base Vertical Dataset class, containing a portion of a vertically splitted dataset. - Args: datatuples: a list where each element is another list (or tuple) of exactly 3 elements. The first one is a sample data, the second one is the corresponding label, and the third one the index (necessary to keep track of the same vertically splitted examples across multiple workers) - worker_id (optional): the worker to which we want to send the dataset - """ def __init__(self, datatuples, worker_id=None): - self._fill_tensors(datatuples) + self.__fill_tensors(datatuples) self.worker_id = None if worker_id != None: self.send_to_worker(worker_id) self.worker_id = worker_id - self._dataset_tolist() + self.__dataset_tolist() def __len__(self): @@ -138,14 +133,13 @@ def __len__(self): """ return self.data_tensor.shape[0] - def __get_item__(self, idx): + def __getitem__(self, index): """ Args: idx: index of the example we want to get - Returns: a tuple with data, label, index of a single example. """ - return tuple([self.data_tensor[idx], self.label_tensor[idx], self.index_tensor[idx]]) + return tuple([self.data_tensor[index], self.label_tensor[index], self.index_tensor[index]]) def __fill_tensors(self, data_tuples): """ @@ -161,17 +155,15 @@ def __dataset_tolist(self): """ list_dataset = [] for i in range(0, self.__len__()): - list_dataset.append(self.__get_item__(i)) + list_dataset.append(self.__getitem__(i)) self.dataset = list_dataset def send_to_worker(self, worker): """ Send the dataset to a worker. - Args: the worker to which we want to send the dataset - Returns: pointers to the remote data, the labels and the index tensors """ @@ -191,7 +183,6 @@ class VerticalFederatedDataset(): This serves as an input to VerticalFederatedDataLoader. Same principle as in Syft 2.0 for FederatedDataset: https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataset.py - Args: datasets: list of BaseVerticalDatasets. """ @@ -202,22 +193,73 @@ def __init__(self, datasets): for dataset in datasets: worker_id = dataset.worker_id self.datasets[worker_id] = dataset + + self.workers = self.__workers() - def workers(self): + def __workers(self): """ Returns: list of workers """ return list(self.datasets.keys()) + + def get_dataset(self, worker): + self[worker].federated = False + dataset = self[worker].get() + del self.datasets[worker] + return dataset + def __getitem__(self, worker_id): + """ + Args: + worker_id[str,int]: ID of respective worker + Returns: + Get Datasets from the respective worker + """ - class VerticalFederatedDataLoader(): - """ - Data loader class, It combines a VerticalFederatedDataset and a sampler. + return self.datasets[worker_id] - TODO: implement the class. Check also FederatedDataLoader of Syft 2.0: - https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataloader.py - """ + def __len__(self): + + return sum(len(dataset) for dataset in self.datasets.values()) + + def __repr__(self): + + fmt_str = "FederatedDataset\n" + fmt_str += f" Distributed accross: {', '.join(str(x) for x in self.workers)}\n" + fmt_str += f" Number of datapoints: {self.__len__()}\n" + return fmt_str + - def __init__(self, vertical_fed_dataset): - self.vertical_fed_dataset = vertical_fed_dataset +class SinglePartitionDataLoader(DataLoader): + """DataLoader for a single vertically-partitioned dataset""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.collate_fn = default_collate + +class VerticalFederatedDataLoader: + """Dataloader which batches data from a complete + set of vertically-partitioned datasets + """ + + def __init__(self, vf_dataset, batch_size=8, shuffle=False, *args, **kwargs): + + self.vf_dataset = vf_dataset + + single_loaders_list = [] + for d in vfd.datasets.values(): + single_loaders_list.append(SinglePartitionDataLoader(d)) + + self.workers = list(vf_dataset.keys()) + + + def __iter__(self): + return zip(*self.vf_dataset) + + def __len__(self): + l = 0 + for x in self.vf_dataset.datasets.values(): + l += len(x) + return l // len(self.workers) From f566e37772100b8a6aeb94ed4c2f2531a075c359 Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Sat, 5 Dec 2020 19:46:08 +0100 Subject: [PATCH 11/21] Added comments for TODOs --- examples/dualheaded/verticalfederateddataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index 7c5ff5d..72bee5d 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -21,6 +21,8 @@ - create class for splitting the data - create LabelSet and SampleSet (to accomodate later different roles of workers) - improve DataLoader to accomodate different sampler (e.g. random sampler when shuffle) and different batch size + - split function should be able to take as an input a dataloader, and not only a dataset (i.e. single sample iteration) + - check that / modify such that it works on data different than images """ From e382c6df7b4f41c7b2bcb6c1d6d8b9cb750c266e Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Sun, 6 Dec 2020 01:47:08 +0100 Subject: [PATCH 12/21] corrected index of last tensor in split_data --- examples/dualheaded/verticalfederateddataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index 72bee5d..987c24a 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -75,7 +75,7 @@ def split_data(dataset, worker_list=None, n_workers=2, label_server=None): i += 1 #add the value of the last worker / split - dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i+1) : ]) + dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i) : ]) dic_single_datasets[worker_list[-1]][1].append(label) dic_single_datasets[worker_list[-1]][2].append(idx) From 74bafec4c6f503b4c93c915161df341c240eab7f Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Sun, 6 Dec 2020 10:10:07 +0100 Subject: [PATCH 13/21] compacted sum of the len in verticalfeddataloader --- examples/dualheaded/verticalfederateddataset.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py index 987c24a..f64f49e 100644 --- a/examples/dualheaded/verticalfederateddataset.py +++ b/examples/dualheaded/verticalfederateddataset.py @@ -204,12 +204,6 @@ def __workers(self): Returns: list of workers """ return list(self.datasets.keys()) - - def get_dataset(self, worker): - self[worker].federated = False - dataset = self[worker].get() - del self.datasets[worker] - return dataset def __getitem__(self, worker_id): """ @@ -261,7 +255,5 @@ def __iter__(self): return zip(*self.vf_dataset) def __len__(self): - l = 0 - for x in self.vf_dataset.datasets.values(): - l += len(x) - return l // len(self.workers) + return sum(len(x) for x in self.datasets.values()) // len(self.workers) + From bb8cae7d205660abc639d5acb6d989d79f76affe Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Fri, 11 Dec 2020 15:58:33 +0100 Subject: [PATCH 14/21] Uploaded skeleton for dualheaded loaders --- examples/dh_examples/Example_.ipynb | 75 +++++++++++++++ examples/dh_examples/dataloaders.py | 57 ++++++++++++ examples/dh_examples/datasets.py | 136 ++++++++++++++++++++++++++++ examples/dh_examples/utils.py | 119 ++++++++++++++++++++++++ 4 files changed, 387 insertions(+) create mode 100644 examples/dh_examples/Example_.ipynb create mode 100644 examples/dh_examples/dataloaders.py create mode 100644 examples/dh_examples/datasets.py create mode 100644 examples/dh_examples/utils.py diff --git a/examples/dh_examples/Example_.ipynb b/examples/dh_examples/Example_.ipynb new file mode 100644 index 0000000..9154fe4 --- /dev/null +++ b/examples/dh_examples/Example_.ipynb @@ -0,0 +1,75 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import print_function\n", + "import syft as sy\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.data._utils.collate import default_collate\n", + "from typing import List, Tuple\n", + "from uuid import UUID\n", + "from uuid import uuid4\n", + "from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler\n", + "\n", + "from abc import ABC, abstractmethod\n", + "from torchvision import datasets, transforms\n", + "\n", + "import utils\n", + "import dataloaders\n", + "\n", + "hook = sy.TorchHook(torch)\n", + "\n", + "transform = transforms.Compose([transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,)),\n", + " ])\n", + "trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)\n", + "#trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n", + "\n", + "# create some workers\n", + "client_1 = sy.VirtualWorker(hook, id=\"client_1\")\n", + "client_2 = sy.VirtualWorker(hook, id=\"client_2\")\n", + "\n", + "server = sy.VirtualWorker(hook, id= \"server\") \n", + "\n", + "data_owners = [client_1, client_2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#get a verticalFederatedDataser\n", + "vfd = split_data_create_vertical_dataset(trainset, data_owners)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/dh_examples/dataloaders.py b/examples/dh_examples/dataloaders.py new file mode 100644 index 0000000..908d818 --- /dev/null +++ b/examples/dh_examples/dataloaders.py @@ -0,0 +1,57 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + +import datasets + + +class SinglePartitionDataLoader(DataLoader): + """DataLoader for a single vertically-partitioned dataset""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + #self.collate_fn = id_collate_fn + +class VerticalFederatedDataLoader: + """Dataloader which batches data from a complete + set of vertically-partitioned datasets + """ + + def __init__(self, vf_dataset, batch_size=8, shuffle=False, drop_last=False, *args, **kwargs): + + self.vf_dataset = vf_dataset + self.batch_size = batch_size + self.shuffle = shuffle + + self.workers = vf_dataset.workers + + self.batch_samplers = {} + for worker in self.workers: + data_range = range(len(list(self.vf_dataset.datasets.values()))) + if shuffle: + sampler = RandomSampler(data_range) + else: + sampler = SequentialSampler(data_range) + batch_sampler = BatchSampler(sampler, self.batch_size, drop_last) + self.batch_samplers[worker] = batch_sampler + + single_loaders = [] + for k in vfd.datasets.keys(): + single_loaders.append(SinglePartitionDataLoader(vfd.datasets[k], batch_sampler=self.batch_samplers[k])) + + self.single_loaders = single_loaders + + + def __iter__(self): + return zip(*self.single_loaders) + + def __len__(self): + return sum(len(x) for x in self.vf_dataset.datasets.values()) // len(self.workers) \ No newline at end of file diff --git a/examples/dh_examples/datasets.py b/examples/dh_examples/datasets.py new file mode 100644 index 0000000..3474bc4 --- /dev/null +++ b/examples/dh_examples/datasets.py @@ -0,0 +1,136 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + + +class BaseSet(Dataset): + def __init__(self, ids, values, worker_id=None, is_labels=False): + self.values_dic = {} + for i, l in zip(ids, values): + self.values_dic[i] = l + self.is_labels = is_labels + + self.ids = torch.Tensor(ids) + self.values = torch.Tensor(values) if is_labels else torch.stack(values) + + self.worker_id = None + if worker_id != None: + self.send_to_worker(worker_id) + + def send_to_worker(self, worker): + self.worker_id = worker + self.values_pointer = self.values.send(worker) + self.index_pointer = self.ids.send(worker) + return self.values_pointer, self.index_pointer + + def __getitem__(self, index): + """ + Args: + idx: index of the example we want to get + Returns: a tuple with data, label, index of a single example. + """ + return tuple([self.values[index], self.ids[index]]) + + def __len__(self): + """ + Returns: amount of samples in the dataset + """ + return self.values.shape[0] + + +class SampleSetWithLabels(Dataset): + def __init__(self, labelset, sampleset, worker_id=None): + #TO-DO: drop non-intersecting, now just assuming they are overlapping + #TO-DO: make sure values are sorted + self.labelset = labelset + self.sampleset = sampleset + + self.labels = labelset.values + self.values = sampleset.values + self.ids = sampleset.ids + + self.values_dic = {} + for k in labelset.values_dic.keys(): + self.values_dic[k] = tuple([sampleset.values_dic[k], torch.Tensor(labelset.values_dic[k])]) + + self.worker_id = None + if worker_id != None: + self.send_to_worker(worker_id) + + def send_to_worker(self, worker): + self.worker_id = worker + self.label_point, self.label_ix_pointer = self.labelset.send_to_worker(worker) + self.value_point, self.values_ix_pointer = self.sampleset.send_to_worker(worker) + return self.label_point, self.label_ix_pointer, self.value_point, self.values_ix_pointer + + + def __getitem__(self, index): + """ + Args: + idx: index of the example we want to get + Returns: a tuple with data, label, index of a single example. + """ + return tuple([self.values[index], self.labels[index], self.ids[index]]) + + def __len__(self): + """ + Returns: amount of samples in the dataset + """ + return self.values.shape[0] + + + +class VerticalFederatedDataset(): + """ + VerticalFederatedDataset, which acts as a dictionary between BaseVerticalDatasets, + already sent to remote workers, and the corresponding workers. + This serves as an input to VerticalFederatedDataLoader. + Same principle as in Syft 2.0 for FederatedDataset: + https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataset.py + + Args: + datasets: list of BaseVerticalDatasets. + """ + def __init__(self, datasets): + + self.datasets = {} #dictionary to keep track of BaseVerticalDatasets and corresponding workers + + for dataset in datasets: + worker_id = dataset.worker_id + self.datasets[worker_id] = dataset + + self.workers = self.__workers() + + def __workers(self): + """ + Returns: list of workers + """ + return list(self.datasets.keys()) + + def __getitem__(self, worker_id): + """ + Args: + worker_id[str,int]: ID of respective worker + Returns: + Get Datasets from the respective worker + """ + + return self.datasets[worker_id] + + def __len__(self): + + return sum(len(dataset) for dataset in self.datasets.values()) + + def __repr__(self): + + fmt_str = "FederatedDataset\n" + fmt_str += f" Distributed accross: {', '.join(str(x) for x in self.workers)}\n" + fmt_str += f" Number of datapoints: {self.__len__()}\n" + return fmt_str \ No newline at end of file diff --git a/examples/dh_examples/utils.py b/examples/dh_examples/utils.py new file mode 100644 index 0000000..fab9af5 --- /dev/null +++ b/examples/dh_examples/utils.py @@ -0,0 +1,119 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + +import dataloaders +import datasets + +""" +Utility functions to split and distribute the data across different workers, +create vertical datasets and federate them. It also contains datasets and dataloader classes. +This code is meant to be used with dual-headed Neural Networks, where there are a bunch of different workers, +which agrees on the labels, and there is a server with the labels only. +Code built upon: +- Abbas Ismail's (@abbas5253) work on dual-headed NN. In particular, check Configuration 1: + https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb +- Syft 2.0 Federated Learning dataset and dataloader: https://github.com/OpenMined/PySyft/tree/syft_0.2.x/syft/frameworks/torch/fl +TODO: + - replace ids with UUIDs + - there is a bug in creation of BaseDataset X + - create class for splitting the data + - create LabelSet and SampleSet (to accomodate later different roles of workers) + - improve DataLoader to accomodate different sampler (e.g. random sampler when shuffle) and different batch size X + - split function should be able to take as an input a dataloader, and not only a dataset (i.e. single sample iteration) + - check that / modify such that it works on data different than images + - dictionary keys should be worker ids, not workers themselves +""" + + + + +def split_data(dataset, worker_list=None, n_workers=2): + """ + Utility function to create a vertical split of the data. It also creates a numerical index to keep + track of the single data across different split. + Args: + dataset: an iterable object represent the dataset. Each element of the iterable + is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. + #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. + worker_list (optional): The list of VirtualWorkers to distribute the data vertically across. + n_workers(optional, default=2): The number of workers to split the data across. If worker_list is not passed, this is necessary to create the split. + label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) + #TODO: add the code to send labels to the server + Returns: + a dictionary holding as keys the workers passed as parameters, or integers corresponding to the split, + and as values a list of lists, where the first element are the single tensor of the data, the second the labels, + the third the index, which is to keep track of the same data point. + """ + + if worker_list == None: + worker_list = list(range(0, n_workers)) + + #counter to create the index of different data samples + idx = 0 + + #dictionary to accomodate the split data + dic_single_datasets = {} + for worker in worker_list: + """ + Each value is a list of three elements, to accomodate, in order: + - data examples (as tensors) + - label + - index + """ + dic_single_datasets[worker] = [[],[],[]] + + """ + Loop through the dataset to split the data and labels vertically across workers. + Splitting method from @abbas5253: https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/distribute_data.py + """ + for tensor, label in dataset: + height = tensor.shape[-1]//len(worker_list) + i = 0 + uuid_idx = uuid4() + for worker in worker_list[:-1]: + dic_single_datasets[worker][0].append(tensor[:, :, height * i : height * (i + 1)]) + dic_single_datasets[worker][1].append(label) + dic_single_datasets[worker][2].append(idx) + i += 1 + + #add the value of the last worker / split + dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i) : ]) + dic_single_datasets[worker_list[-1]][1].append(label) + dic_single_datasets[worker_list[-1]][2].append(idx) + + idx += 1 + + return dic_single_datasets + + +def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): + """ + Utility function to distribute the data vertically across workers and create a vertical federated dataset. + Args: + dataset: an iterable object represent the dataset. Each element of the iterable + is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. + #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. + worker_list: The list of VirtualWorkers to distribute the data vertically across. + label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) + Returns: + a VerticalFederatedDataset. + """ + + #get a dictionary of workers --> list of triples (data, label, idx) representing the dataset. + dic_single_datasets = split_data(dataset, worker_list=worker_list) + + #create base vertical datasets list, to be passed to a vertical federated dataset + base_datasets_list = [] + for worker in worker_list: + base_datasets_list.append(BaseVerticalDataset(dic_single_datasets[worker], worker_id=worker)) + + #create VerticalFederatedDataset + return VerticalFederatedDataset(base_datasets_list) \ No newline at end of file From 593a1ff33554d3484b745f0ac17db3ad83bd91ca Mon Sep 17 00:00:00 2001 From: daler3 Date: Tue, 22 Dec 2020 12:37:32 +0100 Subject: [PATCH 15/21] dataloaders and datasets for dualheaded --- examples/dualheaded/MNIST-DualHeaded.ipynb | 430 ----- .../dualheaded/verticalfederateddataset.py | 259 --- .../Example_.ipynb | 0 .../dataloaders.py | 0 .../datasets.py | 0 .../utils.py | 0 .../Diabetes_prediction_preprocessing.ipynb | 1647 ----------------- 7 files changed, 2336 deletions(-) delete mode 100644 examples/dualheaded/MNIST-DualHeaded.ipynb delete mode 100644 examples/dualheaded/verticalfederateddataset.py rename examples/{dh_examples => dualheaded_datautils}/Example_.ipynb (100%) rename examples/{dh_examples => dualheaded_datautils}/dataloaders.py (100%) rename examples/{dh_examples => dualheaded_datautils}/datasets.py (100%) rename examples/{dh_examples => dualheaded_datautils}/utils.py (100%) delete mode 100644 examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb diff --git a/examples/dualheaded/MNIST-DualHeaded.ipynb b/examples/dualheaded/MNIST-DualHeaded.ipynb deleted file mode 100644 index e62499e..0000000 --- a/examples/dualheaded/MNIST-DualHeaded.ipynb +++ /dev/null @@ -1,430 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# prerequisites\n", - "from __future__ import print_function\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "from torchvision import datasets, transforms\n", - "from torch.autograd import Variable\n", - "from torchvision.utils import save_image\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import argparse\n", - "from torch.optim.lr_scheduler import StepLR\n", - "\n", - "transform=transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))\n", - " ])\n", - "\n", - "bs = 64\n", - "# MNIST Dataset\n", - "train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", - "test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)\n", - "\n", - "# Data Loader (Input Pipeline)\n", - "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)\n", - "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([64, 1, 28, 28])\n", - "torch.Size([64])\n" - ] - } - ], - "source": [ - "dataiter = iter(train_loader)\n", - "images, labels = dataiter.next()\n", - "\n", - "print(images.shape)\n", - "print(labels.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOLklEQVR4nO3df6xU9ZnH8c+zWPjDEsOPi16BSLfBuGaTpTghm6gNKtv4g4gktoE/GjZRb6NoijYR7aolhESz2jZrskFBSNkVrU2ogkJ2q0gwNaZxuEHBJXsva9hyy4U7aEzhH7vSZ/+4h+YCd75zmXNmzsDzfiWTmTnPfOc8TPjcMzPfmfmauwvAxe+vym4AQHsQdiAIwg4EQdiBIAg7EMQl7dzZ1KlTfdasWe3cJRDKoUOHdPz4cRutlivsZnarpH+RNE7SS+7+TOr2s2bNUrVazbNLAAmVSqVuremn8WY2TtK/SrpN0rWSlprZtc3eH4DWyvOafZ6kg+7+qbv/SdIvJS0qpi0ARcsT9umSDo+4PpBtO4OZ9ZhZ1cyqtVotx+4A5JEn7KO9CXDOZ2/dfZ27V9y90tXVlWN3APLIE/YBSTNHXJ8h6Ui+dgC0Sp6wfyhptpl9w8zGS1oiaVsxbQEoWtNTb+7+lZk9KOk/NTz1ttHdPymsMwCFyjXP7u47JO0oqBcALcTHZYEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0Ioq1LNqPzbN26NVnfvXt3sn7w4MFk/a233qpbcz9nAaEzTJ9+zmpiZ3jiiSeS9Xvvvbdu7ZJL4v3X58gOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0HEm2wM5siRI8n6ypUrk/X+/v5c+zezurUrr7wyOXZwcDBZX758ebK+Y0f9BYZ7enqSYxcuXJisX4hyhd3MDkk6IemUpK/cvVJEUwCKV8SR/SZ3P17A/QBoIV6zA0HkDbtL+o2Z7TGzUV8EmVmPmVXNrFqr1XLuDkCz8ob9enefK+k2ScvN7Ntn38Dd17l7xd0rXV1dOXcHoFm5wu7uR7LzIUmvS5pXRFMAitd02M3sUjObePqypO9I2l9UYwCKlefd+MslvZ7No14i6RV3/49CukJh1qxZk6z39fUl641eet10003Jeuo75bNnz06O7e3tTdZfe+21ZH379u11azNmzEiOZZ59BHf/VNLfFdgLgBZi6g0IgrADQRB2IAjCDgRB2IEg+IrrRWD16tV1axs2bEiOnThxYrK+ZcuWZP2GG25I1vO46qqrkvXFixc3Xe/u7m6qpwsZR3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIJ59gvABx98kKyvWrWq6ft++umnk/VWzqM38uWXXybrQ0NDyfrzzz/f9NiLEUd2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCefaLQGpZ5EbK/MnkRvPoDz/8cLL+4osvNr3vU6dONT32QsWRHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJ49uKuvvrq0fb///vvJep55dEl66qmnco2/2DQ8spvZRjMbMrP9I7ZNNrO3zaw/O5/U2jYB5DWWp/G/kHTrWdsek7TT3WdL2pldB9DBGobd3d+T9PlZmxdJ2pRd3iTproL7AlCwZt+gu9zdByUpO59W74Zm1mNmVTOr1mq1JncHIK+Wvxvv7uvcveLula6urlbvDkAdzYb9mJl1S1J2Hu+nOoELTLNh3yZpWXZ5maStxbQDoFUazrOb2auS5kuaamYDkn4i6RlJvzKzeyT9XtJ3W9lkdHPnzm263tvbmxy7fv36ZP2+++5L1hu5//7769ZeeeWVXPe9dOnSZH3FihW57v9i0zDs7l7vEb2l4F4AtBAflwWCIOxAEIQdCIKwA0EQdiAIvuJ6AZgwYUKyvnnz5rq1BQsWJMf29PQk62+++WayXq1Wk/WjR4/WrY0bNy459s4770zWX3755WQdZ+LIDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBMM9+EUj9HPQ777yTHHvNNdck69u3b2+qp9NSy0k3+qnnJ598Mte+cSaO7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBPPsF7m+vr5S95/6DADz6O3FkR0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmCe/SK3fPnyUvc/ODhYt7Znz57k2Ouuu67odkJreGQ3s41mNmRm+0dsW2VmfzCzvdnp9ta2CSCvsTyN/4WkW0fZ/nN3n5OddhTbFoCiNQy7u78n6fM29AKghfK8QfegmX2cPc2fVO9GZtZjZlUzq9ZqtRy7A5BHs2FfK+mbkuZIGpT003o3dPd17l5x90pXV1eTuwOQV1Nhd/dj7n7K3f8sab2kecW2BaBoTYXdzLpHXF0saX+92wLoDA3n2c3sVUnzJU01swFJP5E038zmSHJJhyT9oIU9hvfZZ58l648//njd2uHDh5NjL7vssmT9kUceSdbdPVlfs2ZN3dru3buTY5lnL1bDsLv70lE2b2hBLwBaiI/LAkEQdiAIwg4EQdiBIAg7EARfce0A+/btS9YXLlyYrA8MDNStpZZMlqS1a9cm60uWLEnWG9mwof7EzRtvvJEcu2zZsmR9ypQpTfUUFUd2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCefY2WL16dbL+wgsvJOvHjh1L1mfOnFm39sADDyTH5p1HbyR1/88++2xy7K5du5L1u+++u6meouLIDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBMM/eBqnvdEvS0aNHk/VG30l/6aWX6tYWLFiQHNtqV1xxRdNjU/8uiXn288WRHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJ69AHv37k3Wv/jii2S90bLHjZYuvvHGG+vWTpw4kRzbqLft27cn642+q5/6DEGjfzeK1fDIbmYzzWyXmR0ws0/M7IfZ9slm9raZ9Wfnk1rfLoBmjeVp/FeSfuTufyPp7yUtN7NrJT0maae7z5a0M7sOoEM1DLu7D7p7b3b5hKQDkqZLWiRpU3azTZLualWTAPI7rzfozGyWpG9J+p2ky919UBr+gyBpWp0xPWZWNbNqrVbL1y2Apo057Gb2dUlbJK1w9z+OdZy7r3P3irtXurq6mukRQAHGFHYz+5qGg77Z3X+dbT5mZt1ZvVvSUGtaBFCEhlNvNvz9yg2SDrj7z0aUtklaJumZ7HxrSzq8APT39yfrJ0+eTNYbfYW1t7c3WU8t6Xz8+PHk2I8++ihZb9RbI6nxEyZMSI5duXJlrn3jTGOZZ79e0vcl7TOz0xPKP9ZwyH9lZvdI+r2k77amRQBFaBh2d/+tpHp/nm8pth0ArcLHZYEgCDsQBGEHgiDsQBCEHQiCr7gW4Oabb07Wp00b9ZPEf9FoSeZG3n333Vzj87jllvSEzKOPPlq3Nnny5OTYuXPnNtUTRseRHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJ69AFOmTEnWG30f/bnnnkvWd+zYkaz39fUl6ynz589P1u+4445k/aGHHkrWx48ff74toUU4sgNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAENbOZXMrlYpXq9W27Q+IplKpqFqtjvpr0BzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIhmE3s5lmtsvMDpjZJ2b2w2z7KjP7g5ntzU63t75dAM0ay49XfCXpR+7ea2YTJe0xs7ez2s/dPf3LCwA6wljWZx+UNJhdPmFmByRNb3VjAIp1Xq/ZzWyWpG9J+l226UEz+9jMNprZpDpjesysambVWq2Wq1kAzRtz2M3s65K2SFrh7n+UtFbSNyXN0fCR/6ejjXP3de5ecfdKV1dXAS0DaMaYwm5mX9Nw0De7+68lyd2Pufspd/+zpPWS5rWuTQB5jeXdeJO0QdIBd//ZiO3dI262WNL+4tsDUJSxvBt/vaTvS9pnZnuzbT+WtNTM5khySYck/aAlHQIoxFjejf+tpNG+H5v+MXMAHYVP0AFBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Jo65LNZlaT9L8jNk2VdLxtDZyfTu2tU/uS6K1ZRfZ2lbuP+vtvbQ37OTs3q7p7pbQGEjq1t07tS6K3ZrWrN57GA0EQdiCIssO+ruT9p3Rqb53al0RvzWpLb6W+ZgfQPmUf2QG0CWEHgigl7GZ2q5n9t5kdNLPHyuihHjM7ZGb7smWoqyX3stHMhsxs/4htk83sbTPrz85HXWOvpN46YhnvxDLjpT52ZS9/3vbX7GY2TlKfpH+QNCDpQ0lL3f2/2tpIHWZ2SFLF3Uv/AIaZfVvSSUn/5u5/m237Z0mfu/sz2R/KSe6+skN6WyXpZNnLeGerFXWPXGZc0l2S/lElPnaJvr6nNjxuZRzZ50k66O6fuvufJP1S0qIS+uh47v6epM/P2rxI0qbs8iYN/2dpuzq9dQR3H3T33uzyCUmnlxkv9bFL9NUWZYR9uqTDI64PqLPWe3dJvzGzPWbWU3Yzo7jc3Qel4f88kqaV3M/ZGi7j3U5nLTPeMY9dM8uf51VG2EdbSqqT5v+ud/e5km6TtDx7uoqxGdMy3u0yyjLjHaHZ5c/zKiPsA5Jmjrg+Q9KREvoYlbsfyc6HJL2uzluK+tjpFXSz86GS+/mLTlrGe7RlxtUBj12Zy5+XEfYPJc02s2+Y2XhJSyRtK6GPc5jZpdkbJzKzSyV9R523FPU2Scuyy8skbS2xlzN0yjLe9ZYZV8mPXenLn7t720+SbtfwO/L/I+mfyuihTl9/Lemj7PRJ2b1JelXDT+v+T8PPiO6RNEXSTkn92fnkDurt3yXtk/SxhoPVXVJvN2j4peHHkvZmp9vLfuwSfbXlcePjskAQfIIOCIKwA0EQdiAIwg4EQdiBIAg7EARhB4L4f8J+RLlbURtyAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.imshow(images[0].numpy().squeeze(), cmap='gray_r');" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super(Net, self).__init__()\n", - " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", - " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", - " self.dropout1 = nn.Dropout(0.25)\n", - " self.dropout2 = nn.Dropout(0.5)\n", - " self.fc1 = nn.Linear(9216, 128)\n", - " self.fc2 = nn.Linear(128, 10)\n", - "\n", - " def forward(self, x):\n", - " x = self.conv1(x)\n", - " x = F.relu(x)\n", - " x = self.conv2(x)\n", - " x = F.relu(x)\n", - " x = F.max_pool2d(x, 2)\n", - " x = self.dropout1(x)\n", - " x = torch.flatten(x, 1)\n", - " x = self.fc1(x)\n", - " x = F.relu(x)\n", - " x = self.dropout2(x)\n", - " x = self.fc2(x)\n", - " output = F.log_softmax(x, dim=1)\n", - " return output\n", - "\n", - "\n", - "def train(args, model, device, train_loader, optimizer, epoch):\n", - " model.train()\n", - " for batch_idx, (data, target) in enumerate(train_loader):\n", - " data, target = data.to(device), target.to(device)\n", - " optimizer.zero_grad()\n", - " output = model(data)\n", - " loss = F.nll_loss(output, target)\n", - " loss.backward()\n", - " optimizer.step()\n", - " if batch_idx % args.log_interval == 0:\n", - " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", - " epoch, batch_idx * len(data), len(train_loader.dataset),\n", - " 100. * batch_idx / len(train_loader), loss.item()))\n", - " if args.dry_run:\n", - " break\n", - "\n", - "\n", - "def test(model, device, test_loader):\n", - " model.eval()\n", - " test_loss = 0\n", - " correct = 0\n", - " with torch.no_grad():\n", - " for data, target in test_loader:\n", - " data, target = data.to(device), target.to(device)\n", - " output = model(data)\n", - " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", - " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", - " correct += pred.eq(target.view_as(pred)).sum().item()\n", - "\n", - " test_loss /= len(test_loader.dataset)\n", - "\n", - " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", - " test_loss, correct, len(test_loader.dataset),\n", - " 100. * correct / len(test_loader.dataset)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "lr = 0.001\n", - "epochs = 10\n", - "\n", - "model = Net().to(device)\n", - "optimizer = optim.Adadelta(model.parameters(), lr=lr)\n", - "\n", - "scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n", - "for epoch in range(1, epochs + 1):\n", - " train(args, model, device, train_loader, optimizer, epoch)\n", - " test(model, device, test_loader)\n", - " scheduler.step()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**DUAL HEADED**\n", - "\n", - "\n", - "* Client_1\n", - " * Has Model Segment 1\n", - " * Has the handwritten images segment 1 (vertical distr: left part)\n", - "* Client_2\n", - " * Has model Segment 1\n", - " * Has the handwritten images segment 2 (vertical distr: right part)\n", - "* Server\n", - " * Has Model Segment 2\n", - " * Has the image labels\n" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import pytest\n", - "import syft as sy\n", - "from syft.core.node.common.service.auth import AuthorizationException\n", - "\n", - "import torch as th\n", - "from torchvision import datasets, transforms\n", - "from torch import nn, optim\n", - "import syft as sy\n", - "import numpy as np\n", - "from syft.util import key_emoji" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [], - "source": [ - "# create some workers\n", - "dev_1 = sy.Device(name=\"client_1\")\n", - "dev_2 = sy.Device(name=\"client_2\")\n", - "client_1 = dev_1.get_client()\n", - "client_2 = dev_2.get_client()\n", - "\n", - "server_dev = sy.Device(name=\"server\")\n", - "server = server_dev.get_client()\n", - "\n", - "data_owners = (client_1, client_2)\n", - "model_locations = [client_1, client_2, server]" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">\n", - ">\n", - ">\n" - ] - } - ], - "source": [ - "for location in model_locations:\n", - " print(location)\n", - " x = location" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'client_2 Client'" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model_locations[1].name.split()" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [], - "source": [ - "input_size= [28*14, 28*14]\n", - "hidden_sizes= {\"client_1\": [32, 64], \"client_2\":[32, 64], \"server\":[128, 64]}\n", - "\n", - "#create model segment for each worker\n", - "models = {\n", - " \"client_1\": nn.Sequential(\n", - " nn.Linear(input_size[0], hidden_sizes[\"client_1\"][0]),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_sizes[\"client_1\"][0], hidden_sizes[\"client_1\"][1]),\n", - " nn.ReLU(),\n", - " ),\n", - " \"client_2\": nn.Sequential(\n", - " nn.Linear(input_size[1], hidden_sizes[\"client_2\"][0]),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_sizes[\"client_2\"][0], hidden_sizes[\"client_2\"][1]),\n", - " nn.ReLU(),\n", - " ),\n", - " \"server\": nn.Sequential(\n", - " nn.Linear(hidden_sizes[\"server\"][0], hidden_sizes[\"server\"][1]),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_sizes[\"server\"][1], 10),\n", - " nn.LogSoftmax(dim=1)\n", - " )\n", - "}\n", - "\n", - "# Create optimisers for each segment and link to their segment\n", - "optimizers = [\n", - " optim.SGD(models[location.name.split(\" \")[0]].parameters(), lr=0.05,)\n", - " for location in model_locations\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "ename": "Exception", - "evalue": "Object has no serializable_wrapper_type", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#send model segement to each client and server\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlocation\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel_locations\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmodels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlocation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\" \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/ast/klass.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, client, searchable)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0;31m# Step 3: send message\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 189\u001b[0;31m \u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_immediate_msg_without_reply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;31m# Step 4: return pointer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/node/common/client.py\u001b[0m in \u001b[0;36msend_immediate_msg_without_reply\u001b[0;34m(self, msg, route_index)\u001b[0m\n\u001b[1;32m 255\u001b[0m )\n\u001b[1;32m 256\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msigning_key\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigning_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"> Sending {msg.pprint} {self.pprint} ➡️ {msg.address.pprint}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroutes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mroute_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_immediate_msg_without_reply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/common/message.py\u001b[0m in \u001b[0;36msign\u001b[0;34m(self, signing_key)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34mf\"> Signing with {self.address.key_emoji(key=signing_key.verify_key)}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m )\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0msigned_message\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msigning_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_bytes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# signed_type will be the final subclass callee's closest parent signed_type\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/common/serde/serializable.py\u001b[0m in \u001b[0;36mserialize\u001b[0;34m(self, to_proto, to_bytes)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;31m# indent=None means no white space or \\n in the serialized version\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[0;31m# this is compatible with json.dumps(x, indent=None)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 264\u001b[0;31m \u001b[0mblob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_object2proto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSerializeToString\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 265\u001b[0m blob = DataMessage(\n\u001b[1;32m 266\u001b[0m \u001b[0mobj_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_fully_qualified_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mblob\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/node/common/action/save_object_action.py\u001b[0m in \u001b[0;36m_object2proto\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mid_at_location\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid_at_location\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m \u001b[0mobj_ob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 76\u001b[0m \u001b[0maddr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/ast/klass.py\u001b[0m in \u001b[0;36mserialize\u001b[0;34m(self, to_proto, to_bytes)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0mto_proto\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mto_proto\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 228\u001b[0;31m \u001b[0mto_bytes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mto_bytes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 229\u001b[0m )\n\u001b[1;32m 230\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/syft_decorator_impl.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;31m# try:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;31m# return function(*args, **kwargs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/decorators/typecheck.py\u001b[0m in \u001b[0;36mdecorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprohibit_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcheck_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtypechecked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecorated\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__annotations__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/typeguard/__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 889\u001b[0m \u001b[0mmemo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CallMemo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_localns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 890\u001b[0m \u001b[0mcheck_argument_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m \u001b[0mretval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 892\u001b[0m \u001b[0mcheck_return_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/syft/core/common/serde/serialize.py\u001b[0m in \u001b[0;36m_serialize\u001b[0;34m(obj, to_proto, to_bytes)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mis_serializable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserializable_wrapper_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Object {type(obj)} has no serializable_wrapper_type\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0mis_serializable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mException\u001b[0m: Object has no serializable_wrapper_type" - ] - } - ], - "source": [ - "\n", - "\n", - "\n", - "#send model segement to each client and server\n", - "for location in model_locations:\n", - " models[location.name.split(\" \")[0]].send(location)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "bob_vm = client_1.get_client()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "xp = th.tensor([1,2,3]).tag(\"some\", \"diabetes\", \"data\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "xp.send(vm_sever_cli)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/dualheaded/verticalfederateddataset.py b/examples/dualheaded/verticalfederateddataset.py deleted file mode 100644 index f64f49e..0000000 --- a/examples/dualheaded/verticalfederateddataset.py +++ /dev/null @@ -1,259 +0,0 @@ -from __future__ import print_function -import syft as sy -import torch -from torch.utils.data import Dataset -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate - -""" -Utility functions to split and distribute the data across different workers, -create vertical datasets and federate them. It also contains datasets and dataloader classes. -This code is meant to be used with dual-headed Neural Networks, where there are a bunch of different workers, -which agrees on the labels, and there is a server with the labels only. -Code built upon: -- Abbas Ismail's (@abbas5253) work on dual-headed NN. In particular, check Configuration 1: - https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb -- Syft 2.0 Federated Learning dataset and dataloader: https://github.com/OpenMined/PySyft/tree/syft_0.2.x/syft/frameworks/torch/fl - -TODO: - - replace ids with UUIDs - - there is a bug in creation of BaseDataset - - create class for splitting the data - - create LabelSet and SampleSet (to accomodate later different roles of workers) - - improve DataLoader to accomodate different sampler (e.g. random sampler when shuffle) and different batch size - - split function should be able to take as an input a dataloader, and not only a dataset (i.e. single sample iteration) - - check that / modify such that it works on data different than images -""" - - -def split_data(dataset, worker_list=None, n_workers=2, label_server=None): - """ - Utility function to create a vertical split of the data. It also creates a numerical index to keep - track of the single data across different split. - Args: - dataset: an iterable object represent the dataset. Each element of the iterable - is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. - #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. - worker_list (optional): The list of VirtualWorkers to distribute the data vertically across. - n_workers(optional, default=2): The number of workers to split the data across. If worker_list is not passed, this is necessary to create the split. - label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) - #TODO: add the code to send labels to the server - Returns: - a dictionary holding as keys the workers passed as parameters, or integers corresponding to the split, - and as values a list of lists, where the first element are the single tensor of the data, the second the labels, - the third the index, which is to keep track of the same data point. - """ - - if worker_list == None: - worker_list = list(range(0, n_workers)) - - #counter to create the index of different data samples - idx = 0 - - #dictionary to accomodate the split data - dic_single_datasets = {} - for worker in worker_list: - """ - Each value is a list of three elements, to accomodate, in order: - - data examples (as tensors) - - label - - index - """ - dic_single_datasets[worker] = [[],[],[]] - - """ - Loop through the dataset to split the data and labels vertically across workers. - Splitting method from @abbas5253: https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/distribute_data.py - """ - for tensor, label in dataset: - height = tensor.shape[-1]//len(worker_list) - i = 0 - for worker in worker_list[:-1]: - dic_single_datasets[worker][0].append(tensor[:, :, height * i : height * (i + 1)]) - dic_single_datasets[worker][1].append(label) - dic_single_datasets[worker][2].append(idx) - i += 1 - - #add the value of the last worker / split - dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i) : ]) - dic_single_datasets[worker_list[-1]][1].append(label) - dic_single_datasets[worker_list[-1]][2].append(idx) - - idx += 1 - - return dic_single_datasets - - -def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): - """ - Utility function to distribute the data vertically across workers and create a vertical federated dataset. - Args: - dataset: an iterable object represent the dataset. Each element of the iterable - is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. - #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. - worker_list: The list of VirtualWorkers to distribute the data vertically across. - label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) - Returns: - a VerticalFederatedDataset. - """ - - #get a dictionary of workers --> list of triples (data, label, idx) representing the dataset. - dic_single_datasets = split_data(dataset, worker_list=worker_list, label_server=label_server) - - #create base vertical datasets list, to be passed to a vertical federated dataset - base_datasets_list = [] - for worker in worker_list: - base_datasets_list.append(BaseVerticalDataset(dic_single_datasets[worker], worker_id=worker)) - - #create VerticalFederatedDataset - return VerticalFederatedDataset(base_datasets_list) - -class BaseVerticalDataset(Dataset): - """ - Base Vertical Dataset class, containing a portion of a vertically splitted dataset. - Args: - datatuples: a list where each element is another list (or tuple) of exactly 3 elements. - The first one is a sample data, the second one is the corresponding label, and the third one the index - (necessary to keep track of the same vertically splitted examples across multiple workers) - worker_id (optional): the worker to which we want to send the dataset - """ - def __init__(self, datatuples, worker_id=None): - - self.__fill_tensors(datatuples) - - self.worker_id = None - if worker_id != None: - self.send_to_worker(worker_id) - self.worker_id = worker_id - - self.__dataset_tolist() - - - def __len__(self): - """ - Returns: amount of samples in the dataset - """ - return self.data_tensor.shape[0] - - def __getitem__(self, index): - """ - Args: - idx: index of the example we want to get - Returns: a tuple with data, label, index of a single example. - """ - return tuple([self.data_tensor[index], self.label_tensor[index], self.index_tensor[index]]) - - def __fill_tensors(self, data_tuples): - """ - Private method to fill the tensors of the tuples, labels and index. - """ - self.data_tensor = torch.stack(data_tuples[0]) - self.label_tensor = torch.Tensor(data_tuples[1]) - self.index_tensor = torch.Tensor(data_tuples[2]) - - def __dataset_tolist(self): - """ - Private method to create a compact list version of the dataset, so that len(dataset) is the number of examples. - """ - list_dataset = [] - for i in range(0, self.__len__()): - list_dataset.append(self.__getitem__(i)) - self.dataset = list_dataset - - - def send_to_worker(self, worker): - """ - Send the dataset to a worker. - Args: - the worker to which we want to send the dataset - Returns: - pointers to the remote data, the labels and the index tensors - """ - self.worker_id = worker - self.data_pointer = self.data_tensor.send(worker) - self.label_pointer = self.label_tensor.send(worker) - self.index_pointer = self.index_tensor.send(worker) - return self.data_pointer, self.label_pointer, self.index_pointer - - - - -class VerticalFederatedDataset(): - """ - VerticalFederatedDataset, which acts as a dictionary between BaseVerticalDatasets, - already sent to remote workers, and the corresponding workers. - This serves as an input to VerticalFederatedDataLoader. - Same principle as in Syft 2.0 for FederatedDataset: - https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataset.py - Args: - datasets: list of BaseVerticalDatasets. - """ - def __init__(self, datasets): - - self.datasets = {} #dictionary to keep track of BaseVerticalDatasets and corresponding workers - - for dataset in datasets: - worker_id = dataset.worker_id - self.datasets[worker_id] = dataset - - self.workers = self.__workers() - - - def __workers(self): - """ - Returns: list of workers - """ - return list(self.datasets.keys()) - - def __getitem__(self, worker_id): - """ - Args: - worker_id[str,int]: ID of respective worker - Returns: - Get Datasets from the respective worker - """ - - return self.datasets[worker_id] - - def __len__(self): - - return sum(len(dataset) for dataset in self.datasets.values()) - - def __repr__(self): - - fmt_str = "FederatedDataset\n" - fmt_str += f" Distributed accross: {', '.join(str(x) for x in self.workers)}\n" - fmt_str += f" Number of datapoints: {self.__len__()}\n" - return fmt_str - - -class SinglePartitionDataLoader(DataLoader): - """DataLoader for a single vertically-partitioned dataset""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.collate_fn = default_collate - -class VerticalFederatedDataLoader: - """Dataloader which batches data from a complete - set of vertically-partitioned datasets - """ - - def __init__(self, vf_dataset, batch_size=8, shuffle=False, *args, **kwargs): - - self.vf_dataset = vf_dataset - - single_loaders_list = [] - for d in vfd.datasets.values(): - single_loaders_list.append(SinglePartitionDataLoader(d)) - - self.workers = list(vf_dataset.keys()) - - - def __iter__(self): - return zip(*self.vf_dataset) - - def __len__(self): - return sum(len(x) for x in self.datasets.values()) // len(self.workers) - diff --git a/examples/dh_examples/Example_.ipynb b/examples/dualheaded_datautils/Example_.ipynb similarity index 100% rename from examples/dh_examples/Example_.ipynb rename to examples/dualheaded_datautils/Example_.ipynb diff --git a/examples/dh_examples/dataloaders.py b/examples/dualheaded_datautils/dataloaders.py similarity index 100% rename from examples/dh_examples/dataloaders.py rename to examples/dualheaded_datautils/dataloaders.py diff --git a/examples/dh_examples/datasets.py b/examples/dualheaded_datautils/datasets.py similarity index 100% rename from examples/dh_examples/datasets.py rename to examples/dualheaded_datautils/datasets.py diff --git a/examples/dh_examples/utils.py b/examples/dualheaded_datautils/utils.py similarity index 100% rename from examples/dh_examples/utils.py rename to examples/dualheaded_datautils/utils.py diff --git a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb b/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb deleted file mode 100644 index 55e3697..0000000 --- a/examples/synthea_examples/Diabetes_prediction_preprocessing.ipynb +++ /dev/null @@ -1,1647 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Diabetes prediction with synthea data\n", - "\n", - "###### Mostly from https://github.com/IBM/example-health-machine-learning/blob/master/diabetes-prediction.ipynb" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import data in pandas dataframes " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd \n", - "import numpy as np\n", - "\n", - "#load data into pandas dataframes\n", - "data_dir = \"../../data/synthea/\"\n", - "conditions_file = data_dir+\"conditions.csv\"\n", - "medications_file = data_dir+\"medications.csv\"\n", - "observatios_file = data_dir+\"observations.csv\"\n", - "patients_file = data_dir+\"patients.csv\"\n", - "\n", - "df_cond = pd.read_csv(conditions_file)\n", - "df_med = pd.read_csv(medications_file)\n", - "df_obs = pd.read_csv(observatios_file)\n", - "df_pat = pd.read_csv(patients_file)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
DATEPATIENTENCOUNTERCODEDESCRIPTIONVALUEUNITSTYPE
02011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c91498302-2Body Height167.0cmnumeric
12011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914972514-3Pain severity - 0-10 verbal numeric rating [Sc...3.0{score}numeric
22011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914929463-7Body Weight71.1kgnumeric
32011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914939156-5Body Mass Index25.5kg/m2numeric
42011-03-31T02:00:17Z7d3e489a-7789-9cd6-2a1b-711074af481b814174f3-2e0e-1625-de48-9c40732c914959576-9Body mass index (BMI) [Percentile] Per age and...83.6%numeric
\n", - "
" - ], - "text/plain": [ - " DATE PATIENT \\\n", - "0 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "1 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "2 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "3 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "4 2011-03-31T02:00:17Z 7d3e489a-7789-9cd6-2a1b-711074af481b \n", - "\n", - " ENCOUNTER CODE \\\n", - "0 814174f3-2e0e-1625-de48-9c40732c9149 8302-2 \n", - "1 814174f3-2e0e-1625-de48-9c40732c9149 72514-3 \n", - "2 814174f3-2e0e-1625-de48-9c40732c9149 29463-7 \n", - "3 814174f3-2e0e-1625-de48-9c40732c9149 39156-5 \n", - "4 814174f3-2e0e-1625-de48-9c40732c9149 59576-9 \n", - "\n", - " DESCRIPTION VALUE UNITS TYPE \n", - "0 Body Height 167.0 cm numeric \n", - "1 Pain severity - 0-10 verbal numeric rating [Sc... 3.0 {score} numeric \n", - "2 Body Weight 71.1 kg numeric \n", - "3 Body Mass Index 25.5 kg/m2 numeric \n", - "4 Body mass index (BMI) [Percentile] Per age and... 83.6 % numeric " - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_obs.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
IdBIRTHDATEDEATHDATESSNDRIVERSPASSPORTPREFIXFIRSTLASTSUFFIX...BIRTHPLACEADDRESSCITYSTATECOUNTYZIPLATLONHEALTHCARE_EXPENSESHEALTHCARE_COVERAGE
07d3e489a-7789-9cd6-2a1b-711074af481b1993-01-28NaN999-95-8631S99916705X24646789XMr.Jon665Pacocha935NaN...Lawrence Massachusetts US942 Fahey Overpass Apt 21NatickMassachusettsMiddlesex CountyNaN42.309347-71.349633569019.692293.12
1a3795ec8-54f3-e99e-a4b1-4c067f3141d71971-12-01NaN999-62-4431S99941017X38787090XMr.Dick869Streich926NaN...Swansea Massachusetts US1064 Hickle View Apt 7ChicopeeMassachusettsHampden County1020.042.198239-72.55475218755.460.00
23829c803-1f4c-74ed-0d8f-36e502cadd0f2005-01-07NaN999-21-2332NaNNaNNaNCordell41Eichmann909NaN...Chelmsford Massachusetts US560 Ritchie Way Suite 68SwanseaMassachusettsBristol CountyNaN41.748125-71.182914361770.002768.96
3d7acfddb-f4c2-69f4-2081-ad1fb84904481990-07-04NaN999-53-1990S99932677X67053099XMrs.Cheri871Oberbrunner298NaN...Cambridge Massachusetts US268 Hansen Loaf Apt 62LowellMassachusettsMiddlesex County1850.042.662520-71.368933703332.775551.19
4474766f3-ee93-f5d6-84c3-db38ba8033942012-04-03NaN999-57-2653NaNNaNNaNDesmond566O'Conner199NaN...Cohasset Massachusetts US831 Schumm Lock Apt 62WestboroughMassachusettsWorcester CountyNaN42.253951-71.563825206450.272284.86
\n", - "

5 rows × 25 columns

\n", - "
" - ], - "text/plain": [ - " Id BIRTHDATE DEATHDATE SSN \\\n", - "0 7d3e489a-7789-9cd6-2a1b-711074af481b 1993-01-28 NaN 999-95-8631 \n", - "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 1971-12-01 NaN 999-62-4431 \n", - "2 3829c803-1f4c-74ed-0d8f-36e502cadd0f 2005-01-07 NaN 999-21-2332 \n", - "3 d7acfddb-f4c2-69f4-2081-ad1fb8490448 1990-07-04 NaN 999-53-1990 \n", - "4 474766f3-ee93-f5d6-84c3-db38ba803394 2012-04-03 NaN 999-57-2653 \n", - "\n", - " DRIVERS PASSPORT PREFIX FIRST LAST SUFFIX ... \\\n", - "0 S99916705 X24646789X Mr. Jon665 Pacocha935 NaN ... \n", - "1 S99941017 X38787090X Mr. Dick869 Streich926 NaN ... \n", - "2 NaN NaN NaN Cordell41 Eichmann909 NaN ... \n", - "3 S99932677 X67053099X Mrs. Cheri871 Oberbrunner298 NaN ... \n", - "4 NaN NaN NaN Desmond566 O'Conner199 NaN ... \n", - "\n", - " BIRTHPLACE ADDRESS CITY \\\n", - "0 Lawrence Massachusetts US 942 Fahey Overpass Apt 21 Natick \n", - "1 Swansea Massachusetts US 1064 Hickle View Apt 7 Chicopee \n", - "2 Chelmsford Massachusetts US 560 Ritchie Way Suite 68 Swansea \n", - "3 Cambridge Massachusetts US 268 Hansen Loaf Apt 62 Lowell \n", - "4 Cohasset Massachusetts US 831 Schumm Lock Apt 62 Westborough \n", - "\n", - " STATE COUNTY ZIP LAT LON \\\n", - "0 Massachusetts Middlesex County NaN 42.309347 -71.349633 \n", - "1 Massachusetts Hampden County 1020.0 42.198239 -72.554752 \n", - "2 Massachusetts Bristol County NaN 41.748125 -71.182914 \n", - "3 Massachusetts Middlesex County 1850.0 42.662520 -71.368933 \n", - "4 Massachusetts Worcester County NaN 42.253951 -71.563825 \n", - "\n", - " HEALTHCARE_EXPENSES HEALTHCARE_COVERAGE \n", - "0 569019.69 2293.12 \n", - "1 18755.46 0.00 \n", - "2 361770.00 2768.96 \n", - "3 703332.77 5551.19 \n", - "4 206450.27 2284.86 \n", - "\n", - "[5 rows x 25 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_pat.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Feature selection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Select the features of interests: \n", - "\n", - "- Systolic blood pressure readings from the observations (code 8480-6).\n", - "- Select diastolic blood pressure readings (code 8462-4).\n", - "- Select HDL cholesterol readings (code 2085-9).\n", - "- Select LDL cholesterol readings (code 18262-6).\n", - "- Select BMI (body mass index) readings (code 39156-5).\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def feature_selection_obs(df, code):\n", - " return df[df[\"CODE\"]==code][[\"PATIENT\", \"DATE\", \"VALUE\"]].drop_duplicates().reset_index(drop=True)\n", - "\n", - "#select feautures from observations\n", - "df_systolic = feature_selection_obs(df_obs, \"8480-6\").rename(columns={\"VALUE\": \"SYSTOLIC_BP\"})\n", - "df_diastolic = feature_selection_obs(df_obs, \"8462-4\").rename(columns={\"VALUE\": \"DIASTOLIC_BP\"})\n", - "df_hdl = feature_selection_obs(df_obs, \"2085-9\").rename(columns={\"VALUE\": \"HDL\"})\n", - "df_ldl = feature_selection_obs(df_obs, \"18262-6\").rename(columns={\"VALUE\": \"LDL\"})\n", - "df_bmi = feature_selection_obs(df_obs, \"39156-5\").rename(columns={\"VALUE\": \"BMI\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(83540, 83541, 26900, 26900, 57880)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(df_systolic), len(df_diastolic), len(df_hdl), len(df_ldl), len(df_bmi)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Merge the dataframes (inner join for now, to avoid dealing with missing values)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "df1 = pd.merge(df_systolic, df_diastolic, on=[\"PATIENT\", \"DATE\"], how='inner')\n", - "df2 = pd.merge(df1, df_hdl, on=[\"PATIENT\", \"DATE\"], how='inner')\n", - "df3 = pd.merge(df2, df_ldl, on=[\"PATIENT\", \"DATE\"], how='inner')\n", - "df4 = pd.merge(df3, df_bmi, on=[\"PATIENT\", \"DATE\"], how='inner')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "21224" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(df4)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
PATIENTDATESYSTOLIC_BPDIASTOLIC_BPHDLLDLBMI
0a3795ec8-54f3-e99e-a4b1-4c067f3141d72013-01-16T22:06:58Z128.088.064.589.222.4
1a3795ec8-54f3-e99e-a4b1-4c067f3141d72017-12-20T22:06:58Z116.071.064.978.022.4
29bafdf36-6e60-e93e-7925-c8d15a49ea622012-11-25T09:32:01Z125.082.072.997.327.6
39bafdf36-6e60-e93e-7925-c8d15a49ea622015-12-13T09:32:01Z104.089.064.371.327.6
49bafdf36-6e60-e93e-7925-c8d15a49ea622018-12-30T09:32:01Z121.077.061.277.827.6
\n", - "
" - ], - "text/plain": [ - " PATIENT DATE SYSTOLIC_BP \\\n", - "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2013-01-16T22:06:58Z 128.0 \n", - "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2017-12-20T22:06:58Z 116.0 \n", - "2 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2012-11-25T09:32:01Z 125.0 \n", - "3 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2015-12-13T09:32:01Z 104.0 \n", - "4 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2018-12-30T09:32:01Z 121.0 \n", - "\n", - " DIASTOLIC_BP HDL LDL BMI \n", - "0 88.0 64.5 89.2 22.4 \n", - "1 71.0 64.9 78.0 22.4 \n", - "2 82.0 72.9 97.3 27.6 \n", - "3 89.0 64.3 71.3 27.6 \n", - "4 77.0 61.2 77.8 27.6 " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df4.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Join also the age (derived from birth date in PATIENT dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "df5 = pd.merge(df4, df_pat[[\"Id\", \"BIRTHDATE\"]].rename(columns={\"Id\": \"PATIENT\"}), on =[\"PATIENT\"], how='inner')" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "df5[\"DATE\"] = [x.split(\"T\")[0] for x in list(df5[\"DATE\"])]" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from datetime import datetime\n", - "#from https://stackoverflow.com/questions/8419564/difference-between-two-dates-in-python\n", - "def days_between(d1, d2):\n", - " d1 = datetime.strptime(d1, \"%Y-%m-%d\")\n", - " d2 = datetime.strptime(d2, \"%Y-%m-%d\")\n", - " return abs((d2 - d1).days)\n", - "\n", - "def age_calculation(l1, l2):\n", - " age_list = []\n", - " i = 0\n", - " for i in range(0, len(l1)):\n", - " age_list.append(days_between(l1[i], l2[i]) / 365.00)\n", - " return age_list\n", - "\n", - "df5[\"AGE\"] = age_calculation(list(df5[\"DATE\"]), list(df5[\"BIRTHDATE\"]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "df5.drop([\"BIRTHDATE\"], axis=1, inplace=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
PATIENTDATESYSTOLIC_BPDIASTOLIC_BPHDLLDLBMIAGE
0a3795ec8-54f3-e99e-a4b1-4c067f3141d72013-01-16128.088.064.589.222.441.156164
1a3795ec8-54f3-e99e-a4b1-4c067f3141d72017-12-20116.071.064.978.022.446.084932
29bafdf36-6e60-e93e-7925-c8d15a49ea622012-11-25125.082.072.997.327.658.167123
39bafdf36-6e60-e93e-7925-c8d15a49ea622015-12-13104.089.064.371.327.661.216438
49bafdf36-6e60-e93e-7925-c8d15a49ea622018-12-30121.077.061.277.827.664.265753
\n", - "
" - ], - "text/plain": [ - " PATIENT DATE SYSTOLIC_BP DIASTOLIC_BP \\\n", - "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2013-01-16 128.0 88.0 \n", - "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2017-12-20 116.0 71.0 \n", - "2 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2012-11-25 125.0 82.0 \n", - "3 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2015-12-13 104.0 89.0 \n", - "4 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2018-12-30 121.0 77.0 \n", - "\n", - " HDL LDL BMI AGE \n", - "0 64.5 89.2 22.4 41.156164 \n", - "1 64.9 78.0 22.4 46.084932 \n", - "2 72.9 97.3 27.6 58.167123 \n", - "3 64.3 71.3 27.6 61.216438 \n", - "4 61.2 77.8 27.6 64.265753 " - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df5.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we find the patient with diabetes diagnosis, and select the start date column (equivalent to first diagnosis), in the CONDITION dataset. " - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "df_pat_diab = df_cond[df_cond.DESCRIPTION == \"Diabetes\"][[\"PATIENT\", \"START\"]]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "df6 = pd.merge(df5, df_pat_diab, on=[\"PATIENT\"], how=\"left\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "df6[\"HAS_DIABETES\"] = [(0 if (type(el) == float and np.isnan(el)) else 1) for el in list(df6[\"START\"])]" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
PATIENTDATESYSTOLIC_BPDIASTOLIC_BPHDLLDLBMIAGESTARTHAS_DIABETES
0a3795ec8-54f3-e99e-a4b1-4c067f3141d72013-01-16128.088.064.589.222.441.156164NaN0
1a3795ec8-54f3-e99e-a4b1-4c067f3141d72017-12-20116.071.064.978.022.446.084932NaN0
29bafdf36-6e60-e93e-7925-c8d15a49ea622012-11-25125.082.072.997.327.658.167123NaN0
39bafdf36-6e60-e93e-7925-c8d15a49ea622015-12-13104.089.064.371.327.661.216438NaN0
49bafdf36-6e60-e93e-7925-c8d15a49ea622018-12-30121.077.061.277.827.664.265753NaN0
\n", - "
" - ], - "text/plain": [ - " PATIENT DATE SYSTOLIC_BP DIASTOLIC_BP \\\n", - "0 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2013-01-16 128.0 88.0 \n", - "1 a3795ec8-54f3-e99e-a4b1-4c067f3141d7 2017-12-20 116.0 71.0 \n", - "2 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2012-11-25 125.0 82.0 \n", - "3 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2015-12-13 104.0 89.0 \n", - "4 9bafdf36-6e60-e93e-7925-c8d15a49ea62 2018-12-30 121.0 77.0 \n", - "\n", - " HDL LDL BMI AGE START HAS_DIABETES \n", - "0 64.5 89.2 22.4 41.156164 NaN 0 \n", - "1 64.9 78.0 22.4 46.084932 NaN 0 \n", - "2 72.9 97.3 27.6 58.167123 NaN 0 \n", - "3 64.3 71.3 27.6 61.216438 NaN 0 \n", - "4 61.2 77.8 27.6 64.265753 NaN 0 " - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df6.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Filter data\n", - "For this example, we filter the positive observations taken before diagnosis and then we reduce the observations to a single one per patient. In the future, it might be valuable keeping the time and try to predict diabetes before it occurs (e.g. with RNN). In this case, however, we need to check better the generative model underlying synthea, as in the notebook we are trying to reproduce here: \"The impact of the condition (diabetes) is not reflected in the observations until the patient is diagnosed with the condition in a wellness visit\". However, there is a condition called \"Prediabetes\" which we could take into account. " - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "def date_to_int(string_date):\n", - " #a date can also be nan (float type)\n", - " return int(string_date.replace(\"-\", \"\")) if type(string_date) == str else 0\n", - "def col_date_to_int(col_date):\n", - " return list(map(date_to_int, col_date))\n", - "\n", - "df6[\"temp_date\"] = col_date_to_int(df6[\"DATE\"])\n", - "df6[\"temp_start\"] = col_date_to_int(df6[\"START\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "57" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_with_diab = df6[df6.HAS_DIABETES == 1]\n", - "df_to_discard = df_with_diab[df_with_diab[\"temp_start\"] > df_with_diab[\"temp_date\"]]\n", - "len(df_to_discard)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "df7 = df6.drop(index = df_to_discard.index, inplace=False).reset_index().drop(columns=[\"index\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's now reduce the observations to a single observation per patient (the earliest available observation)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
PATIENTSYSTOLIC_BPDIASTOLIC_BPHDLLDLBMIAGEHAS_DIABETES
028de1ba3-efdc-1797-5a32-d4f8d3be0936109.078.075.174.425.230.3561640
16686010d-3d7b-d69d-d2cd-8bbe4b3e6041160.0104.062.182.628.230.3561640
242217d99-02e1-2cc8-83ed-5101c246a559112.084.069.294.630.131.2219180
3c429d985-225b-f380-4462-57852cf61186121.074.063.1109.029.631.2219180
406365dfa-6203-5413-2c6d-553a4a988c1f117.082.064.694.134.334.2328770
...........................
3821391d2527-19ca-8b83-2cd0-5c484946b2b7156.095.076.576.434.130.3589040
3822964603c6-e5bd-6d86-5b21-34056a4a651a124.078.068.0102.125.331.2219180
38233d44d71e-2241-29c2-6aed-7f7f4d65d8d4123.084.075.066.530.030.3589040
3824bbcfa21f-caa2-4540-d2b6-929f48ae27c2130.071.069.199.329.331.2219180
3825d9f4d286-8db4-201d-c1ed-80909f6b3927121.083.068.1104.422.730.3589040
\n", - "

3826 rows × 8 columns

\n", - "
" - ], - "text/plain": [ - " PATIENT SYSTOLIC_BP DIASTOLIC_BP HDL \\\n", - "0 28de1ba3-efdc-1797-5a32-d4f8d3be0936 109.0 78.0 75.1 \n", - "1 6686010d-3d7b-d69d-d2cd-8bbe4b3e6041 160.0 104.0 62.1 \n", - "2 42217d99-02e1-2cc8-83ed-5101c246a559 112.0 84.0 69.2 \n", - "3 c429d985-225b-f380-4462-57852cf61186 121.0 74.0 63.1 \n", - "4 06365dfa-6203-5413-2c6d-553a4a988c1f 117.0 82.0 64.6 \n", - "... ... ... ... ... \n", - "3821 391d2527-19ca-8b83-2cd0-5c484946b2b7 156.0 95.0 76.5 \n", - "3822 964603c6-e5bd-6d86-5b21-34056a4a651a 124.0 78.0 68.0 \n", - "3823 3d44d71e-2241-29c2-6aed-7f7f4d65d8d4 123.0 84.0 75.0 \n", - "3824 bbcfa21f-caa2-4540-d2b6-929f48ae27c2 130.0 71.0 69.1 \n", - "3825 d9f4d286-8db4-201d-c1ed-80909f6b3927 121.0 83.0 68.1 \n", - "\n", - " LDL BMI AGE HAS_DIABETES \n", - "0 74.4 25.2 30.356164 0 \n", - "1 82.6 28.2 30.356164 0 \n", - "2 94.6 30.1 31.221918 0 \n", - "3 109.0 29.6 31.221918 0 \n", - "4 94.1 34.3 34.232877 0 \n", - "... ... ... ... ... \n", - "3821 76.4 34.1 30.358904 0 \n", - "3822 102.1 25.3 31.221918 0 \n", - "3823 66.5 30.0 30.358904 0 \n", - "3824 99.3 29.3 31.221918 0 \n", - "3825 104.4 22.7 30.358904 0 \n", - "\n", - "[3826 rows x 8 columns]" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df7.sort_values(by=[\"temp_date\"], axis=0, inplace=True)\n", - "df7.reset_index(inplace=True)\n", - "df7 = df7.drop(columns=\"index\")\n", - "df7[\"OBS_INDEX\"] = df7.groupby([\"PATIENT\"]).cumcount()+1\n", - "df8 = df7[df7.OBS_INDEX == 1]\n", - "df8.reset_index(inplace=True)\n", - "df8 = df8.drop(columns=[\"index\", \"temp_date\", \"temp_start\", \"OBS_INDEX\", \"START\", \"DATE\"])\n", - "df8" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(3480, 346)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(df8[\"HAS_DIABETES\"]).count(0), list(df8[\"HAS_DIABETES\"]).count(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "#prepare dataset for pytorch\n", - "df9 = df8.drop(columns=\"PATIENT\")" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['SYSTOLIC_BP', 'DIASTOLIC_BP', 'HDL', 'LDL', 'BMI', 'AGE', 'HAS_DIABETES']" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(df9.columns)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Divide into training and test set, and define train and test dataloader" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [], - "source": [ - "df_input = df9.drop(columns=\"HAS_DIABETES\")\n", - "df_y = df9[\"HAS_DIABETES\"]\n", - "\n", - "for col in list(df_input.columns):\n", - " df_input[col] = list(map(float, df_input[col]))\n", - " \n", - "train_ratio = 0.7\n", - "\n", - "msk = np.random.rand(len(df9)) < train_ratio\n", - "train_set = df_input[msk].values\n", - "train_labels = df_y[msk].values\n", - "test_set = df_input[~msk].values\n", - "test_labels = df_y[~msk].values" - ] - }, - { - "cell_type": "code", - "execution_count": 121, - "metadata": {}, - "outputs": [], - "source": [ - "import torch \n", - "\n", - "train_target = torch.tensor(train_labels.astype(np.float32))\n", - "train = torch.tensor(train_set.astype(np.float32)) \n", - "\n", - "test_target = torch.tensor(test_labels.astype(np.float32))\n", - "test = torch.tensor(test_set.astype(np.float32)) \n", - "\n", - "bs = 10\n", - "train_tensor = torch.utils.data.TensorDataset(train, train_target) \n", - "train_loader = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = bs, shuffle = True)\n", - "\n", - "test_tensor = torch.utils.data.TensorDataset(test, test_target) \n", - "test_loader = torch.utils.data.DataLoader(dataset = test_tensor, batch_size = bs, shuffle = False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Simple logistic regression model" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "metadata": {}, - "outputs": [], - "source": [ - "class LogisticRegression(torch.nn.Module):\n", - " def __init__(self, input_dim, output_dim):\n", - " super(LogisticRegression, self).__init__()\n", - " self.linear = torch.nn.Linear(input_dim, output_dim)\n", - "\n", - " def forward(self, x):\n", - " outputs = self.linear(x)\n", - " return outputs" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "metadata": {}, - "outputs": [], - "source": [ - "epochs = 10\n", - "input_dim = 6\n", - "output_dim = 2\n", - "lr_rate = 0.001\n", - "\n", - "model = LogisticRegression(input_dim, output_dim)\n", - "criterion = torch.nn.CrossEntropyLoss()\n", - "\n", - "optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training loop and evaluation" - ] - }, - { - "cell_type": "code", - "execution_count": 177, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0. Iteration: 200. Loss: 7.707485929131508e-05. Accuracy: 90.78590785907859.\n", - "Epoch: 1. Iteration: 400. Loss: 0.0004317985731177032. Accuracy: 94.76061427280939.\n", - "Epoch: 2. Iteration: 600. Loss: 0.0. Accuracy: 90.60523938572719.\n", - "Epoch: 2. Iteration: 800. Loss: 1.2545890808105469. Accuracy: 45.347786811201445.\n", - "Epoch: 3. Iteration: 1000. Loss: 0.07240111380815506. Accuracy: 94.03794037940379.\n", - "Epoch: 4. Iteration: 1200. Loss: 0.0018771663308143616. Accuracy: 93.13459801264679.\n", - "Epoch: 5. Iteration: 1400. Loss: 1.3109453916549683. Accuracy: 79.67479674796748.\n", - "Epoch: 5. Iteration: 1600. Loss: 1.7740905284881592. Accuracy: 92.6829268292683.\n", - "Epoch: 6. Iteration: 1800. Loss: 0.0184627752751112. Accuracy: 94.2186088527552.\n", - "Epoch: 7. Iteration: 2000. Loss: 0.6591505408287048. Accuracy: 75.51942186088527.\n", - "Epoch: 8. Iteration: 2200. Loss: 1.10377836227417. Accuracy: 90.60523938572719.\n", - "Epoch: 8. Iteration: 2400. Loss: 0.2253536880016327. Accuracy: 92.773261065944.\n", - "Epoch: 9. Iteration: 2600. Loss: 3.0842161178588867. Accuracy: 90.51490514905149.\n" - ] - } - ], - "source": [ - "from torch.autograd import Variable\n", - "iter = 0\n", - "loss_v = []\n", - "for epoch in range(int(epochs)):\n", - " for i, (point, labels) in enumerate(train_loader):\n", - " points = Variable(point.view(-1, 6))\n", - " labels = Variable(labels.type(torch.LongTensor))\n", - " \n", - " optimizer.zero_grad()\n", - " outputs = model(points)\n", - " loss = criterion(outputs, labels)\n", - " loss_v.append(loss.detach().numpy())\n", - " loss.backward()\n", - " optimizer.step()\n", - " \n", - " iter+=1\n", - " if iter%200==0:\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for points, labels in test_loader:\n", - " points = Variable(points.view(-1, 6))\n", - " outputs = model(points)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * float(correct)/float(total)\n", - " print(\"Epoch: {}. Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, iter, loss.item(), accuracy))" - ] - }, - { - "cell_type": "code", - "execution_count": 141, - "metadata": {}, - "outputs": [], - "source": [ - "#final test\n", - "labels = test_target\n", - "outputs = model(test)\n", - "_, predicted = torch.max(outputs.data,1)\n", - "total = labels.size(0)\n", - "correct = (predicted == labels).sum()\n", - "accuracy = float(correct)/float(total)" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.metrics import confusion_matrix\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sn\n", - "\n", - "cm = confusion_matrix(labels, predicted, labels=[0,1])\n", - "tn, fp, fn, tp = cm.ravel()" - ] - }, - { - "cell_type": "code", - "execution_count": 170, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1002, 0],\n", - " [ 105, 0]])" - ] - }, - "execution_count": 170, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cm" - ] - }, - { - "cell_type": "code", - "execution_count": 171, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1002, 0, 105, 0)" - ] - }, - "execution_count": 171, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tn, fp, fn, tp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From a3ad412bf8cbe065344182ea289cc70422839cc2 Mon Sep 17 00:00:00 2001 From: Daniele Romanini Date: Sun, 27 Dec 2020 11:19:55 +0100 Subject: [PATCH 16/21] Updated utility functions Updated split_data_create_vertical_dataset to match with current dataset classes (i.e. samplesetwithlabels). --- examples/dualheaded_datautils/utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/dualheaded_datautils/utils.py b/examples/dualheaded_datautils/utils.py index fab9af5..2b04217 100644 --- a/examples/dualheaded_datautils/utils.py +++ b/examples/dualheaded_datautils/utils.py @@ -107,13 +107,15 @@ def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): a VerticalFederatedDataset. """ - #get a dictionary of workers --> list of triples (data, label, idx) representing the dataset. - dic_single_datasets = split_data(dataset, worker_list=worker_list) - - #create base vertical datasets list, to be passed to a vertical federated dataset + #get a dictionary of workers --> data , label_list, index_list, ordered + dic_single_datasets, label_list, index_list = split_data(dataset, worker_list=worker_list) + + #instantiate BaseSets + label_set = BaseSet(index_list, label_list, is_labels=True) base_datasets_list = [] - for worker in worker_list: - base_datasets_list.append(BaseVerticalDataset(dic_single_datasets[worker], worker_id=worker)) + for w in dic_single_datasets.keys(): + bs = BaseSet(index_list, dic_single_datasets[w], is_labels=False) + base_datasets_list.append(SampleSetWithLabels(label_set, bs, worker_id=w)) #create VerticalFederatedDataset - return VerticalFederatedDataset(base_datasets_list) \ No newline at end of file + return VerticalFederatedDataset(base_datasets_list) From 901dff8fd32cad556252490c022c47f54b7ab7d2 Mon Sep 17 00:00:00 2001 From: daler3 Date: Thu, 7 Jan 2021 17:16:36 +0100 Subject: [PATCH 17/21] updated datasets --- examples/dualheaded_datautils/dataloaders.py | 6 ++-- examples/dualheaded_datautils/datasets.py | 31 +++++++++++++++----- examples/dualheaded_datautils/utils.py | 17 ++++++----- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/examples/dualheaded_datautils/dataloaders.py b/examples/dualheaded_datautils/dataloaders.py index 908d818..94c1ab0 100644 --- a/examples/dualheaded_datautils/dataloaders.py +++ b/examples/dualheaded_datautils/dataloaders.py @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs): #self.collate_fn = id_collate_fn -class VerticalFederatedDataLoader: +class VerticalFederatedDataLoader(): """Dataloader which batches data from a complete set of vertically-partitioned datasets """ @@ -44,8 +44,8 @@ def __init__(self, vf_dataset, batch_size=8, shuffle=False, drop_last=False, *ar self.batch_samplers[worker] = batch_sampler single_loaders = [] - for k in vfd.datasets.keys(): - single_loaders.append(SinglePartitionDataLoader(vfd.datasets[k], batch_sampler=self.batch_samplers[k])) + for k in vf_dataset.datasets.keys(): + single_loaders.append(SinglePartitionDataLoader(vf_dataset.datasets[k], batch_sampler=self.batch_samplers[k])) self.single_loaders = single_loaders diff --git a/examples/dualheaded_datautils/datasets.py b/examples/dualheaded_datautils/datasets.py index 3474bc4..959b93d 100644 --- a/examples/dualheaded_datautils/datasets.py +++ b/examples/dualheaded_datautils/datasets.py @@ -58,8 +58,9 @@ def __init__(self, labelset, sampleset, worker_id=None): self.values_dic = {} for k in labelset.values_dic.keys(): - self.values_dic[k] = tuple([sampleset.values_dic[k], torch.Tensor(labelset.values_dic[k])]) - + self.values_dic[k] = tuple([sampleset.values_dic[k], labelset.values_dic[k]]) + + print("ciao") self.worker_id = None if worker_id != None: self.send_to_worker(worker_id) @@ -102,27 +103,41 @@ def __init__(self, datasets): self.datasets = {} #dictionary to keep track of BaseVerticalDatasets and corresponding workers + indices_list = set() + + #take intersecting items for dataset in datasets: - worker_id = dataset.worker_id - self.datasets[worker_id] = dataset + indices_list.update(dataset.ids) + self.datasets[dataset.worker_id] = dataset self.workers = self.__workers() - + + #create a list of dictionaries + self.dict_items_list = [] + + for index in indices_list: + curr_dict = {} + for w in self.workers: + curr_dict[w] = tuple(list(self.datasets[w].values_dic[index.item()])+[index.item()]) + + self.dict_items_list.append(curr_dict) + + def __workers(self): """ Returns: list of workers """ return list(self.datasets.keys()) - def __getitem__(self, worker_id): + def __getitem__(self, idx): """ Args: worker_id[str,int]: ID of respective worker Returns: - Get Datasets from the respective worker + Get dataset item from different workers """ - return self.datasets[worker_id] + return self.dict_items_list[idx] def __len__(self): diff --git a/examples/dualheaded_datautils/utils.py b/examples/dualheaded_datautils/utils.py index 2b04217..4fddea8 100644 --- a/examples/dualheaded_datautils/utils.py +++ b/examples/dualheaded_datautils/utils.py @@ -11,6 +11,7 @@ import dataloaders import datasets +from datasets import * """ Utility functions to split and distribute the data across different workers, @@ -68,30 +69,30 @@ def split_data(dataset, worker_list=None, n_workers=2): - label - index """ - dic_single_datasets[worker] = [[],[],[]] + dic_single_datasets[worker] = [] """ Loop through the dataset to split the data and labels vertically across workers. Splitting method from @abbas5253: https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/distribute_data.py """ + label_list = [] + index_list = [] for tensor, label in dataset: height = tensor.shape[-1]//len(worker_list) i = 0 uuid_idx = uuid4() for worker in worker_list[:-1]: - dic_single_datasets[worker][0].append(tensor[:, :, height * i : height * (i + 1)]) - dic_single_datasets[worker][1].append(label) - dic_single_datasets[worker][2].append(idx) + dic_single_datasets[worker].append(tensor[:, :, height * i : height * (i + 1)]) i += 1 #add the value of the last worker / split - dic_single_datasets[worker_list[-1]][0].append(tensor[:, :, height * (i) : ]) - dic_single_datasets[worker_list[-1]][1].append(label) - dic_single_datasets[worker_list[-1]][2].append(idx) + dic_single_datasets[worker_list[-1]].append(tensor[:, :, height * (i) : ]) + label_list.append(label) + index_list.append(idx) idx += 1 - return dic_single_datasets + return dic_single_datasets, label_list, index_list def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): From 8bdc9f3dc3095815c3f47020e1dedba7c56c5267 Mon Sep 17 00:00:00 2001 From: daler3 Date: Thu, 7 Jan 2021 18:59:14 +0100 Subject: [PATCH 18/21] updated to use custom dataloaders, updated example notebook --- examples/dualheaded_datautils/Example_.ipynb | 117 ++++++++++++++++++- examples/dualheaded_datautils/dataloaders.py | 37 ++++-- examples/dualheaded_datautils/datasets.py | 6 +- examples/dualheaded_datautils/utils.py | 6 +- 4 files changed, 143 insertions(+), 23 deletions(-) diff --git a/examples/dualheaded_datautils/Example_.ipynb b/examples/dualheaded_datautils/Example_.ipynb index 9154fe4..b0790fb 100644 --- a/examples/dualheaded_datautils/Example_.ipynb +++ b/examples/dualheaded_datautils/Example_.ipynb @@ -2,9 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Torch was already hooked... skipping hooking process\n" + ] + } + ], "source": [ "from __future__ import print_function\n", "import syft as sy\n", @@ -46,9 +54,110 @@ "metadata": {}, "outputs": [], "source": [ - "#get a verticalFederatedDataser\n", - "vfd = split_data_create_vertical_dataset(trainset, data_owners)" + "#get a verticalFederatedDatase\n", + "vfd = utils.split_data_create_vertical_dataset(trainset, data_owners)" ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "loader = DataLoader(vfd, batch_size=4, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{: [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([1, 7, 3, 4]), tensor([48949., 24503., 52281., 33117.], dtype=torch.float64)], : [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([1, 7, 3, 4]), tensor([48949., 24503., 52281., 33117.], dtype=torch.float64)]}\n" + ] + } + ], + "source": [ + "for el in loader: \n", + " print(el)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/dualheaded_datautils/dataloaders.py b/examples/dualheaded_datautils/dataloaders.py index 94c1ab0..8e445d5 100644 --- a/examples/dualheaded_datautils/dataloaders.py +++ b/examples/dualheaded_datautils/dataloaders.py @@ -12,6 +12,9 @@ import datasets +"""I think this is not needed anymore""" + + class SinglePartitionDataLoader(DataLoader): """DataLoader for a single vertically-partitioned dataset""" @@ -20,22 +23,36 @@ def __init__(self, *args, **kwargs): #self.collate_fn = id_collate_fn -class VerticalFederatedDataLoader(): + + +class VerticalFederatedDataLoader(DataLoader): """Dataloader which batches data from a complete set of vertically-partitioned datasets + + + DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, *, prefetch_factor=2, + persistent_workers=False) """ - def __init__(self, vf_dataset, batch_size=8, shuffle=False, drop_last=False, *args, **kwargs): + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, *, prefetch_factor=2, + persistent_workers=False): - self.vf_dataset = vf_dataset + self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle + self.num_workers = num_workers - self.workers = vf_dataset.workers + self.workers = dataset.workers self.batch_samplers = {} for worker in self.workers: - data_range = range(len(list(self.vf_dataset.datasets.values()))) + data_range = range(len(self.dataset)) if shuffle: sampler = RandomSampler(data_range) else: @@ -44,14 +61,10 @@ def __init__(self, vf_dataset, batch_size=8, shuffle=False, drop_last=False, *ar self.batch_samplers[worker] = batch_sampler single_loaders = [] - for k in vf_dataset.datasets.keys(): - single_loaders.append(SinglePartitionDataLoader(vf_dataset.datasets[k], batch_sampler=self.batch_samplers[k])) + for k in self.dataset.datasets.keys(): + single_loaders.append(SinglePartitionDataLoader(self.dataset.datasets[k], batch_sampler=self.batch_samplers[k])) self.single_loaders = single_loaders - - - def __iter__(self): - return zip(*self.single_loaders) def __len__(self): - return sum(len(x) for x in self.vf_dataset.datasets.values()) // len(self.workers) \ No newline at end of file + return sum(len(x) for x in self.dataset.datasets.values()) // len(self.workers) \ No newline at end of file diff --git a/examples/dualheaded_datautils/datasets.py b/examples/dualheaded_datautils/datasets.py index 959b93d..9497752 100644 --- a/examples/dualheaded_datautils/datasets.py +++ b/examples/dualheaded_datautils/datasets.py @@ -60,7 +60,6 @@ def __init__(self, labelset, sampleset, worker_id=None): for k in labelset.values_dic.keys(): self.values_dic[k] = tuple([sampleset.values_dic[k], labelset.values_dic[k]]) - print("ciao") self.worker_id = None if worker_id != None: self.send_to_worker(worker_id) @@ -121,6 +120,8 @@ def __init__(self, datasets): curr_dict[w] = tuple(list(self.datasets[w].values_dic[index.item()])+[index.item()]) self.dict_items_list.append(curr_dict) + + self.indices = list(indices_list) def __workers(self): @@ -140,8 +141,7 @@ def __getitem__(self, idx): return self.dict_items_list[idx] def __len__(self): - - return sum(len(dataset) for dataset in self.datasets.values()) + return len(self.indices) def __repr__(self): diff --git a/examples/dualheaded_datautils/utils.py b/examples/dualheaded_datautils/utils.py index 4fddea8..5ed714d 100644 --- a/examples/dualheaded_datautils/utils.py +++ b/examples/dualheaded_datautils/utils.py @@ -24,13 +24,11 @@ - Syft 2.0 Federated Learning dataset and dataloader: https://github.com/OpenMined/PySyft/tree/syft_0.2.x/syft/frameworks/torch/fl TODO: - replace ids with UUIDs - - there is a bug in creation of BaseDataset X - create class for splitting the data - - create LabelSet and SampleSet (to accomodate later different roles of workers) - - improve DataLoader to accomodate different sampler (e.g. random sampler when shuffle) and different batch size X - - split function should be able to take as an input a dataloader, and not only a dataset (i.e. single sample iteration) - check that / modify such that it works on data different than images - dictionary keys should be worker ids, not workers themselves + + - the custom dataloder class is probably not needed anymore (Discuss) """ From 54a4f164ec878e73dada400001865c6d4e110f48 Mon Sep 17 00:00:00 2001 From: daler3 Date: Thu, 7 Jan 2021 19:41:34 +0100 Subject: [PATCH 19/21] added enhanced worker class (wip) --- .../enhancedSplitWorkers.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 examples/dualheaded_datautils/enhancedSplitWorkers.py diff --git a/examples/dualheaded_datautils/enhancedSplitWorkers.py b/examples/dualheaded_datautils/enhancedSplitWorkers.py new file mode 100644 index 0000000..9e9c788 --- /dev/null +++ b/examples/dualheaded_datautils/enhancedSplitWorkers.py @@ -0,0 +1,54 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + +import datasets + + +"""This is an experimental work-in-progress feature""" + + +class EnhanchedWorker(): + """Single worker with a role (label / data holder) and a model""" + + def __init__(self, worker, dataset, model, level=1): + + self.worker = worker + self.dataset = dataset #It can also be None, and then it would be only computational + self.model = model + + self.level = level if level >= 0 else 0 #it should start from zero, otherwise throw error #TODO: implement error throwing + + + +class FederatedWorkerChain(): + + """Class wrapping all the workers with their corresponding model """ + def __init__(self, enhanchedWorkersList): + self.enhanchedWorkersList = enhanchedWorkersList + dic_workers = {} + for ew in enhanchedWorkersList: + if ew.level not in dic_workers.keys(): + dic_workers[ew.level] = [] + + dic_workers[ew.level].append(ew) + + self.dic_workers = dic_workers + + + #TODO: implement check that the level passed is valid + def get_same_level_en_workers(self, level): + return self.dic_workers[level] + + def get_previous_level_en_workers(self, level): + return self.dic_workers[level-1] + + def get_next_level_en_workers(self, level): + return self.dic_workers[level+1] From f21d8af2888010bdad2ea3071b05d143c55f265c Mon Sep 17 00:00:00 2001 From: daler3 Date: Fri, 8 Jan 2021 13:47:31 +0100 Subject: [PATCH 20/21] changed workers with worker's ids in dictionaries; added models' segments --- examples/dualheaded_datautils/Example_.ipynb | 111 ++++++++++++++++--- examples/dualheaded_datautils/utils.py | 6 +- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/examples/dualheaded_datautils/Example_.ipynb b/examples/dualheaded_datautils/Example_.ipynb index b0790fb..b290a0c 100644 --- a/examples/dualheaded_datautils/Example_.ipynb +++ b/examples/dualheaded_datautils/Example_.ipynb @@ -1,18 +1,19 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing split functions and dataloading\n", + "\n", + "In this section, we test split functions (utils), custom datasets classes and dataloading (with standard pytorch dataloader). " + ] + }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Torch was already hooked... skipping hooking process\n" - ] - } - ], + "outputs": [], "source": [ "from __future__ import print_function\n", "import syft as sy\n", @@ -50,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -60,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -69,14 +70,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{: [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", + "{'client_1': [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " ...,\n", @@ -109,7 +110,7 @@ " ...,\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([1, 7, 3, 4]), tensor([48949., 24503., 52281., 33117.], dtype=torch.float64)], : [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([4, 5, 2, 3]), tensor([41319., 27699., 36485., 3835.], dtype=torch.float64)], 'client_2': [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " ...,\n", @@ -142,7 +143,7 @@ " ...,\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([1, 7, 3, 4]), tensor([48949., 24503., 52281., 33117.], dtype=torch.float64)]}\n" + " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([4, 5, 2, 3]), tensor([41319., 27699., 36485., 3835.], dtype=torch.float64)]}\n" ] } ], @@ -152,6 +153,84 @@ " break" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "#as in https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb\n", + "from torch import nn, optim\n", + "\n", + "model_locations = [client_1, client_2, server]\n", + "\n", + "input_size= [28*14, 28*14]\n", + "hidden_sizes= {\"client_1\": [32, 64], \"client_2\":[32, 64], \"server\":[128, 64]}\n", + "\n", + "#create model segment for each worker\n", + "models = {\n", + " \"client_1\": nn.Sequential(\n", + " nn.Linear(input_size[0], hidden_sizes[\"client_1\"][0]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"client_1\"][0], hidden_sizes[\"client_1\"][1]),\n", + " nn.ReLU(),\n", + " ),\n", + " \"client_2\": nn.Sequential(\n", + " nn.Linear(input_size[1], hidden_sizes[\"client_2\"][0]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"client_2\"][0], hidden_sizes[\"client_2\"][1]),\n", + " nn.ReLU(),\n", + " ),\n", + " \"server\": nn.Sequential(\n", + " nn.Linear(hidden_sizes[\"server\"][0], hidden_sizes[\"server\"][1]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"server\"][1], 10),\n", + " nn.LogSoftmax(dim=1)\n", + " )\n", + "}\n", + "\n", + "\n", + "\n", + "# Create optimisers for each segment and link to their segment\n", + "optimizers = [\n", + " optim.SGD(models[location.id].parameters(), lr=0.05,)\n", + " for location in model_locations\n", + "]\n", + "\n", + "\n", + "#send model segement to each client and server\n", + "for location in model_locations:\n", + " models[location.id].send(location)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/dualheaded_datautils/utils.py b/examples/dualheaded_datautils/utils.py index 5ed714d..f2d9635 100644 --- a/examples/dualheaded_datautils/utils.py +++ b/examples/dualheaded_datautils/utils.py @@ -67,7 +67,7 @@ def split_data(dataset, worker_list=None, n_workers=2): - label - index """ - dic_single_datasets[worker] = [] + dic_single_datasets[worker.id] = [] """ Loop through the dataset to split the data and labels vertically across workers. @@ -80,11 +80,11 @@ def split_data(dataset, worker_list=None, n_workers=2): i = 0 uuid_idx = uuid4() for worker in worker_list[:-1]: - dic_single_datasets[worker].append(tensor[:, :, height * i : height * (i + 1)]) + dic_single_datasets[worker.id].append(tensor[:, :, height * i : height * (i + 1)]) i += 1 #add the value of the last worker / split - dic_single_datasets[worker_list[-1]].append(tensor[:, :, height * (i) : ]) + dic_single_datasets[worker_list[-1].id].append(tensor[:, :, height * (i) : ]) label_list.append(label) index_list.append(idx) From 96ed8e913fc9f7527398d9918bd6089ab1896b1c Mon Sep 17 00:00:00 2001 From: daler3 Date: Fri, 15 Jan 2021 10:58:36 +0100 Subject: [PATCH 21/21] Addressed Tom's comments --- examples/dualheaded_datautils/Example_.ipynb | 48 +++++++++---------- examples/dualheaded_datautils/datasets.py | 2 +- .../enhancedSplitWorkers.py | 2 +- examples/dualheaded_datautils/utils.py | 2 +- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/examples/dualheaded_datautils/Example_.ipynb b/examples/dualheaded_datautils/Example_.ipynb index b290a0c..a8a0b0a 100644 --- a/examples/dualheaded_datautils/Example_.ipynb +++ b/examples/dualheaded_datautils/Example_.ipynb @@ -110,40 +110,40 @@ " ...,\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([4, 5, 2, 3]), tensor([41319., 27699., 36485., 3835.], dtype=torch.float64)], 'client_2': [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)], 'client_2': [tensor([[[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", " ...,\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + " [ 0.1686, 0.9922, 0.5137, ..., -1.0000, -1.0000, -1.0000],\n", + " [ 0.1686, 0.9922, 0.3412, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", "\n", "\n", - " [[[-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", " ...,\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", "\n", "\n", - " [[[-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", " ...,\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", "\n", "\n", - " [[[-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", " ...,\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.],\n", - " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([4, 5, 2, 3]), tensor([41319., 27699., 36485., 3835.], dtype=torch.float64)]}\n" + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)]}\n" ] } ], diff --git a/examples/dualheaded_datautils/datasets.py b/examples/dualheaded_datautils/datasets.py index 9497752..3f27416 100644 --- a/examples/dualheaded_datautils/datasets.py +++ b/examples/dualheaded_datautils/datasets.py @@ -21,7 +21,7 @@ def __init__(self, ids, values, worker_id=None, is_labels=False): self.values = torch.Tensor(values) if is_labels else torch.stack(values) self.worker_id = None - if worker_id != None: + if worker_id: self.send_to_worker(worker_id) def send_to_worker(self, worker): diff --git a/examples/dualheaded_datautils/enhancedSplitWorkers.py b/examples/dualheaded_datautils/enhancedSplitWorkers.py index 9e9c788..5c5909b 100644 --- a/examples/dualheaded_datautils/enhancedSplitWorkers.py +++ b/examples/dualheaded_datautils/enhancedSplitWorkers.py @@ -24,7 +24,7 @@ def __init__(self, worker, dataset, model, level=1): self.dataset = dataset #It can also be None, and then it would be only computational self.model = model - self.level = level if level >= 0 else 0 #it should start from zero, otherwise throw error #TODO: implement error throwing + self.level = max(level, 0) #it should start from zero, otherwise throw error #TODO: implement error throwing diff --git a/examples/dualheaded_datautils/utils.py b/examples/dualheaded_datautils/utils.py index f2d9635..afe2cfd 100644 --- a/examples/dualheaded_datautils/utils.py +++ b/examples/dualheaded_datautils/utils.py @@ -52,7 +52,7 @@ def split_data(dataset, worker_list=None, n_workers=2): the third the index, which is to keep track of the same data point. """ - if worker_list == None: + if worker_list is None: worker_list = list(range(0, n_workers)) #counter to create the index of different data samples