from celery import current_app as celery
import requests
from io import BytesIO
import shapefile
import zipfile


@celery.task(bind=True)
def download_layer_task(self, arcgis_url, top_layer_name, layer_name):
    try:
        layers_url = f"{arcgis_url}/rest/services/{top_layer_name}/MapServer?f=json"
        layers_response = requests.get(layers_url)
        layers_response.raise_for_status()
        layers_data = layers_response.json()

        layer_id = next((layer['id'] for layer in layers_data['layers']
                         if layer['name'].lower() == layer_name.lower()), None)
        if layer_id is None:
            return {"error": "Layer name not found"}

        query_url = f"{arcgis_url}/rest/services/{top_layer_name}/MapServer/{layer_id}/query"
        params = {
            'where': '1=1',
            'outFields': '*',
            'f': 'json',
            'returnGeometry': 'true',
            'outSR': '4326'
        }
        response = requests.get(query_url, params=params)
        response.raise_for_status()
        data = response.json()

        if 'fields' not in data or 'features' not in data:
            return {"error": "Invalid response from ArcGIS server"}

        fields = data['fields']
        features = data['features']

        shp_io = BytesIO()
        shx_io = BytesIO()
        dbf_io = BytesIO()

        with shapefile.Writer(shp=shp_io, shx=shx_io, dbf=dbf_io) as shp:
            for field in fields:
                shp.field(field['name'], 'C')

            for feature in features:
                shp.record(*[feature['attributes'][field['name']] for field in fields])
                geometry = feature['geometry']
                if 'x' in geometry and 'y' in geometry:
                    shp.point(geometry['x'], geometry['y'])
                elif 'rings' in geometry:
                    shp.poly(geometry['rings'])
                elif 'paths' in geometry:
                    shp.line(geometry['paths'])

        shp_io.seek(0)
        shx_io.seek(0)
        dbf_io.seek(0)

        zip_io = BytesIO()
        with zipfile.ZipFile(zip_io, 'w') as zip_file:
            zip_file.writestr(f"{layer_name}.shp", shp_io.read())
            zip_file.writestr(f"{layer_name}.shx", shx_io.read())
            zip_file.writestr(f"{layer_name}.dbf", dbf_io.read())

        zip_io.seek(0)
        return zip_io.getvalue()
    except requests.RequestException as e:
        return {"error": str(e)}
    except Exception as e:
        return {"error": str(e)}
