Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions devito/arch/archinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def supports(self, query, language=None):
warning(f"Couldn't establish if `query={query}` is supported on this "
"system. Assuming it is not.")
return False
elif query == 'async-loads' and cc >= 80:
elif query == 'async-pipe' and cc >= 80:
# Asynchronous pipeline loads -- introduced in Ampere
return True
elif query in ('tma', 'thread-block-cluster') and cc >= 90:
Expand All @@ -1055,7 +1055,7 @@ class Volta(NvidiaDevice):
class Ampere(Volta):

def supports(self, query, language=None):
if query == 'async-loads':
if query == 'async-pipe':
return True
else:
return super().supports(query, language)
Expand Down
10 changes: 10 additions & 0 deletions devito/ir/cgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,16 @@ def _print_Abs(self, expr):
return f"fabs({self._print(arg)})"
return self._print_fmath_func('abs', expr)

def _print_BitwiseNot(self, expr):
# Unary function, single argument
arg = expr.args[0]
return f'~{self._print(arg)}'

def _print_BitwiseXor(self, expr):
# Binary function
arg0, arg1 = expr.args
return f'{self._print(arg0)} ^ {self._print(arg1)}'

def _print_Add(self, expr, order=None):
""""
Print an addition.
Expand Down
6 changes: 4 additions & 2 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def __repr__(self):
if not self.is_Reduction:
return super().__repr__()
elif self.operation is OpInc:
return '%s += %s' % (self.lhs, self.rhs)
return f'Inc({self.lhs}, {self.rhs})'
else:
return '%s = %s(%s)' % (self.lhs, self.operation, self.rhs)
return f'Eq({self.lhs}, {self.operation}({self.rhs}))'

__str__ = __repr__

# Pickling support
__reduce_ex__ = Pickable.__reduce_ex__
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def _gen_value(self, obj, mode=1, masked=()):
qualifiers = [v for k, v in self._qualifiers_mapper.items()
if getattr(obj.function, k, False) and v not in masked]

if obj.is_LocalObject and mode == 2:
qualifiers.extend(as_tuple(obj._C_tag))

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = self.ccode(obj._C_typedata)
strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape)
Expand Down
17 changes: 9 additions & 8 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,17 +775,19 @@ def __init__(self, intervals, sub_iterators=None, directions=None):
super().__init__(intervals)

# Normalize sub-iterators
sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v))))
for k, v in (sub_iterators or {}).items()])
sub_iterators = sub_iterators or {}
sub_iterators = {d: tuple(filter_ordered(as_tuple(v)))
for d, v in sub_iterators.items() if d in self.intervals}
sub_iterators.update({i.dim: () for i in self.intervals
if i.dim not in sub_iterators})
self._sub_iterators = frozendict(sub_iterators)

# Normalize directions
if directions is None:
self._directions = frozendict([(i.dim, Any) for i in self.intervals])
else:
self._directions = frozendict(directions)
directions = directions or {}
directions = {d: v for d, v in directions.items() if d in self.intervals}
directions.update({i.dim: Any for i in self.intervals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth renaming the direction Any to avoid potential squatting on typing.Any?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should rather ask the python developers to revisit their type hinting crazyness 😂

if i.dim not in directions})
self._directions = frozendict(directions)

def __repr__(self):
ret = ', '.join(["%s%s" % (repr(i), repr(self.directions[i.dim]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover non-fstring

Expand All @@ -807,8 +809,7 @@ def __lt__(self, other):
return len(self.itintervals) < len(other.itintervals)

def __hash__(self):
return hash((super().__hash__(), self.sub_iterators,
self.directions))
return hash((super().__hash__(), self.sub_iterators, self.directions))

def __contains__(self, d):
try:
Expand Down
3 changes: 2 additions & 1 deletion devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass
from devito.symbolics import estimate_cost, q_leaf, q_terminal
from devito.symbolics import Reserved, estimate_cost, q_leaf, q_terminal
from devito.symbolics.search import search
from devito.symbolics.manipulation import _uxreplace
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
Expand Down Expand Up @@ -401,6 +401,7 @@ def _(expr):

@_catch.register(Indexed)
@_catch.register(Symbol)
@_catch.register(Reserved)
def _(expr):
"""
Handler for objects preventing CSE to propagate through their arguments.
Expand Down
46 changes: 30 additions & 16 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sympy import S
import numpy as np

from devito.finite_differences import IndexDerivative
from devito.finite_differences import IndexDerivative, Weights
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
from devito.passes.clusters.misc import fuse
from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace
Expand Down Expand Up @@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):


@_core.register(Symbol)
@_core.register(Indexed)
@_core.register(BasicWrapperMixin)
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
return expr, []


@_core.register(Indexed)
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
if not isinstance(expr.function, Weights):
return expr, []

# Lower or reuse a previously lowered Weights array
sregistry = kwargs['sregistry']
subs_user = kwargs['subs']

w0 = expr.function
k = tuple(w0.weights)
try:
w = weights[k]
except KeyError:
name = sregistry.make_name(prefix='w')
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
initvalue = tuple(i.subs(subs_user) for i in k)
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)

rebuilt = expr._subs(w0.indexed, w.indexed)

return rebuilt, []


@_core.register(IndexDerivative)
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
sregistry = kwargs['sregistry']
options = kwargs['options']
subs_user = kwargs['subs']

try:
cbk0 = deriv_schedule_registry[options['deriv-schedule']]
Expand All @@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):

# Create the concrete Weights array, or reuse an already existing one
# if possible
name = sregistry.make_name(prefix='w')
w0 = ideriv.weights.function
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
k = tuple(w0.weights)
try:
w = weights[k]
except KeyError:
initvalue = tuple(i.subs(subs_user) for i in k)
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs)

# Replace the abstract Weights array with the concrete one
subs = {w0.indexed: w.indexed}
subs = {ideriv.weights.base: w.base}
init = uxreplace(init, subs)
ideriv = uxreplace(ideriv, subs)

Expand All @@ -155,13 +169,13 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
ispace1 = IterationSpace.union(ispace, ispace0, relations=extra)

# The Symbol that will hold the result of the IndexDerivative computation
# NOTE: created before recurring so that we ultimately get a sound ordering
# NOTE: created before recursing so that we ultimately get a sound ordering
try:
s = reusables.pop()
assert np.can_cast(s.dtype, dtype)
assert np.can_cast(s.dtype, w.dtype)
except KeyError:
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=dtype)
s = Symbol(name=name, dtype=w.dtype)

# Go inside `expr` and recursively lower any nested IndexDerivatives
expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs)
Expand Down
75 changes: 43 additions & 32 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from collections import OrderedDict
from ctypes import c_uint64
from functools import singledispatch
from operator import itemgetter

import numpy as np
Expand Down Expand Up @@ -97,17 +96,29 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage):
"""
decl = Definition(obj)

if obj._C_init:
definition = (decl, obj._C_init)
init = obj._C_init
if not init:
definition = decl
efuncs = ()
elif isinstance(init, (list, tuple)):
assert len(init) == 2, "Expected (efunc, call)"
init, definition = init
efuncs = (init,)
elif init.is_Callable:
definition = Call(init.name, init.parameters,
retobj=obj if init.retval else None)
efuncs = (init,)
else:
definition = (decl)
definition = (decl, init)
efuncs = ()

frees = obj._C_free

if obj.free_symbols - {obj}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs = {'objs' if obj.free_symbols - {obj} else 'standalones': definition,
          efuncs': efuncs, 'frees': frees}
storage.update(obj, site, **kwargs)

perhaps?

storage.update(obj, site, objs=definition, frees=frees)
storage.update(obj, site, objs=definition, efuncs=efuncs, frees=frees)
else:
storage.update(obj, site, standalones=definition, frees=frees)
storage.update(obj, site, standalones=definition, efuncs=efuncs,
frees=frees)

def _alloc_array_on_low_lat_mem(self, site, obj, storage):
"""
Expand Down Expand Up @@ -554,7 +565,7 @@ class DeviceAwareDataManager(DataManager):
def __init__(self, options=None, **kwargs):
self.gpu_fit = options['gpu-fit']
self.gpu_create = options['gpu-create']
self.pmode = options.get('place-transfers')
self.gpu_place_transfers = options.get('place-transfers')

super().__init__(**kwargs)

Expand Down Expand Up @@ -587,7 +598,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage):

storage.update(obj, site, maps=mmap, unmaps=unmap)

def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False):
def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm,
read_only=False, **kwargs):
"""
Map a Function already defined in the host memory in to the device high
bandwidth memory.
Expand Down Expand Up @@ -620,42 +632,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs)

@iet_pass
def place_transfers(self, iet, data_movs=None, **kwargs):
def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs):
"""
Create a new IET with host-device data transfers. This requires mapping
symbols to the suitable memory spaces.
"""
if not self.pmode:
if not self.gpu_place_transfers:
return iet, {}

@singledispatch
def _place_transfers(iet, data_movs):
if not isinstance(iet, EntryFunction):
return iet, {}

@_place_transfers.register(EntryFunction)
def _(iet, data_movs):
reads, writes = data_movs
reads, writes = data_movs

# Special symbol which gives user code control over data deallocations
devicerm = DeviceRM()
# Special symbol which gives user code control over data deallocations
devicerm = DeviceRM()

storage = Storage()
for i in filter_sorted(writes):
if i.is_Array:
self._map_array_on_high_bw_mem(iet, i, storage)
else:
self._map_function_on_high_bw_mem(iet, i, storage, devicerm)
for i in filter_sorted(reads - writes):
if i.is_Array:
self._map_array_on_high_bw_mem(iet, i, storage)
else:
self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True)

iet, efuncs = self._inject_definitions(iet, storage)
storage = Storage()
for i in filter_sorted(writes):
if i.is_Array:
self._map_array_on_high_bw_mem(iet, i, storage)
else:
self._map_function_on_high_bw_mem(
iet, i, storage, devicerm, ctx=ctx
)
for i in filter_sorted(reads - writes):
if i.is_Array:
self._map_array_on_high_bw_mem(iet, i, storage)
else:
self._map_function_on_high_bw_mem(
iet, i, storage, devicerm, read_only=True, ctx=ctx
)

return iet, {'efuncs': efuncs}
iet, efuncs = self._inject_definitions(iet, storage)

return _place_transfers(iet, data_movs=data_movs)
return iet, {'efuncs': efuncs}

@iet_pass
def place_devptr(self, iet, **kwargs):
Expand Down
27 changes: 22 additions & 5 deletions devito/passes/iet/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
search)
from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass
from devito.types import (
Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension,
Array, Bundle, ComponentAccess, CompositeObject, IncrDimension, FunctionMap,
ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp,
NPThreads, NThreadsBase, Wildcard
)
Expand Down Expand Up @@ -528,12 +528,19 @@ def _(i, mapper, sregistry):

@abstract_object.register(Array)
def _(i, mapper, sregistry):
if isinstance(i, Lock):
name = sregistry.make_name(prefix='lock')
name = sregistry.make_name(prefix=i._symbol_prefix)

if i.initvalue is not None:
initvalue = []
for v in i.initvalue:
try:
initvalue.append(v.xreplace(mapper))
except AttributeError:
initvalue.append(v)
else:
name = sregistry.make_name(prefix='a')
initvalue = None

v = i._rebuild(name=name, alias=True)
v = i._rebuild(name=name, initvalue=initvalue, alias=True)

mapper.update({
i: v,
Expand Down Expand Up @@ -640,6 +647,16 @@ def _(i, mapper, sregistry):
mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ptr'))


@abstract_object.register(FunctionMap)
def _(i, mapper, sregistry):
name = sregistry.make_name(prefix=i._symbol_prefix)
tensor = mapper.get(i.tensor, i.tensor)

v = i._rebuild(name, tensor)

mapper[i] = v


@abstract_object.register(NPThreads)
def _(i, mapper, sregistry):
mapper[i] = i._rebuild(name=sregistry.make_name(prefix='npthreads'))
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _make_parallel(self, iet, sync_mapper=None):

return iet, {'includes': [self.langbb['header']]}

def make_parallel(self, graph):
def make_parallel(self, graph, **kwargs):
return self._make_parallel(graph, sync_mapper=graph.sync_mapper)


Expand Down
4 changes: 2 additions & 2 deletions devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from devito.tools.dtypes_lowering import dtype_mapper

__all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex',
'LONG']
'DOUBLE', 'VOID', 'LONG', 'ULONG', 'NoDeclStruct', 'c_complex',
'c_double_complex']


limits_mapper = {
Expand Down
Loading
Loading